package com.rapidminer.operator.learner.functions.kernel.rvm;

import Jama.Matrix;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelBasisFunction;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelRadial;
import com.rapidminer.operator.learner.functions.kernel.rvm.util.SECholeskyDecomposition;
import java.util.Iterator;
import java.util.LinkedList;

/* loaded from: input_file:gen_lib/rapidminer.jar:com/rapidminer/operator/learner/functions/kernel/rvm/RVMClassification.class */
public class RVMClassification extends RVMBase {
    public RVMClassification(ClassificationProblem classificationProblem, Parameter parameter) {
        super(classificationProblem, parameter);
    }

    @Override // com.rapidminer.operator.learner.functions.kernel.rvm.RVMBase
    public Model learn() {
        double d;
        double log;
        double d2;
        double log2;
        ClassificationProblem classificationProblem = (ClassificationProblem) this.problem;
        int problemSize = classificationProblem.getProblemSize();
        int i = problemSize + 1;
        int i2 = (this.parameter.maxIterations * 50) / 100;
        double[][] inputVectors = classificationProblem.getInputVectors();
        KernelBasisFunction[] kernels = classificationProblem.getKernels();
        double[][] dArr = new double[problemSize][i];
        for (int i3 = 0; i3 < i - 1; i3++) {
            for (int i4 = 0; i4 < problemSize; i4++) {
                dArr[i4][i3 + 1] = kernels[i3 + 1].eval(inputVectors[i4]);
            }
        }
        for (int i5 = 0; i5 < problemSize; i5++) {
            dArr[i5][0] = 1.0d;
        }
        double[] dArr2 = new double[i];
        for (int i6 = 0; i6 < i; i6++) {
            dArr2[i6] = this.parameter.initAlpha;
        }
        Matrix matrix = new Matrix(dArr);
        Matrix matrix2 = new Matrix(dArr2, i);
        Matrix matrix3 = new Matrix(i, 1, 0.0d);
        Matrix matrix4 = null;
        Matrix matrix5 = null;
        int[] iArr = null;
        for (int i7 = 1; i7 <= this.parameter.maxIterations; i7++) {
            LinkedList linkedList = new LinkedList();
            for (int i8 = 0; i8 < i; i8++) {
                if (matrix2.get(i8, 0) < this.parameter.alpha_max) {
                    linkedList.add(Integer.valueOf(i8));
                }
            }
            iArr = new int[linkedList.size()];
            Iterator it = linkedList.iterator();
            for (int i9 = 0; i9 < linkedList.size(); i9++) {
                iArr[i9] = ((Integer) it.next()).intValue();
            }
            Matrix matrix6 = matrix.getMatrix(0, matrix.getRowDimension() - 1, iArr);
            Matrix matrix7 = matrix2.getMatrix(iArr, 0, 0);
            matrix4 = matrix3.getMatrix(iArr, 0, 0);
            double pow = Math.pow(2.0d, -8.0d);
            Matrix matrix8 = new Matrix(matrix7.getRowDimension(), matrix7.getRowDimension(), 0.0d);
            for (int i10 = 0; i10 < matrix7.getRowDimension(); i10++) {
                matrix8.set(i10, i10, matrix7.get(i10, 0));
            }
            Matrix times = matrix6.times(matrix4);
            for (int i11 = 0; i11 < times.getRowDimension(); i11++) {
                times.set(i11, 0, sigmoid(times.get(i11, 0)));
            }
            double d3 = 0.0d;
            int[] targetVectors = classificationProblem.getTargetVectors();
            for (int i12 = 0; i12 < targetVectors.length; i12++) {
                if (targetVectors[i12] == 1) {
                    d2 = d3;
                    log2 = Math.log(times.get(i12, 0));
                } else {
                    d2 = d3;
                    log2 = Math.log(1.0d - times.get(i12, 0));
                }
                d3 = d2 - log2;
            }
            double d4 = 0.0d;
            for (int i13 = 0; i13 < matrix7.getRowDimension(); i13++) {
                d4 += matrix7.get(i13, 0) * matrix4.get(i13, 0) * matrix4.get(i13, 0);
            }
            double problemSize2 = (d3 + (d4 / 2.0d)) / classificationProblem.getProblemSize();
            for (int i14 = 0; i14 < 25; i14++) {
                Matrix matrix9 = new Matrix(matrix6.getRowDimension(), matrix6.getRowDimension(), 0.0d);
                for (int i15 = 0; i15 < matrix9.getRowDimension(); i15++) {
                    matrix9.set(i15, i15, times.get(i15, 0) * (1.0d - times.get(i15, 0)));
                }
                Matrix times2 = matrix6.transpose().times(matrix9).times(matrix6);
                times2.plusEquals(matrix8);
                Matrix matrix10 = new Matrix(times.getRowDimension(), 1, 0.0d);
                for (int i16 = 0; i16 < times.getRowDimension(); i16++) {
                    matrix10.set(i16, 0, targetVectors[i16] - times.get(i16, 0));
                }
                Matrix matrix11 = (Matrix) matrix7.clone();
                for (int i17 = 0; i17 < matrix11.getRowDimension(); i17++) {
                    matrix11.set(i17, 0, matrix11.get(i17, 0) * matrix4.get(i17, 0));
                }
                Matrix minus = matrix6.transpose().times(matrix10).minus(matrix11);
                SECholeskyDecomposition sECholeskyDecomposition = new SECholeskyDecomposition(times2.getArray());
                matrix5 = sECholeskyDecomposition.getPTR().times(sECholeskyDecomposition.getL()).inverse();
                if (i14 >= 2 && minus.normF() / matrix4.getRowDimension() < 1.0E-6d) {
                    break;
                }
                Matrix times3 = matrix5.transpose().times(matrix5.times(minus));
                double d5 = 1.0d;
                while (true) {
                    double d6 = d5;
                    if (d6 > pow) {
                        Matrix plus = ((Matrix) matrix4.clone()).plus(times3.times(d6));
                        times = matrix6.times(plus);
                        for (int i18 = 0; i18 < times.getRowDimension(); i18++) {
                            times.set(i18, 0, sigmoid(times.get(i18, 0)));
                        }
                        double d7 = 0.0d;
                        for (int i19 = 0; i19 < targetVectors.length; i19++) {
                            if (targetVectors[i19] == 1) {
                                d = d7;
                                log = Math.log(times.get(i19, 0));
                            } else {
                                d = d7;
                                log = Math.log(1.0d - times.get(i19, 0));
                            }
                            d7 = d - log;
                        }
                        double d8 = 0.0d;
                        for (int i20 = 0; i20 < matrix7.getRowDimension(); i20++) {
                            d8 += matrix7.get(i20, 0) * plus.get(i20, 0) * plus.get(i20, 0);
                        }
                        if ((d7 + (d8 / 2.0d)) / classificationProblem.getProblemSize() <= problemSize2) {
                            matrix4 = plus;
                            break;
                        }
                        d5 = d6 / 2.0d;
                    }
                }
            }
            double[] dArr3 = new double[matrix5.getRowDimension()];
            for (int i21 = 0; i21 < dArr3.length; i21++) {
                double d9 = 0.0d;
                for (int i22 = 0; i22 < dArr3.length; i22++) {
                    d9 += matrix5.get(i22, i21) * matrix5.get(i22, i21);
                }
                dArr3[i21] = d9;
            }
            double[] dArr4 = new double[dArr3.length];
            for (int i23 = 0; i23 < dArr4.length; i23++) {
                dArr4[i23] = 1.0d - (matrix7.get(i23, 0) * dArr3[i23]);
            }
            double[] dArr5 = new double[matrix7.getRowDimension()];
            for (int i24 = 0; i24 < dArr5.length; i24++) {
                dArr5[i24] = Math.log(matrix7.get(i24, 0));
            }
            for (int i25 = 0; i25 < matrix7.getRowDimension(); i25++) {
                matrix7.set(i25, 0, dArr4[i25] / (matrix4.get(i25, 0) * matrix4.get(i25, 0)));
            }
            double d10 = 0.0d;
            for (int i26 = 0; i26 < dArr5.length; i26++) {
                double abs = Math.abs(dArr5[i26] - Math.log(matrix7.get(i26, 0)));
                if (abs > d10) {
                    d10 = abs;
                }
            }
            if (d10 < this.parameter.min_delta_log_alpha) {
                break;
            }
            for (int i27 = 0; i27 < matrix7.getRowDimension(); i27++) {
                matrix2.set(iArr[i27], 0, matrix7.get(i27, 0));
            }
        }
        double[] dArr6 = new double[iArr.length];
        KernelBasisFunction[] kernelBasisFunctionArr = new KernelBasisFunction[iArr.length];
        boolean z = false;
        for (int i28 = 0; i28 < iArr.length; i28++) {
            dArr6[i28] = matrix4.get(i28, 0);
            if (iArr[i28] == 0) {
                z = true;
                kernelBasisFunctionArr[i28] = new KernelBasisFunction(new KernelRadial());
            } else {
                kernelBasisFunctionArr[i28] = kernels[iArr[i28]];
            }
        }
        return new Model(dArr6, kernelBasisFunctionArr, z, false);
    }

    public double sigmoid(double d) {
        return 1.0d / (1.0d + Math.exp(-d));
    }

    public String toString() {
        return "Classification-RVM";
    }
}
