package matrix;
public final class Matrix {
private double[][] matrix;
public static class MMException extends
Exception {
public MMException() {
}
public MMException(String message) {
super(message);
}
}
public Matrix(int n) {
matrix = new double[n][n];
}
public String toString() {
int n = matrix.length;
String result = "";
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
result += String.format("%1.2f ", matrix[j][i]);
}
result += "\n";
}
return result;
}
public int getSize() {
return matrix.length;
}
public double getElem(int i, int j) throws MMException {
return matrix[i][j];
}
public void setElem(int i, int j, double x) throws MMException {
matrix[i][j] = x;
}
public Matrix getSubMatrix(int left, int top, int n) throws MMException {
if (n < 1) {
return new Matrix(0);
}
Matrix result = new Matrix(n);
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
result.setElem(i, j, matrix[i + left][j + top]);
}
}
return result;
}
private static Matrix add(Matrix m1, Matrix m2) throws MMException {
if (m1 == null || m2 == null) {
throw new MMException("Matrices must not be null.");
}
int n = m1.getSize();
Matrix result = new Matrix(n);
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
result.setElem(i, j, m1.getElem(i, j) + m2.getElem(i, j));
}
}
return result;
}
public Matrix add(Matrix m2) throws MMException {
return add(this, m2);
}
private static Matrix subtract(Matrix m1, Matrix m2) throws MMException {
if (m1 == null || m2 == null) {
throw new MMException("Matrices must not be null.");
}
int n = m1.getSize();
Matrix result = new Matrix(n);
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
result.setElem(i, j, m1.getElem(i, j) - m2.getElem(i, j));
}
}
return result;
}
public Matrix subtract(Matrix m2) throws MMException {
return subtract(this, m2);
}
private static Matrix stdMult(Matrix m1, Matrix m2) throws MMException {
if (m1 == null || m2 == null) {
throw new MMException("Matrices must not be null.");
}
int n = m1.getSize();
Matrix result = new Matrix(n);
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
int value = 0;
for (int k = 0; k < n; k++) {
value += m1.getElem(j, k) * m2.getElem(k, i);
}
result.setElem(i, j, value);
}
}
return result;
}
public Matrix stdMult(Matrix m2) throws MMException {
return stdMult(this, m2);
}
private static Matrix ssMult(Matrix m1, Matrix m2, int n0, boolean verbose)
throws MMException {
int n = m1.getSize();
if (n < n0 || n < 3) {
Matrix result = stdMult(m1, m2);
if (verbose) {
System.out.println(result);
}
return result;
} else {
/*
* Dividing the two matrices into 4 submatrices each and poke them
* into some variables.
*/
Matrix m1TopLeft = m1.getSubMatrix(0, 0, (int) n / 2);
Matrix m1TopRight = m1.getSubMatrix((int) n / 2, 0, (int) n / 2);
Matrix m1BottomLeft = m1.getSubMatrix(0, (int) n / 2, (int) n / 2);
Matrix m1BottomRight = m1.getSubMatrix((int) n / 2, (int) n / 2,
(int) n / 2);
Matrix m2TopLeft = m2.getSubMatrix(0, 0, (int) n / 2);
Matrix m2TopRight = m2.getSubMatrix((int) n / 2, 0, (int) n / 2);
Matrix m2BottomLeft = m2.getSubMatrix(0, (int) n / 2, (int) n / 2);
Matrix m2BottomRight = m2.getSubMatrix((int) n / 2, (int) n / 2,
(int) n / 2);
/*
* Defining and setting the seven auxiliary matrices
*/
Matrix a1 = ssMult(subtract(m1TopRight, m1BottomRight), add(
m2BottomLeft, m2BottomRight), n0, verbose);
Matrix a2 = ssMult(add(m1TopLeft, m1BottomRight), add(m2TopLeft,
m2BottomRight), n0, verbose);
Matrix a3 = ssMult(subtract(m1TopLeft, m1BottomLeft), add(
m2TopLeft, m2TopRight), n0, verbose);
Matrix a4 = ssMult(add(m1TopLeft, m1TopRight), m2BottomRight, n0,
verbose);
Matrix a5 = ssMult(m1TopLeft, subtract(m2TopRight, m2BottomRight),
n0, verbose);
Matrix a6 = ssMult(m1BottomRight,
subtract(m2BottomLeft, m2TopLeft), n0, verbose);
Matrix a7 = ssMult(add(m1BottomLeft, m1BottomRight), m2TopLeft, n0,
verbose);
/*
* Calculating the four submatrices
*/
Matrix c11 = subtract(add(a1, a2), subtract(a4, a6));
Matrix c21 = add(a4, a5);
Matrix c12 = add(a6, a7);
Matrix c22 = add(subtract(a2, a3), subtract(a5, a7));
/*
* Finally inserting the four submatrices into the result matrix
*/
Matrix result = new Matrix(n);
for (int i = 0; i < (int) n / 2; i++) {
for (int j = 0; j < (int) n / 2; j++) {
result.setElem(i, j, c11.getElem(i, j));
}
}
for (int i = 0; i < (int) n / 2; i++) {
for (int j = 0; j < (int) n / 2; j++) {
result.setElem(i + (int) n / 2, j, c12.getElem(i, j));
}
}
for (int i = 0; i < (int) n / 2; i++) {
for (int j = 0; j < (int) n / 2; j++) {
result.setElem(i, j + (int) n / 2, c21.getElem(i, j));
}
}
for (int i = 0; i < (int) n / 2; i++) {
for (int j = 0; j < (int) n / 2; j++) {
result.setElem(i + (int) n / 2, j + (int) n / 2, c22
.getElem(i, j));
}
}
if (verbose) {
System.out.println(result);
}
return result;
}
}
private static boolean isPowerOf2(int n) {
while (n % 2 == 0) {
n = (int) n / 2;
}
return n == 1;
}
private static int getNextPowerOf2(int n) {
int result;
for (result = 2; result < n; result *= 2) {
}
return result;
}
public Matrix ssMult(Matrix m2, int n0, boolean verbose)
throws MMException {
if (matrix.length != m2.getSize()) {
throw new MMException("Matrices must be the same size.");
}
Matrix expandedM1, expandedM2;
if (!isPowerOf2(matrix.length)) {
expandedM1 = new Matrix(getNextPowerOf2(matrix.length));
expandedM2 = new Matrix(getNextPowerOf2(matrix.length));
for(int i=0; i<matrix.length; i++) {
for(int j=0; j<matrix.length; j++) {
expandedM1.setElem(i,j,matrix[i][j]);
expandedM2.setElem(i,j,m2.getElem(i,j));
}
}
} else {
expandedM1 = this;
expandedM2 = m2;
}
if (matrix.length < n0 || matrix.length < 3) {
Matrix result = stdMult(expandedM1, expandedM2);
if (expandedM1.getSize() != matrix.length) {
result = result.getSubMatrix(0, 0, matrix.length);
}
if (!verbose) {
System.out.println(result);
}
return result;
} else {
Matrix result = ssMult(expandedM1, expandedM2, n0, verbose);
if (expandedM1.getSize() != matrix.length) {
result = result.getSubMatrix(0, 0, matrix.length);
}
if (!verbose) {
System.out.println(result);
}
return result;
}
}
public Matrix ssMult(Matrix m2, int n0) throws MMException {
return ssMult(m2, n0, true);
}
}