Matrix inversion – Java implementation of finding the inverse of any order matrix

My recent study involves matrices, and finding the inverse, adjoint, and determinant of an n-order matrix is a headache. So I thought about finding a code online to find the inverse correlation of a matrix. After searching for a long time, all I found were calculations designed for float and double… em, this is not what I want. The effect I want is: just like our hand calculations, we cannot use decimals, and use fractions when needed. Expressed as fractions.

Forget it, I’d better work hard all night to achieve one! Alas, you can just build your own wheel. Maybe you will use it one day. Even if you don’t use it, it can still be considered as exercising your mathematical calculation and programming skills, haha.

The algorithm involved is actually very simple. For example, find the inverse matrix of matrix A, concatenate it with the identity matrix to obtain the augmented matrix (A, E), and then use row transformation to change the original matrix A into the identity matrix E. When the unit The matrix becomes the inverse of A. For another example, the value of the determinant of a matrix is calculated based on a certain row expansion, using recursion. I won’t go into details about the code, just look at the code.

Regarding the representation of fractions encountered in calculations, I specifically defined a Fraction class to handle its related calculations.

The idea is so simple, but it is still a little troublesome to implement. However, I still finished it, hehe.

Note: The code in this article is purely original. Any similarity is purely coincidental! If you quote the code of this article, please indicate the source.

Attached are the test results:

Test 1: Matrix A:

Calculation results:

Test 2: Matrix A

Calculation results:

Okay, no more nonsense! Above code: (Java implementation)

1.Mat class

import java.util.Arrays;

/**
 * matrix
 * @author gt5b
 * 2023-09-07 02:23
 */
public class Mat {
    private int rows;
    private int cols;
    private Fraction[][] elements;

    /**
     * What is passed in is a one-dimensional array
     * First determine whether it is a square matrix. If it is a square matrix, return the square matrix; otherwise, return a 1 × n matrix.
     * @param array one-dimensional array
     */
    public Mat(int[] array){
        int n = (int)(Math.sqrt(array.length));
        if(n * n == array.length){
            this.rows = n;
            this.cols = n;
        }else{
            this.rows = 1;
            this.cols = array.length;
        }
        elements = new Fraction[rows][cols];
        arrayToElements(array);
    }
    public Mat(int[][] array) {
        this.rows = array.length;
        this.cols = array[0].length;
        elements = new Fraction[rows][cols];
        for (int i = 0; i < rows; i + + ) {
            for (int j = 0; j < cols; j + + ) {
                elements[i][j] = new Fraction(array[i][j], 1);
            }
        }
    }

    /**
     * Create a matrix from a one-dimensional array
     *
     * @param array one-dimensional array
     * @param rows The number of rows in the matrix
     * @param cols The number of columns of the matrix
     */
    public Mat(int[] array, int rows, int cols) {
        if (rows == 0 & amp; & amp; cols == 0) {
            throw new IllegalArgumentException("The number of rows and columns cannot be 0 at the same time");
        }
        if (rows == 0) {
            rows = array.length / cols;
        } else if (cols == 0) {
            cols = array.length / rows;
        }
        if (rows * cols != array.length) {
            throw new IllegalArgumentException("The array length does not match the number of rows and columns");
        }
        this.rows = rows;
        this.cols = cols;
        elements = new Fraction[rows][cols];
        arrayToElements(array);
    }
    private void arrayToElements(int[] array){
        int index = 0;
        for (int i = 0; i < rows; i + + ) {
            for (int j = 0; j < cols; j + + ) {
                elements[i][j] = new Fraction(array[index + + ], 1);
            }
        }
    }
    public Mat(int rows, int cols) {
        this.rows = rows;
        this.cols = cols;
        this.elements = new Fraction[rows][cols];
        for (int i = 0; i < rows; i + + ) {
            Arrays.fill(elements[i], new Fraction(0, 1));
        }
    }

    public int getRows() {
        return rows;
    }

    public int getCols() {
        return cols;
    }

    public Fraction getElement(int row, int col) {
        return elements[row][col];
    }

    public void setElement(int row, int col, Fraction value) {
        elements[row][col] = value;
    }

    /**
     * Determine whether it is a square matrix
     * @return square matrix returns true
     */
    public boolean isSquare() {
        return rows == cols;
    }

    /**
     * Find the adjoint matrix
     * @return Returns the adjoint matrix A* of matrix A
     */
    public Mat adjoinMatrix(){
        //Find the inverse of the matrix
        Mat n = this.inverseMatrix();
        Mat adjoinMatrix = new Mat(rows,cols);
        //Find the value of the determinant
        Fraction d = this.determinant();
        //The adjoint matrix is d * its inverse
        for(int i = 0; i < rows; + + i){
            for(int j = 0; j < cols; + + j){
                adjoinMatrix.setElement(i,j,n.getElement(i,j).multiply(d));
            }
        }
        return adjoinMatrix;
    }

    /**
     * Find the value of the determinant of the matrix
     * @return Returns the exact result of the determinant of this matrix
     */
    public Fraction determinant() {
        //Select which row to expand into the sum aij×Aij
        int chooseRow = 0;
        // Check if the matrix is a square matrix
        if (!isSquare()) {
            throw new IllegalStateException("The matrix is not a square matrix");
        }
        // Get the order of the matrix
        int n = getRows();
        // Recursively calculate the determinant
        if (n == 1) {
            return getElement(0, 0);
        } else if (n == 2) {
            // The calculation of the determinant of the 2x2 matrix is relatively simple for the second order, and you can directly use ad-bc to calculate it.
            Fraction a = getElement(0, 0);
            Fraction b = getElement(0, 1);
            Fraction c = getElement(1, 0);
            Fraction d = getElement(1, 1);
            return a.multiply(d).subtract(b.multiply(c));
        } else {
            // Calculation of determinants of third-order and higher-order matrices, implemented here using recursion
            Fraction determinant = new Fraction(0, 1);
            for (int j = 0; j < n; j + + ) {
                // Get the algebraic cofactor Aij of aij
                Fraction cofactor = getCofactor(chooseRow, j);
                // Get aij
                Fraction e = this.getElement(chooseRow,j);
                // Accumulated algebraic cofactor aij×Aij
                determinant = determinant.add(cofactor.multiply(e));
            }
            return determinant;
        }
    }

    /**
     * Solve the algebraic cofactor of aij
     * @param row aij's row
     * @param column aij's column
     * @return the algebraic cofactor Aij of aij
     */
    private Fraction getCofactor(int row, int column) {
        // Calculate the symbols of algebraic cofactors
        int sign = (row + column) % 2 == 0 ? 1 : -1;
        // Calculate the determinant of the submatrix
        Fraction subDeterminant = subMatrix(row, column).determinant();
        //The algebraic cofactor is equal to the submatrix determinant times the sign
        return subDeterminant.multiply(new Fraction(sign, 1));
    }

    /**
     * Get the subformula of aij
     * @param row The row where the element is located
     * @param column the column where the element is located
     * @return Returns its subformula (remove the row and column where it is located, and the resulting matrix is obtained)
     */
    private Mat subMatrix(int row, int column) {
        int n = getRows();
        //Create submatrix
        Mat subMatrix = new Mat(n - 1, n - 1);
        int rowIndex = 0;
        for (int i = 0; i < n; i + + ) {
            if (i == row) {
                continue;
            }
            int colIndex = 0;
            for (int j = 0; j < n; j + + ) {
                if (j == column) {
                    continue;
                }
                //Copy the elements of the original matrix to the submatrix
                subMatrix.setElement(rowIndex, colIndex, getElement(i, j));
                colIndex + + ;
            }
            rowIndex + + ;
        }
        return subMatrix;
    }

    /**
     * Get the inverse of the matrix
     * @return Returns the inverse of this
     */
    public Mat inverseMatrix() {
        // Check if the matrix is invertible
        if (!this.isInvertible()) {
            throw new IllegalArgumentException("Matrix is not invertible");
        }
        // Get the number of rows and columns of the matrix
        int rows = this.getRows();
        int cols = this.getCols();
        //Create the identity matrix
        Mat identityMatrix = createIdentityMatrix(rows);
        // Concatenate the original matrix and the identity matrix
        Mat augmentedMatrix = concatenateMatrices(this, identityMatrix);
        // Use Gauss-Jordan elimination method to solve the inverse matrix
        gaussianJordanElimination(augmentedMatrix);
        // Extract the inverse matrix part. This method obtains the inverse matrix by extracting the rightmost column of the augmented matrix.
        return extractInverseMatrix(augmentedMatrix, cols);
    }

    /**
     * Determine whether the matrix is invertible
     * @return Reversible returns true
     */
    private boolean isInvertible() {
        // Check if the matrix is a square matrix
        if (!this.isSquare()) {
            return false;
        }
        // Check whether the value of the determinant of the matrix is 0
        if (this.determinant().equals(new Fraction(0, 1))) {
            throw new IllegalArgumentException("The matrix determinant value is 0, irreversible");
        }
        return true;
    }

    /**
     * Create an identity matrix
     * @param size The order of the matrix
     * @return Returns an n-order matrix
     */
    public static Mat createIdentityMatrix(int size) {
        Mat identityMatrix = new Mat(size, size);
        for (int i = 0; i < size; i + + ) {
            for (int j = 0; j < size; j + + ) {
                if (i == j) {
                    identityMatrix.setElement(i, j, new Fraction(1, 1));
                } else {
                    identityMatrix.setElement(i, j, new Fraction(0, 1));
                }
            }
        }
        return identityMatrix;
    }

    /**
     * Concatenate matrix A and identity matrix E
     * @param matrix1 original matrix A
     * @param matrix2 Identity matrix of the same order as A
     * @return Return augmented matrix (A,E)
     */
    private static Mat concatenateMatrices(Mat matrix1, Mat matrix2) {
        int rows1 = matrix1.getRows();
        int cols1 = matrix1.getCols();
        int rows2 = matrix2.getRows();
        int cols2 = matrix2.getCols();
        if (rows1 != rows2) {
            throw new IllegalArgumentException("The number of rows of the two matrices is not equal");
        }
        Mat resultMatrix = new Mat(rows1, cols1 + cols2);
        for (int i = 0; i < rows1; i + + ) {
            for (int j = 0; j < cols1; j + + ) {
                resultMatrix.setElement(i, j, matrix1.getElement(i, j));
            }
            for (int j = 0; j < cols2; j + + ) {
                resultMatrix.setElement(i, cols1 + j, matrix2.getElement(i, j));
            }
        }

        return resultMatrix;
    }

    /**
     * Find the inverse of matrix A using Gaussian elimination method
     * @param matrix
     */
    private static void gaussianJordanElimination(Mat matrix) {
        int rows = matrix.getRows();
        int cols = matrix.getCols();
        for (int i = 0; i < rows; i + + ) {
            Fraction pivot = matrix.getElement(i, i);
            //If the element processed in the current row is 0
            if (pivot.equals(new Fraction(0, 1))) {
                //Find the non-zero element below it
                int swapRow = findNonZeroElementRow(matrix, i, i);
                //If it cannot be found, it means that all the following are 0, indicating that the matrix is irreversible
                if (swapRow == -1) {
                    throw new IllegalArgumentException("Matrix is not invertible");
                }
                //Swap these two lines after finding them
                swapRows(matrix, i, swapRow);
                pivot = matrix.getElement(i, i);
            }
            for (int j = 0; j < cols; j + + ) {
                Fraction element = matrix.getElement(i, j);
                matrix.setElement(i, j, element.divide(pivot));
            }
            //The ajk of the j-th row is subtracted from the aik of the i-th row so that the ajk becomes 0
            for (int j = 0; j < rows; j + + ) {
                if (j != i) {
                    Fraction factor = matrix.getElement(j, i);
                    for (int k = 0; k < cols; k + + ) {
                        Fraction element1 = matrix.getElement(j, k);
                        Fraction element2 = matrix.getElement(i, k);
                        matrix.setElement(j, k, element1.subtract(element2.multiply(factor)));
                    }
                }
            }
        }
    }

    /**
     * Find non-zero elements when row transformation
     * @param matrix matrix A
     * @param col aij's column
     * @param startRow The row where aij is located
     * @return Returns the subscript of the non-zero element directly below aij
     */
    private static int findNonZeroElementRow(Mat matrix, int col, int startRow) {
        int rows = matrix.getRows();
        for (int i = startRow; i < rows; i + + ) {
            if (!matrix.getElement(i, col).equals(new Fraction(0, 1))) {
                return i;
            }
        }
        return -1;
    }

    private static void swapRows(Mat matrix, int row1, int row2) {
        int cols = matrix.getCols();
        for (int j = 0; j < cols; j + + ) {
            Fraction temp = matrix.getElement(row1, j);
            matrix.setElement(row1, j, matrix.getElement(row2, j));
            matrix.setElement(row2, j, temp);
        }
    }

    /**
     * After the row transformation is completed, (A, E)-> (E, A inverse), extract the inverse of A
     * @param matrix augmented matrix (E, A inverse)
     * @param cols
     * @return
     */
    private static Mat extractInverseMatrix(Mat matrix, int cols) {
        int rows = matrix.getRows();
        Mat inverseMatrix = new Mat(rows, cols);
        for (int i = 0; i < rows; i + + ) {
            for (int j = 0; j < cols; j + + ) {
                //Because it is (E, A inverse), the inverse of A on the right is obtained, so the column uses j + cols
                inverseMatrix.setElement(i, j, matrix.getElement(i, j + cols));
            }
        }
        return inverseMatrix;
    }

    public void setRows(int rows) {
        this.rows = rows;
    }

    public void setCols(int cols) {
        this.cols = cols;
    }

    public Fraction[][] getElements() {
        return elements;
    }

    public void setElements(Fraction[][] elements) {
        this.elements = elements;
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (o == null || getClass() != o.getClass()) return false;
        Mat mat = (Mat) o;
        if (rows != mat.rows) return false;
        if (cols != mat.cols) return false;
        return Arrays.deepEquals(elements, mat.elements);
    }

    @Override
    public int hashCode() {
        int result = rows;
        result = 31 * result + cols;
        result = 31 * result + Arrays.deepHashCode(elements);
        return result;
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        for(int i = 0; i < rows; + + i){
            sb.append("[");
            for(int j = 0; j < cols; + + j){
                sb.append(getElement(i, j)).append("\t,");
            }
            sb.delete(sb.length()-1,sb.length());
            sb.append("]\
");
        }
        return sb.toString();
    }
}

2.Fraction class

/**
 * Fraction
 * @author gt5b
 * 2023-09-07 01:12
 */
public class Fraction {
    private int numerator;
    private int denominator;
    public Fraction(){
        this.numerator = 0;
        this.denominator = 1;
    }
    public Fraction(int numerator, int denominator) {
        if (denominator == 0) {
            throw new IllegalArgumentException("The denominator cannot be 0");
        }
        //If the numerator is 0 and the denominator is not 0, write it directly as 0/1
        if(numerator == 0){
            this.numerator = 0;
            this.denominator = 1;
        }else{
            //Find the greatest common factor of the numerator and denominator to convert it into the simplest fraction
            int gcd = gcd(Math.abs(numerator), Math.abs(denominator));
            //Judge positive or negative
            //If the numerator is negative, the denominator is positive. Directly adjust the greatest common factor to -gcd
            if (denominator < 0 & amp; & numerator > 0) {
                gcd = -gcd;
            }
            //If the numerator is positive, the denominator is negative. Adjust the greatest common factor to -gcd,
            //In order to facilitate post-processing, change the denominator to a positive number and the numerator to a negative number
            else if(denominator > 0 & amp; & numerator < 0){
                gcd = -gcd;
                denominator = -denominator;
                numerator = -numerator;
            }
            //If the numerator and denominator are both negative numbers, turn them into positive numbers
            else if(denominator < 0 & amp; & numerator < 0){
                denominator = -denominator;
                numerator = -numerator;
            }
            this.denominator = denominator / gcd;
            this.numerator = numerator / gcd;
        }


    }

    public int getNumerator() {
        return numerator;
    }

    public int getDenominator() {
        return denominator;
    }

    /**
     * Find the greatest common factor of the numerator and denominator
     * @param a molecule
     * @param b denominator
     * @return Returns the greatest common factor of the numerator and denominator
     */
    private int gcd(int a, int b) {
        if (b == 0) {
            return a;
        }
        return gcd(b, a % b);
    }

    public Fraction add(Fraction other) {
        int newNumerator = this.numerator * other.denominator + this.denominator * other.numerator;
        int newDenominator = this.denominator * other.denominator;
        return new Fraction(newNumerator, newDenominator);
    }

    public Fraction subtract(Fraction other) {
        int newNumerator = this.numerator * other.denominator - this.denominator * other.numerator;
        int newDenominator = this.denominator * other.denominator;
        return new Fraction(newNumerator, newDenominator);
    }
    public Fraction multiply(Fraction other) {
        int newNumerator = this.numerator * other.numerator;
        int newDenominator = this.denominator * other.denominator;
        return new Fraction(newNumerator, newDenominator);
    }
    public Fraction divide(Fraction other) {
        int newNumerator = this.numerator * other.denominator;
        int newDenominator = this.denominator * other.numerator;
        return new Fraction(newNumerator, newDenominator);
    }
    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (o == null || getClass() != o.getClass()) return false;
        Fraction fraction = (Fraction) o;
        if (numerator != fraction.numerator) return false;
        return denominator == fraction.denominator;
    }
    @Override
    public int hashCode() {
        int result = numerator;
        result = 31 * result + denominator;
        return result;
    }

    public void setNumerator(int numerator) {
        this.numerator = numerator;
    }

    public void setDenominator(int denominator) {
        this.denominator = denominator;
    }

    @Override
    public String toString() {
        if (denominator == 1) {
            return String.valueOf(numerator);
        }
        return numerator + "/" + denominator;
    }
}