package cern.colt.matrix.tfloat.algo.solver.preconditioner;

import cern.colt.Sorting;
import cern.colt.matrix.tfloat.FloatMatrix1D;
import cern.colt.matrix.tfloat.FloatMatrix2D;
import cern.colt.matrix.tfloat.impl.DenseFloatMatrix1D;
import cern.colt.matrix.tfloat.impl.SparseRCFloatMatrix2D;

/* loaded from: input_file:lib/parallelcolt-0.9.4.jar:cern/colt/matrix/tfloat/algo/solver/preconditioner/FloatILU.class */
public class FloatILU implements FloatPreconditioner {
    private SparseRCFloatMatrix2D LU;
    private final FloatMatrix1D y;
    private int[] diagind;
    private final int n;

    public FloatILU(int i) {
        this.n = i;
        this.y = new DenseFloatMatrix1D(i);
    }

    @Override // cern.colt.matrix.tfloat.algo.solver.preconditioner.FloatPreconditioner
    public FloatMatrix1D apply(FloatMatrix1D floatMatrix1D, FloatMatrix1D floatMatrix1D2) {
        if (floatMatrix1D2 == null) {
            floatMatrix1D2 = floatMatrix1D.like();
        }
        lowerUnitSolve(floatMatrix1D, this.y);
        return upperSolve(this.y, floatMatrix1D2);
    }

    @Override // cern.colt.matrix.tfloat.algo.solver.preconditioner.FloatPreconditioner
    public FloatMatrix1D transApply(FloatMatrix1D floatMatrix1D, FloatMatrix1D floatMatrix1D2) {
        if (floatMatrix1D2 == null) {
            floatMatrix1D2 = floatMatrix1D.like();
        }
        upperTransSolve(floatMatrix1D, this.y);
        return loverUnitTransSolve(this.y, floatMatrix1D2);
    }

    @Override // cern.colt.matrix.tfloat.algo.solver.preconditioner.FloatPreconditioner
    public void setMatrix(FloatMatrix2D floatMatrix2D) {
        if (floatMatrix2D.rows() != this.n) {
            throw new IllegalArgumentException("A.rows() != n");
        }
        this.LU = new SparseRCFloatMatrix2D(this.n, this.n);
        this.LU.assign(floatMatrix2D);
        if (!this.LU.hasColumnIndexesSorted()) {
            this.LU.sortColumnIndexes();
        }
        factor();
    }

    private void factor() {
        int[] columnIndexes = this.LU.getColumnIndexes();
        int[] rowPointers = this.LU.getRowPointers();
        float[] values = this.LU.getValues();
        this.diagind = findDiagonalIndexes(this.n, columnIndexes, rowPointers);
        for (int i = 1; i < this.n; i++) {
            for (int i2 = rowPointers[i]; i2 < this.diagind[i]; i2++) {
                int i3 = columnIndexes[i2];
                float f = values[this.diagind[i3]];
                if (f == 0.0f) {
                    throw new RuntimeException("Zero pivot encountered on row " + (i2 + 1) + " during ILU process");
                }
                int i4 = i2;
                float f2 = values[i4] / f;
                values[i4] = f2;
                int i5 = rowPointers[i] + 1;
                for (int i6 = this.diagind[i3] + 1; i6 < rowPointers[i3 + 1]; i6++) {
                    while (i5 < rowPointers[i + 1] && columnIndexes[i5] < columnIndexes[i6]) {
                        i5++;
                    }
                    if (columnIndexes[i5] == columnIndexes[i6]) {
                        int i7 = i5;
                        values[i7] = values[i7] - (f2 * values[i6]);
                    }
                }
            }
        }
    }

    private int[] findDiagonalIndexes(int i, int[] iArr, int[] iArr2) {
        int[] iArr3 = new int[i];
        for (int i2 = 0; i2 < i; i2++) {
            iArr3[i2] = Sorting.binarySearchFromTo(iArr, i2, iArr2[i2], iArr2[i2 + 1] - 1);
            if (iArr3[i2] < 0) {
                throw new RuntimeException("Missing diagonal entry on row " + (i2 + 1));
            }
        }
        return iArr3;
    }

    private FloatMatrix1D lowerUnitSolve(FloatMatrix1D floatMatrix1D, FloatMatrix1D floatMatrix1D2) {
        float[] elements = ((DenseFloatMatrix1D) floatMatrix1D).elements();
        float[] elements2 = ((DenseFloatMatrix1D) floatMatrix1D2).elements();
        int[] columnIndexes = this.LU.getColumnIndexes();
        int[] rowPointers = this.LU.getRowPointers();
        float[] values = this.LU.getValues();
        int rows = this.LU.rows();
        for (int i = 0; i < rows; i++) {
            float f = 0.0f;
            for (int i2 = rowPointers[i]; i2 < this.diagind[i]; i2++) {
                f += values[i2] * elements2[columnIndexes[i2]];
            }
            elements2[i] = elements[i] - f;
        }
        return floatMatrix1D2;
    }

    private FloatMatrix1D loverUnitTransSolve(FloatMatrix1D floatMatrix1D, FloatMatrix1D floatMatrix1D2) {
        floatMatrix1D2.assign(floatMatrix1D);
        float[] elements = ((DenseFloatMatrix1D) floatMatrix1D2).elements();
        int[] columnIndexes = this.LU.getColumnIndexes();
        int[] rowPointers = this.LU.getRowPointers();
        float[] values = this.LU.getValues();
        for (int rows = this.LU.rows() - 1; rows >= 0; rows--) {
            for (int i = rowPointers[rows]; i < this.diagind[rows]; i++) {
                int i2 = columnIndexes[i];
                elements[i2] = elements[i2] - (values[i] * elements[rows]);
            }
        }
        return floatMatrix1D2;
    }

    private FloatMatrix1D upperSolve(FloatMatrix1D floatMatrix1D, FloatMatrix1D floatMatrix1D2) {
        float[] elements = ((DenseFloatMatrix1D) floatMatrix1D).elements();
        float[] elements2 = ((DenseFloatMatrix1D) floatMatrix1D2).elements();
        int[] columnIndexes = this.LU.getColumnIndexes();
        int[] rowPointers = this.LU.getRowPointers();
        float[] values = this.LU.getValues();
        for (int rows = this.LU.rows() - 1; rows >= 0; rows--) {
            float f = 0.0f;
            for (int i = this.diagind[rows] + 1; i < rowPointers[rows + 1]; i++) {
                f += values[i] * elements2[columnIndexes[i]];
            }
            elements2[rows] = (elements[rows] - f) / values[this.diagind[rows]];
        }
        return floatMatrix1D2;
    }

    private FloatMatrix1D upperTransSolve(FloatMatrix1D floatMatrix1D, FloatMatrix1D floatMatrix1D2) {
        floatMatrix1D2.assign(floatMatrix1D);
        float[] elements = ((DenseFloatMatrix1D) floatMatrix1D2).elements();
        int[] columnIndexes = this.LU.getColumnIndexes();
        int[] rowPointers = this.LU.getRowPointers();
        float[] values = this.LU.getValues();
        int rows = this.LU.rows();
        for (int i = 0; i < rows; i++) {
            int i2 = i;
            elements[i2] = elements[i2] / values[this.diagind[i]];
            for (int i3 = this.diagind[i] + 1; i3 < rowPointers[i + 1]; i3++) {
                int i4 = columnIndexes[i3];
                elements[i4] = elements[i4] - (values[i3] * elements[i]);
            }
        }
        return floatMatrix1D2;
    }
}
