package org.fastica;

import cern.colt.matrix.DoubleFactory1D;
import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.colt.matrix.linalg.SingularValueDecomposition;
import cern.jet.random.Normal;
import cern.jet.random.engine.MersenneTwister;
import java.util.ArrayList;
import java.util.List;
import org.fastica.EVfilters.EigenValueFilter;
import org.fastica.EVfilters.ToNComponentsEVFilter;
import org.fastica.FastICAConfig;
import org.fastica.contrastFunctions.ContrastFunction;
import org.fastica.listener.FastICAListener;
import org.fastica.math.MatrixUtils;
import org.fastica.math.Scale;
import org.knime.core.node.CanceledExecutionException;

/* loaded from: input_file:lib/fastica.jar:org/fastica/FastICA.class */
public class FastICA {
    private static Algebra a = new Algebra();
    private static final int iterPerUpdate = 200;
    private static final int symIterPerUpdate = 20;
    private static final double symDelta = 1.0E-4d;
    private static final double svdEps = 7.0E-16d;

    public static double[][] fastICA(double[][] dArr, int i) throws CanceledExecutionException {
        return fastICAwrapper(new DenseDoubleMatrix2D(dArr), new FastICAConfig(i), null);
    }

    public static List<double[][]> fastICA2(double[][] dArr, int i, double d, int i2) throws CanceledExecutionException {
        return fastICAwrapper2(new DenseDoubleMatrix2D(dArr), new FastICAConfig(i, d, i2), null);
    }

    public static double[][] fastICA(double[][] dArr, FastICAConfig fastICAConfig, EigenValueFilter eigenValueFilter) throws CanceledExecutionException {
        return fastICAwrapper(new DenseDoubleMatrix2D(dArr), fastICAConfig, eigenValueFilter);
    }

    public static double[][] fastICA(double[][] dArr, FastICAConfig fastICAConfig) throws CanceledExecutionException {
        return fastICA(dArr, fastICAConfig, null);
    }

    private static double[][] fastICAwrapper(DoubleMatrix2D doubleMatrix2D, FastICAConfig fastICAConfig, EigenValueFilter eigenValueFilter) throws CanceledExecutionException {
        if (fastICAConfig.getNumICs() > doubleMatrix2D.columns()) {
            System.err.println("number of ICs is too large... setting to " + doubleMatrix2D.columns());
            fastICAConfig.setNumICs(doubleMatrix2D.columns());
        } else if (fastICAConfig.getNumICs() < doubleMatrix2D.columns()) {
            System.err.println("reducing input to " + fastICAConfig.getNumICs() + " variables");
            doubleMatrix2D = PCA.pca(doubleMatrix2D, new ToNComponentsEVFilter(fastICAConfig.getNumICs()))[0].viewDice();
        } else {
            System.err.println("extracting exactly " + fastICAConfig.getNumICs() + " from " + fastICAConfig.getNumICs() + " variables");
        }
        int rows = doubleMatrix2D.rows();
        int columns = doubleMatrix2D.columns();
        DoubleMatrix2D initialMixingMatrix = fastICAConfig.getInitialMixingMatrix();
        int numICs = fastICAConfig.getNumICs();
        if (initialMixingMatrix == null) {
            fastICAConfig.setInitialMixingMatrix(makeRandom(fastICAConfig.getNumICs()));
        } else if (initialMixingMatrix.rows() != numICs || initialMixingMatrix.columns() != numICs) {
            throw new IllegalArgumentException("Bad initial weight matrix!");
        }
        MatrixUtils.center(doubleMatrix2D);
        DoubleMatrix2D viewDice = doubleMatrix2D.viewDice();
        DoubleMatrix2D whiteningMatrix = getWhiteningMatrix(viewDice, numICs, rows, columns);
        return a.mult(a.mult(algorithm(a.mult(whiteningMatrix, viewDice), fastICAConfig), whiteningMatrix), viewDice).viewDice().toArray();
    }

    private static List<double[][]> fastICAwrapper2(DoubleMatrix2D doubleMatrix2D, FastICAConfig fastICAConfig, EigenValueFilter eigenValueFilter) throws CanceledExecutionException {
        if (fastICAConfig.getNumICs() > doubleMatrix2D.columns()) {
            System.err.println("number of ICs is too large... setting to " + doubleMatrix2D.columns());
            fastICAConfig.setNumICs(doubleMatrix2D.columns());
        } else if (fastICAConfig.getNumICs() < doubleMatrix2D.columns()) {
            System.err.println("reducing input to " + fastICAConfig.getNumICs() + " variables");
            doubleMatrix2D = PCA.pca(doubleMatrix2D, new ToNComponentsEVFilter(fastICAConfig.getNumICs()))[0].viewDice();
        } else {
            System.err.println("extracting exactly " + fastICAConfig.getNumICs() + " from " + fastICAConfig.getNumICs() + " variables");
        }
        int rows = doubleMatrix2D.rows();
        int columns = doubleMatrix2D.columns();
        DoubleMatrix2D initialMixingMatrix = fastICAConfig.getInitialMixingMatrix();
        int numICs = fastICAConfig.getNumICs();
        if (initialMixingMatrix == null) {
            fastICAConfig.setInitialMixingMatrix(makeRandom(fastICAConfig.getNumICs()));
        } else if (initialMixingMatrix.rows() != numICs || initialMixingMatrix.columns() != numICs) {
            throw new IllegalArgumentException("Bad initial weight matrix!");
        }
        DoubleMatrix1D center = MatrixUtils.center(doubleMatrix2D);
        DoubleMatrix2D viewDice = doubleMatrix2D.viewDice();
        System.out.println("viewDice()");
        DoubleMatrix2D whiteningMatrix = getWhiteningMatrix(viewDice, numICs, rows, columns);
        DoubleMatrix2D mult = a.mult(algorithm(a.mult(whiteningMatrix, viewDice), fastICAConfig), whiteningMatrix);
        System.out.println("unwhitened");
        DoubleMatrix2D viewDice2 = a.mult(mult, viewDice).viewDice();
        System.out.println("w size: " + mult.size());
        System.out.println("means size: " + center.size());
        MatrixUtils.uncenter(mult, center);
        ArrayList arrayList = new ArrayList();
        arrayList.add(viewDice2.toArray());
        arrayList.add(mult.toArray());
        return arrayList;
    }

    private static DoubleMatrix2D makeRandom(int i) {
        Normal normal = new Normal(0.0d, 1.0d, new MersenneTwister());
        DenseDoubleMatrix2D denseDoubleMatrix2D = new DenseDoubleMatrix2D(i, i);
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < i; i3++) {
                denseDoubleMatrix2D.setQuick(i2, i3, normal.nextDouble());
            }
        }
        return denseDoubleMatrix2D;
    }

    private static DoubleMatrix2D getWhiteningMatrix(DoubleMatrix2D doubleMatrix2D, int i, int i2, int i3) {
        SingularValueDecomposition singularValueDecomposition = new SingularValueDecomposition(a.mult(doubleMatrix2D, doubleMatrix2D.viewDice()).assign(new Scale(1.0d / i2)));
        double[] singularValues = singularValueDecomposition.getSingularValues();
        for (int i4 = 0; i4 < singularValues.length; i4++) {
            singularValues[i4] = 1.0d / Math.sqrt(singularValues[i4] + svdEps);
        }
        return a.mult(DoubleFactory2D.dense.diagonal(new DenseDoubleMatrix1D(singularValues)), singularValueDecomposition.getU().viewDice()).viewPart(0, 0, i, i3);
    }

    private static DoubleMatrix2D algorithm(DoubleMatrix2D doubleMatrix2D, FastICAConfig fastICAConfig) throws CanceledExecutionException {
        return fastICAConfig.getApproach() == FastICAConfig.Approach.DEFLATION ? deflationAlg(doubleMatrix2D, fastICAConfig) : parallelAlg(doubleMatrix2D, fastICAConfig);
    }

    private static DoubleMatrix2D parallelAlg(DoubleMatrix2D doubleMatrix2D, FastICAConfig fastICAConfig) throws CanceledExecutionException {
        int maxIterations = fastICAConfig.getMaxIterations();
        double epsilon = fastICAConfig.getEpsilon();
        ContrastFunction cf = fastICAConfig.getCf();
        DoubleMatrix2D initialMixingMatrix = fastICAConfig.getInitialMixingMatrix();
        FastICAListener listener = fastICAConfig.getListener();
        int columns = doubleMatrix2D.columns();
        DoubleMatrix2D symmetricDecorrelate = symmetricDecorrelate(initialMixingMatrix);
        symmetricDecorrelate.copy();
        int i = 0;
        double d = Double.MAX_VALUE;
        while (d > epsilon && i < maxIterations) {
            DoubleMatrix2D mult = a.mult(symmetricDecorrelate, doubleMatrix2D);
            DoubleMatrix2D assign = a.mult(mult.copy().assign(cf.getFunction()), doubleMatrix2D.viewDice()).assign(new Scale(1 / columns));
            MatrixUtils.subtract(assign, a.mult(DoubleFactory2D.dense.diagonal(MatrixUtils.getRowAvgs(mult.assign(cf.getDerivative()))), symmetricDecorrelate));
            DoubleMatrix2D symmetricDecorrelate2 = symmetricDecorrelate(assign);
            System.err.println("finished iteration " + i);
            d = delta(symmetricDecorrelate2, symmetricDecorrelate);
            symmetricDecorrelate = symmetricDecorrelate2;
            i++;
            if (listener != null && i % symIterPerUpdate == 19) {
                double d2 = (i * 1.0d) / maxIterations;
                if (!listener.markProgress((int) Math.round(d2 * fastICAConfig.getNumICs()), (int) Math.round(d2 * fastICAConfig.getMaxIterations()))) {
                    throw new CanceledExecutionException();
                }
            }
        }
        return symmetricDecorrelate;
    }

    private static DoubleMatrix2D deflationAlg(DoubleMatrix2D doubleMatrix2D, FastICAConfig fastICAConfig) throws CanceledExecutionException {
        int maxIterations = fastICAConfig.getMaxIterations();
        int numICs = fastICAConfig.getNumICs();
        double epsilon = fastICAConfig.getEpsilon();
        ContrastFunction cf = fastICAConfig.getCf();
        DoubleMatrix2D initialMixingMatrix = fastICAConfig.getInitialMixingMatrix();
        FastICAListener listener = fastICAConfig.getListener();
        DoubleMatrix2D make = DoubleFactory2D.dense.make(numICs, numICs, 0.0d);
        for (int i = 0; i < numICs; i++) {
            DoubleMatrix1D decorrelateVector = decorrelateVector(initialMixingMatrix.viewRow(i).viewPart(0, numICs).copy(), make, i);
            double d = Double.MAX_VALUE;
            int i2 = 0;
            if (listener != null && !listener.markProgress(i, 0)) {
                throw new CanceledExecutionException();
            }
            while (i2 < maxIterations && d > epsilon) {
                DoubleMatrix1D mult = a.mult(doubleMatrix2D.viewDice(), decorrelateVector);
                DoubleMatrix1D rowAvgs = MatrixUtils.getRowAvgs(MatrixUtils.pieceWiseMult(doubleMatrix2D, toMultRows(mult.copy().assign(cf.getFunction()), numICs)));
                double zSum = mult.assign(cf.getDerivative()).zSum() / r0.size();
                DoubleMatrix1D copy = decorrelateVector.copy();
                copy.assign(new Scale(zSum));
                MatrixUtils.subtract(rowAvgs, copy);
                DoubleMatrix1D decorrelateVector2 = decorrelateVector(rowAvgs, make, i);
                d = delta(decorrelateVector, decorrelateVector2);
                System.err.println("delta is " + d + "on iteration " + i2);
                decorrelateVector = decorrelateVector2;
                i2++;
                if (listener != null && i2 % iterPerUpdate == 199 && !listener.markProgress(i, i2)) {
                    throw new CanceledExecutionException();
                }
            }
            for (int i3 = 0; i3 < numICs; i3++) {
                make.setQuick(i, i3, decorrelateVector.getQuick(i3));
            }
        }
        return make;
    }

    private static DoubleMatrix2D toMultRows(DoubleMatrix1D doubleMatrix1D, int i) {
        DenseDoubleMatrix2D denseDoubleMatrix2D = new DenseDoubleMatrix2D(i, doubleMatrix1D.size());
        for (int i2 = 0; i2 < i; i2++) {
            denseDoubleMatrix2D.viewRow(i2).assign(doubleMatrix1D.copy());
        }
        return denseDoubleMatrix2D;
    }

    private static DoubleMatrix2D symmetricDecorrelate(DoubleMatrix2D doubleMatrix2D) {
        DoubleMatrix2D viewDice = doubleMatrix2D.viewDice();
        viewDice.assign(new Scale(1.0d / Math.sqrt(a.norm1(a.mult(viewDice, viewDice.viewDice())))));
        double d = Double.MAX_VALUE;
        while (d > 1.0E-4d) {
            DoubleMatrix2D assign = viewDice.copy().assign(new Scale(1.5d));
            DoubleMatrix2D mult = a.mult(a.mult(viewDice, viewDice.viewDice()), viewDice);
            mult.assign(new Scale(0.5d));
            MatrixUtils.subtract(assign, mult);
            d = delta(viewDice, assign);
            viewDice = assign;
        }
        return viewDice.viewDice();
    }

    private static double delta(DoubleMatrix1D doubleMatrix1D, DoubleMatrix1D doubleMatrix1D2) {
        if (doubleMatrix1D.size() != doubleMatrix1D2.size()) {
            throw new IllegalArgumentException();
        }
        return Math.abs(Math.abs(a.mult(doubleMatrix1D, doubleMatrix1D2)) - Math.round(r0));
    }

    private static double delta(DoubleMatrix2D doubleMatrix2D, DoubleMatrix2D doubleMatrix2D2) {
        int rows = doubleMatrix2D.rows();
        double d = 0.0d;
        for (int i = 0; i < rows; i++) {
            double abs = Math.abs(a.mult(doubleMatrix2D.viewRow(i), doubleMatrix2D2.viewRow(i))) - Math.round(r0);
            if (abs > d) {
                d = abs;
            }
        }
        return d;
    }

    private static DoubleMatrix1D decorrelateVector(DoubleMatrix1D doubleMatrix1D, DoubleMatrix2D doubleMatrix2D, int i) {
        DoubleMatrix1D make = DoubleFactory1D.dense.make(doubleMatrix1D.size(), 0.0d);
        for (int i2 = 0; i2 < i; i2++) {
            MatrixUtils.incVector(make, doubleMatrix2D.viewRow(i2).copy().assign(new Scale(a.mult(doubleMatrix1D, doubleMatrix2D.viewRow(i2)))));
        }
        MatrixUtils.subtract(doubleMatrix1D, make);
        MatrixUtils.normalize(doubleMatrix1D);
        return doubleMatrix1D;
    }
}
