package com.rapidminer.operator.learner.functions.kernel.rvm;

import Jama.Matrix;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelBasisFunction;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelRadial;
import com.rapidminer.operator.learner.functions.kernel.rvm.util.SECholeskyDecomposition;
import java.util.Iterator;
import java.util.LinkedList;

/* loaded from: input_file:gen_lib/rapidminer.jar:com/rapidminer/operator/learner/functions/kernel/rvm/RVMRegression.class */
public class RVMRegression extends RVMBase {
    public RVMRegression(RegressionProblem regressionProblem, Parameter parameter) {
        super(regressionProblem, parameter);
    }

    @Override // com.rapidminer.operator.learner.functions.kernel.rvm.RVMBase
    public Model learn() {
        RegressionProblem regressionProblem = (RegressionProblem) this.problem;
        int problemSize = regressionProblem.getProblemSize();
        int i = problemSize + 1;
        double pow = 1.0d / Math.pow(this.parameter.initSigma, 2.0d);
        double[][] inputVectors = regressionProblem.getInputVectors();
        KernelBasisFunction[] kernels = regressionProblem.getKernels();
        double[][] dArr = new double[problemSize][i];
        for (int i2 = 0; i2 < i - 1; i2++) {
            for (int i3 = 0; i3 < problemSize; i3++) {
                dArr[i3][i2 + 1] = kernels[i2 + 1].eval(inputVectors[i3]);
            }
        }
        for (int i4 = 0; i4 < problemSize; i4++) {
            dArr[i4][0] = 1.0d;
        }
        double[] dArr2 = new double[i];
        for (int i5 = 0; i5 < i; i5++) {
            dArr2[i5] = this.parameter.initAlpha;
        }
        Matrix matrix = new Matrix(dArr);
        Matrix matrix2 = new Matrix(regressionProblem.getTargetVectors());
        Matrix matrix3 = new Matrix(dArr2, i);
        Matrix times = matrix.transpose().times(matrix2);
        int[] iArr = null;
        Matrix matrix4 = null;
        for (int i6 = 1; i6 <= this.parameter.maxIterations; i6++) {
            LinkedList linkedList = new LinkedList();
            for (int i7 = 0; i7 < i; i7++) {
                if (matrix3.get(i7, 0) < this.parameter.alpha_max) {
                    linkedList.add(Integer.valueOf(i7));
                }
            }
            iArr = new int[linkedList.size()];
            Iterator it = linkedList.iterator();
            for (int i8 = 0; i8 < linkedList.size(); i8++) {
                iArr[i8] = ((Integer) it.next()).intValue();
            }
            Matrix matrix5 = matrix.getMatrix(0, matrix.getRowDimension() - 1, iArr);
            Matrix matrix6 = times.getMatrix(iArr, 0, 0);
            Matrix matrix7 = matrix3.getMatrix(iArr, 0, 0);
            Matrix matrix8 = new Matrix(matrix7.getRowDimension(), matrix7.getRowDimension(), 0.0d);
            for (int i9 = 0; i9 < matrix7.getRowDimension(); i9++) {
                matrix8.set(i9, i9, matrix7.get(i9, 0));
            }
            Matrix times2 = matrix5.transpose().times(matrix5);
            times2.timesEquals(pow);
            times2.plusEquals(matrix8);
            SECholeskyDecomposition sECholeskyDecomposition = new SECholeskyDecomposition(times2.getArray());
            Matrix inverse = sECholeskyDecomposition.getPTR().times(sECholeskyDecomposition.getL()).inverse();
            matrix4 = inverse.transpose().times(inverse.times(matrix6)).times(pow);
            double[] dArr3 = new double[inverse.getRowDimension()];
            for (int i10 = 0; i10 < dArr3.length; i10++) {
                double d = 0.0d;
                for (int i11 = 0; i11 < dArr3.length; i11++) {
                    d += inverse.get(i11, i10) * inverse.get(i11, i10);
                }
                dArr3[i10] = d;
            }
            double[] dArr4 = new double[dArr3.length];
            for (int i12 = 0; i12 < dArr4.length; i12++) {
                dArr4[i12] = 1.0d - (matrix7.get(i12, 0) * dArr3[i12]);
            }
            double[] dArr5 = new double[matrix7.getRowDimension()];
            for (int i13 = 0; i13 < dArr5.length; i13++) {
                dArr5[i13] = Math.log(matrix7.get(i13, 0));
            }
            for (int i14 = 0; i14 < matrix7.getRowDimension(); i14++) {
                matrix7.set(i14, 0, dArr4[i14] / (matrix4.get(i14, 0) * matrix4.get(i14, 0)));
            }
            double d2 = 0.0d;
            for (int i15 = 0; i15 < dArr5.length; i15++) {
                double abs = Math.abs(dArr5[i15] - Math.log(matrix7.get(i15, 0)));
                if (abs > d2) {
                    d2 = abs;
                }
            }
            if (d2 < this.parameter.min_delta_log_alpha) {
                break;
            }
            double d3 = 0.0d;
            Matrix minus = matrix2.minus(matrix5.times(matrix4));
            for (int i16 = 0; i16 < problemSize; i16++) {
                d3 += minus.get(i16, 0) * minus.get(i16, 0);
            }
            double d4 = 0.0d;
            for (double d5 : dArr4) {
                d4 += d5;
            }
            pow = (problemSize - d4) / d3;
            for (int i17 = 0; i17 < matrix7.getRowDimension(); i17++) {
                matrix3.set(iArr[i17], 0, matrix7.get(i17, 0));
            }
        }
        double[] dArr6 = new double[iArr.length];
        KernelBasisFunction[] kernelBasisFunctionArr = new KernelBasisFunction[iArr.length];
        boolean z = false;
        for (int i18 = 0; i18 < iArr.length; i18++) {
            dArr6[i18] = matrix4.get(i18, 0);
            if (iArr[i18] == 0) {
                z = true;
                kernelBasisFunctionArr[i18] = new KernelBasisFunction(new KernelRadial());
            } else {
                kernelBasisFunctionArr[i18] = kernels[iArr[i18]];
            }
        }
        return new Model(dArr6, kernelBasisFunctionArr, z, true);
    }

    public String toString() {
        return "Regression-RVM";
    }
}
