package com.rapidminer.operator.learner.functions.kernel;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.functions.kernel.rvm.ClassificationProblem;
import com.rapidminer.operator.learner.functions.kernel.rvm.ConstructiveRegression;
import com.rapidminer.operator.learner.functions.kernel.rvm.Parameter;
import com.rapidminer.operator.learner.functions.kernel.rvm.RVMClassification;
import com.rapidminer.operator.learner.functions.kernel.rvm.RVMRegression;
import com.rapidminer.operator.learner.functions.kernel.rvm.RegressionProblem;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelBasisFunction;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelCauchy;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelEpanechnikov;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelGaussianCombination;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelLaplace;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelMultiquadric;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelPoly;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelRadial;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelSigmoid;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.conditions.EqualTypeCondition;
import com.rapidminer.tools.RandomGenerator;
import java.util.List;
import org.apache.poi.ddf.EscherProperties;

/* loaded from: input_file:gen_lib/rapidminer.jar:com/rapidminer/operator/learner/functions/kernel/RVMLearner.class */
public class RVMLearner extends AbstractKernelBasedLearner {
    public static final String PARAMETER_RVM_TYPE = "rvm_type";
    public static final String PARAMETER_KERNEL_TYPE = "kernel_type";
    public static final String PARAMETER_MAX_ITERATION = "max_iteration";
    public static final String PARAMETER_MIN_DELTA_LOG_ALPHA = "min_delta_log_alpha";
    public static final String PARAMETER_ALPHA_MAX = "alpha_max";
    public static final String PARAMETER_KERNEL_LENGTHSCALE = "kernel_lengthscale";
    public static final String PARAMETER_KERNEL_DEGREE = "kernel_degree";
    public static final String PARAMETER_KERNEL_BIAS = "kernel_bias";
    public static final String PARAMETER_KERNEL_SIGMA1 = "kernel_sigma1";
    public static final String PARAMETER_KERNEL_SIGMA2 = "kernel_sigma2";
    public static final String PARAMETER_KERNEL_SIGMA3 = "kernel_sigma3";
    public static final String PARAMETER_KERNEL_SHIFT = "kernel_shift";
    public static final String PARAMETER_KERNEL_A = "kernel_a";
    public static final String PARAMETER_KERNEL_B = "kernel_b";
    public static final String[] RVM_TYPES = {"Regression-RVM", "Classification-RVM", "Constructive-Regression-RVM"};
    public static final String[] KERNEL_TYPES = {"rbf", "cauchy", "laplace", "poly", "sigmoid", "Epanechnikov", "gaussian combination", "multiquadric"};

    public RVMLearner(OperatorDescription operatorDescription) {
        super(operatorDescription);
    }

    @Override // com.rapidminer.operator.learner.CapabilityProvider
    public boolean supportsCapability(OperatorCapability operatorCapability) {
        return operatorCapability == OperatorCapability.NUMERICAL_ATTRIBUTES || operatorCapability == OperatorCapability.BINOMINAL_LABEL || operatorCapability == OperatorCapability.NUMERICAL_LABEL;
    }

    @Override // com.rapidminer.operator.learner.Learner
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        com.rapidminer.operator.learner.functions.kernel.rvm.Model learn;
        Parameter parameter = new Parameter();
        int size = exampleSet.size();
        int i = size + 1;
        parameter.min_delta_log_alpha = getParameterAsDouble(PARAMETER_MIN_DELTA_LOG_ALPHA);
        parameter.alpha_max = getParameterAsDouble(PARAMETER_ALPHA_MAX);
        parameter.maxIterations = getParameterAsInt("max_iteration");
        double[][] dArr = new double[size][exampleSet.getAttributes().size()];
        double[][] dArr2 = new double[size][1];
        int i2 = 0;
        for (Example example : exampleSet) {
            double[] dArr3 = new double[1];
            dArr3[0] = example.getLabel();
            dArr[i2] = RVMModel.makeInputVector(example);
            dArr2[i2] = dArr3;
            i2++;
        }
        Attribute label = exampleSet.getAttributes().getLabel();
        parameter.initAlpha = Math.pow(1.0d / size, 2.0d);
        parameter.initSigma = 0.1d;
        log("Creating kernel basis functions [" + KERNEL_TYPES[getParameterAsInt("kernel_type")] + "].");
        KernelBasisFunction[] createKernels = createKernels(dArr, i);
        String str = RVM_TYPES[getParameterAsInt(PARAMETER_RVM_TYPE)];
        if (!label.isNominal()) {
            RegressionProblem regressionProblem = new RegressionProblem(dArr, dArr2, createKernels);
            if (str.equals("Regression-RVM")) {
                try {
                    learn = new RVMRegression(regressionProblem, parameter).learn();
                } catch (ArrayIndexOutOfBoundsException e) {
                    throw new UserError(this, EscherProperties.GROUPSHAPE__BORDERLEFTCOLOR);
                }
            } else {
                if (!str.equals("Constructive-Regression-RVM")) {
                    throw new UserError(this, 207, str, PARAMETER_RVM_TYPE, "only one of the regression types can be used for the given regression problem");
                }
                try {
                    learn = new ConstructiveRegression(regressionProblem, parameter, getParameterAsBoolean(RandomGenerator.PARAMETER_USE_LOCAL_RANDOM_SEED), getParameterAsInt(RandomGenerator.PARAMETER_LOCAL_RANDOM_SEED)).learn();
                } catch (ArrayIndexOutOfBoundsException e2) {
                    throw new UserError(this, EscherProperties.GROUPSHAPE__BORDERLEFTCOLOR);
                }
            }
        } else {
            if (label.getMapping().size() != 2) {
                throw new UserError(this, 114, getName(), label.getName());
            }
            int[] iArr = new int[size];
            for (int i3 = 0; i3 < size; i3++) {
                iArr[i3] = (int) dArr2[i3][0];
            }
            ClassificationProblem classificationProblem = new ClassificationProblem(dArr, iArr, createKernels);
            if (!str.equals("Classification-RVM")) {
                throw new UserError(this, 207, str, PARAMETER_RVM_TYPE, "only Classification-RVM can be used for the given two class classification problem");
            }
            try {
                learn = new RVMClassification(classificationProblem, parameter).learn();
            } catch (ArrayIndexOutOfBoundsException e3) {
                throw new UserError(this, EscherProperties.GROUPSHAPE__BORDERLEFTCOLOR);
            }
        }
        return new RVMModel(exampleSet, learn);
    }

    public KernelBasisFunction[] createKernels(double[][] dArr, int i) throws OperatorException {
        KernelBasisFunction kernelBasisFunction;
        KernelBasisFunction[] kernelBasisFunctionArr = new KernelBasisFunction[i];
        double parameterAsDouble = getParameterAsDouble("kernel_lengthscale");
        double parameterAsDouble2 = getParameterAsDouble("kernel_bias");
        double parameterAsDouble3 = getParameterAsDouble("kernel_degree");
        double parameterAsDouble4 = getParameterAsDouble("kernel_a");
        double parameterAsDouble5 = getParameterAsDouble("kernel_b");
        double parameterAsDouble6 = getParameterAsDouble("kernel_sigma1");
        double parameterAsDouble7 = getParameterAsDouble("kernel_sigma2");
        double parameterAsDouble8 = getParameterAsDouble("kernel_sigma3");
        double parameterAsDouble9 = getParameterAsDouble("kernel_shift");
        for (int i2 = 0; i2 < i - 1; i2++) {
            double[] dArr2 = dArr[i2];
            switch (getParameterAsInt("kernel_type")) {
                case 0:
                    kernelBasisFunction = new KernelBasisFunction(new KernelRadial(parameterAsDouble), dArr2);
                    break;
                case 1:
                    kernelBasisFunction = new KernelBasisFunction(new KernelCauchy(parameterAsDouble), dArr2);
                    break;
                case 2:
                    kernelBasisFunction = new KernelBasisFunction(new KernelLaplace(parameterAsDouble), dArr2);
                    break;
                case 3:
                    kernelBasisFunction = new KernelBasisFunction(new KernelPoly(parameterAsDouble, parameterAsDouble2, parameterAsDouble3), dArr2);
                    break;
                case 4:
                    kernelBasisFunction = new KernelBasisFunction(new KernelSigmoid(parameterAsDouble4, parameterAsDouble5), dArr2);
                    break;
                case 5:
                    kernelBasisFunction = new KernelBasisFunction(new KernelEpanechnikov(parameterAsDouble6, parameterAsDouble3), dArr2);
                    break;
                case 6:
                    kernelBasisFunction = new KernelBasisFunction(new KernelGaussianCombination(parameterAsDouble6, parameterAsDouble7, parameterAsDouble8), dArr2);
                    break;
                case 7:
                    kernelBasisFunction = new KernelBasisFunction(new KernelMultiquadric(parameterAsDouble6, parameterAsDouble9), dArr2);
                    break;
                default:
                    kernelBasisFunction = new KernelBasisFunction(new KernelRadial(parameterAsDouble), dArr2);
                    break;
            }
            kernelBasisFunctionArr[i2 + 1] = kernelBasisFunction;
        }
        return kernelBasisFunctionArr;
    }

    @Override // com.rapidminer.operator.Operator, com.rapidminer.parameter.ParameterHandler
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        ParameterTypeCategory parameterTypeCategory = new ParameterTypeCategory(PARAMETER_RVM_TYPE, "Regression RVM", RVM_TYPES, 0);
        parameterTypeCategory.setExpert(false);
        parameterTypes.add(parameterTypeCategory);
        ParameterTypeCategory parameterTypeCategory2 = new ParameterTypeCategory("kernel_type", "The type of the kernel functions.", KERNEL_TYPES, 0);
        parameterTypeCategory2.setExpert(false);
        parameterTypes.add(parameterTypeCategory2);
        ParameterTypeDouble parameterTypeDouble = new ParameterTypeDouble("kernel_lengthscale", "The lengthscale used in all kernels.", 0.0d, Double.POSITIVE_INFINITY, 3.0d);
        parameterTypeDouble.registerDependencyCondition(new EqualTypeCondition(this, "kernel_type", KERNEL_TYPES, false, 0, 1, 2, 3));
        parameterTypeDouble.setExpert(false);
        parameterTypes.add(parameterTypeDouble);
        ParameterTypeDouble parameterTypeDouble2 = new ParameterTypeDouble("kernel_degree", "The degree used in the poly kernel.", 0.0d, Double.POSITIVE_INFINITY, 2.0d);
        parameterTypeDouble2.registerDependencyCondition(new EqualTypeCondition(this, "kernel_type", KERNEL_TYPES, false, 3, 5));
        parameterTypeDouble2.setExpert(false);
        parameterTypes.add(parameterTypeDouble2);
        ParameterTypeDouble parameterTypeDouble3 = new ParameterTypeDouble("kernel_bias", "The bias used in the poly kernel.", 0.0d, Double.POSITIVE_INFINITY, 1.0d);
        parameterTypeDouble3.registerDependencyCondition(new EqualTypeCondition(this, "kernel_type", KERNEL_TYPES, false, 3));
        parameterTypeDouble3.setExpert(false);
        parameterTypes.add(parameterTypeDouble3);
        ParameterTypeDouble parameterTypeDouble4 = new ParameterTypeDouble("kernel_sigma1", "The SVM kernel parameter sigma1 (Epanechnikov, Gaussian Combination, Multiquadric).", 0.0d, Double.POSITIVE_INFINITY, 1.0d);
        parameterTypeDouble4.registerDependencyCondition(new EqualTypeCondition(this, "kernel_type", KERNEL_TYPES, false, 5, 6, 7));
        parameterTypeDouble4.setExpert(false);
        parameterTypes.add(parameterTypeDouble4);
        ParameterTypeDouble parameterTypeDouble5 = new ParameterTypeDouble("kernel_sigma2", "The SVM kernel parameter sigma2 (Gaussian Combination).", 0.0d, Double.POSITIVE_INFINITY, 0.0d);
        parameterTypeDouble5.registerDependencyCondition(new EqualTypeCondition(this, "kernel_type", KERNEL_TYPES, false, 6));
        parameterTypeDouble5.setExpert(false);
        parameterTypes.add(parameterTypeDouble5);
        ParameterTypeDouble parameterTypeDouble6 = new ParameterTypeDouble("kernel_sigma3", "The SVM kernel parameter sigma3 (Gaussian Combination).", 0.0d, Double.POSITIVE_INFINITY, 2.0d);
        parameterTypeDouble6.registerDependencyCondition(new EqualTypeCondition(this, "kernel_type", KERNEL_TYPES, false, 6));
        parameterTypeDouble6.setExpert(false);
        parameterTypes.add(parameterTypeDouble6);
        ParameterTypeDouble parameterTypeDouble7 = new ParameterTypeDouble("kernel_shift", "The SVM kernel parameter shift (Multiquadric).", Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 1.0d);
        parameterTypeDouble7.registerDependencyCondition(new EqualTypeCondition(this, "kernel_type", KERNEL_TYPES, false, 7));
        parameterTypeDouble7.setExpert(false);
        parameterTypes.add(parameterTypeDouble7);
        ParameterTypeDouble parameterTypeDouble8 = new ParameterTypeDouble("kernel_a", "The SVM kernel parameter a (neural).", Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 1.0d);
        parameterTypeDouble8.registerDependencyCondition(new EqualTypeCondition(this, "kernel_type", KERNEL_TYPES, false, 4));
        parameterTypeDouble8.setExpert(false);
        parameterTypes.add(parameterTypeDouble8);
        ParameterTypeDouble parameterTypeDouble9 = new ParameterTypeDouble("kernel_b", "The SVM kernel parameter b (neural).", Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 0.0d);
        parameterTypeDouble9.registerDependencyCondition(new EqualTypeCondition(this, "kernel_type", KERNEL_TYPES, false, 4));
        parameterTypeDouble9.setExpert(false);
        parameterTypes.add(parameterTypeDouble9);
        parameterTypes.add(new ParameterTypeInt("max_iteration", "The maximum number of iterations used.", 1, Integer.MAX_VALUE, 100));
        parameterTypes.add(new ParameterTypeDouble(PARAMETER_MIN_DELTA_LOG_ALPHA, "Abort iteration if largest log alpha change is smaller than this", 0.0d, Double.POSITIVE_INFINITY, 0.001d));
        parameterTypes.add(new ParameterTypeDouble(PARAMETER_ALPHA_MAX, "Prune basis function if its alpha is bigger than this", 0.0d, Double.POSITIVE_INFINITY, 1.0E12d));
        parameterTypes.addAll(RandomGenerator.getRandomGeneratorParameters(this));
        return parameterTypes;
    }
}
