package com.rapidminer.operator.learner.functions.linear;

import Jama.Matrix;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Statistics;
import com.rapidminer.example.Tools;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.annotation.ResourceConsumptionEstimator;
import com.rapidminer.operator.features.weighting.FeatureWeighting;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.functions.LinearRegressionModel;
import com.rapidminer.operator.learner.functions.linear.LinearRegressionMethod;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.parameter.conditions.BooleanParameterCondition;
import com.rapidminer.parameter.conditions.EqualTypeCondition;
import com.rapidminer.tools.OperatorResourceConsumptionHandler;
import com.rapidminer.tools.math.FDistribution;
import com.rapidminer.tools.math.MathFunctions;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:gen_lib/rapidminer.jar:com/rapidminer/operator/learner/functions/linear/LinearRegression.class */
public class LinearRegression extends AbstractLearner {
    public static final String PARAMETER_FEATURE_SELECTION = "feature_selection";
    public static final String PARAMETER_ELIMINATE_COLINEAR_FEATURES = "eliminate_colinear_features";
    public static final String PARAMETER_USE_BIAS = "use_bias";
    public static final String PARAMETER_MIN_TOLERANCE = "min_tolerance";
    public static final String PARAMETER_RIDGE = "ridge";
    public static final Map<String, Class<? extends LinearRegressionMethod>> SELECTION_METHODS = new LinkedHashMap();
    public static final int NO_SELECTION = 0;
    public static final int M5_PRIME = 1;
    public static final int GREEDY = 2;
    private OutputPort weightOutput;

    public LinearRegression(OperatorDescription operatorDescription) {
        super(operatorDescription);
        this.weightOutput = getOutputPorts().createPort(FeatureWeighting.PARAMETER_WEIGHTS);
        getTransformer().addGenerationRule(this.weightOutput, AttributeWeights.class);
    }

    @Override // com.rapidminer.operator.learner.Learner
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        Attribute label = exampleSet.getAttributes().getLabel();
        Attribute attribute = label;
        boolean z = false;
        String str = null;
        String str2 = null;
        Tools.onlyNonMissingValues(exampleSet, "Linear Regression");
        boolean parameterAsBoolean = getParameterAsBoolean("use_bias");
        boolean parameterAsBoolean2 = getParameterAsBoolean(PARAMETER_ELIMINATE_COLINEAR_FEATURES);
        double parameterAsDouble = getParameterAsDouble("ridge");
        double parameterAsDouble2 = getParameterAsDouble(PARAMETER_MIN_TOLERANCE);
        if (label.isNominal() && label.getMapping().size() == 2) {
            str = label.getMapping().getNegativeString();
            str2 = label.getMapping().getPositiveString();
            int negativeIndex = label.getMapping().getNegativeIndex();
            attribute = AttributeFactory.createAttribute("regression_label", 4);
            exampleSet.getExampleTable().addAttribute(attribute);
            for (Example example : exampleSet) {
                if (example.getValue(label) == negativeIndex) {
                    example.setValue(attribute, 0.0d);
                } else {
                    example.setValue(attribute, 1.0d);
                }
            }
            exampleSet.getAttributes().setLabel(attribute);
            z = true;
        }
        int size = exampleSet.getAttributes().size();
        boolean[] zArr = new boolean[size];
        int i = 0;
        String[] strArr = new String[size];
        for (Attribute attribute2 : exampleSet.getAttributes()) {
            zArr[i] = attribute2.isNumerical();
            strArr[i] = attribute2.getName();
            i++;
        }
        exampleSet.recalculateAllAttributeStatistics();
        double[] dArr = new double[size];
        double[] dArr2 = new double[size];
        int i2 = 0;
        Attribute[] attributeArr = new Attribute[exampleSet.getAttributes().size()];
        for (Attribute attribute3 : exampleSet.getAttributes()) {
            attributeArr[i2] = attribute3;
            if (zArr[i2]) {
                dArr[i2] = exampleSet.getStatistics(attribute3, Statistics.AVERAGE_WEIGHTED);
                dArr2[i2] = Math.sqrt(exampleSet.getStatistics(attribute3, Statistics.VARIANCE_WEIGHTED));
                if (dArr2[i2] == 0.0d) {
                    zArr[i2] = false;
                }
            }
            i2++;
        }
        double statistics = exampleSet.getStatistics(attribute, Statistics.AVERAGE_WEIGHTED);
        double sqrt = Math.sqrt(exampleSet.getStatistics(attribute, Statistics.VARIANCE_WEIGHTED));
        int size2 = exampleSet.size();
        int i3 = 1;
        for (boolean z2 : zArr) {
            if (z2) {
                i3++;
            }
        }
        double[] performRegression = performRegression(exampleSet, zArr, dArr, statistics, parameterAsDouble);
        if (parameterAsBoolean2) {
            boolean z3 = true;
            while (z3) {
                int i4 = -1;
                double d = 1.0d;
                boolean z4 = false;
                for (int i5 = 0; i5 < zArr.length; i5++) {
                    if (zArr[i5]) {
                        double tolerance = getTolerance(exampleSet, zArr, i5, parameterAsDouble, parameterAsBoolean);
                        if (tolerance < parameterAsDouble2 && tolerance <= d) {
                            d = tolerance;
                            i4 = i5;
                            z4 = true;
                        }
                    }
                }
                if (z4) {
                    zArr[i4] = false;
                } else {
                    z3 = false;
                }
                performRegression = performRegression(exampleSet, zArr, dArr, statistics, parameterAsDouble);
            }
        } else {
            performRegression = performRegression(exampleSet, zArr, dArr, statistics, parameterAsDouble);
        }
        double squaredError = getSquaredError(exampleSet, zArr, performRegression, parameterAsBoolean);
        Class<? extends LinearRegressionMethod> cls = SELECTION_METHODS.get(((String[]) SELECTION_METHODS.keySet().toArray(new String[SELECTION_METHODS.size()]))[getParameterAsInt(PARAMETER_FEATURE_SELECTION)]);
        if (cls == null) {
            throw new UserError(this, 904, PARAMETER_FEATURE_SELECTION, "unknown method");
        }
        try {
            LinearRegressionMethod.LinearRegressionResult applyMethod = cls.newInstance().applyMethod(this, parameterAsBoolean, parameterAsDouble, exampleSet, zArr, size2, i3, dArr, statistics, dArr2, sqrt, performRegression, squaredError);
            if (z) {
                exampleSet.getAttributes().remove(attribute);
                exampleSet.getExampleTable().removeAttribute(attribute);
                exampleSet.getAttributes().setLabel(label);
            }
            FDistribution fDistribution = new FDistribution(1, exampleSet.size() - applyMethod.coefficients.length);
            int length = applyMethod.coefficients.length;
            double[] dArr3 = new double[length];
            double[] dArr4 = new double[length];
            double[] dArr5 = new double[length];
            double[] dArr6 = new double[length];
            double[] dArr7 = new double[length];
            double size3 = applyMethod.error / (exampleSet.size() - 1);
            int i6 = 0;
            for (boolean z5 : applyMethod.isUsedAttribute) {
                if (z5) {
                    i6++;
                }
            }
            double[][] dArr8 = new double[exampleSet.size() + 1][i6 + 1];
            for (int i7 = 0; i7 < dArr8[0].length; i7++) {
                dArr8[0][i7] = 1.0d;
            }
            for (int i8 = 0; i8 < exampleSet.size() + 1; i8++) {
                dArr8[i8][0] = 1.0d;
            }
            int i9 = 1;
            for (Example example2 : exampleSet) {
                int i10 = 0;
                int i11 = 1;
                for (Attribute attribute4 : exampleSet.getAttributes()) {
                    if (applyMethod.isUsedAttribute[i10]) {
                        dArr8[i9][i11] = example2.getValue(attribute4);
                        i11++;
                    }
                    i10++;
                }
                i9++;
            }
            Matrix matrix = new Matrix(dArr8);
            Matrix matrix2 = null;
            try {
                matrix2 = matrix.transpose().times(matrix).inverse();
                int i12 = 0;
                for (int i13 = 0; i13 < applyMethod.isUsedAttribute.length; i13++) {
                    if (applyMethod.isUsedAttribute[i13]) {
                        dArr5[i12] = getTolerance(exampleSet, applyMethod.isUsedAttribute, i13, parameterAsDouble, parameterAsBoolean);
                        dArr3[i12] = Math.sqrt(size3 * matrix2.get(i12 + 1, i12 + 1));
                        dArr4[i12] = (applyMethod.coefficients[i12] * dArr2[i13]) / dArr[i13];
                        if (!com.rapidminer.tools.Tools.isZero(dArr3[i12])) {
                            dArr6[i12] = applyMethod.coefficients[i12] / dArr3[i12];
                            double probabilityForValue = fDistribution.getProbabilityForValue(dArr6[i12] * dArr6[i12]);
                            dArr7[i12] = probabilityForValue < 0.0d ? 1.0d : Math.max(0.0d, 1.0d - probabilityForValue);
                        } else if (com.rapidminer.tools.Tools.isZero(applyMethod.coefficients[i12])) {
                            dArr6[i12] = 0.0d;
                            dArr7[i12] = 1.0d;
                        } else {
                            dArr6[i12] = Double.POSITIVE_INFINITY;
                            dArr7[i12] = 0.0d;
                        }
                        i12++;
                    }
                }
            } catch (Throwable th) {
                double correlation = getCorrelation(exampleSet, zArr, performRegression, parameterAsBoolean);
                double min = Math.min(correlation * correlation, 1.0d);
                int i14 = 0;
                for (int i15 = 0; i15 < applyMethod.isUsedAttribute.length; i15++) {
                    if (applyMethod.isUsedAttribute[i15]) {
                        double tolerance2 = getTolerance(exampleSet, applyMethod.isUsedAttribute, i15, parameterAsDouble, parameterAsBoolean);
                        dArr3[i14] = (Math.sqrt((1.0d - min) / (tolerance2 * ((exampleSet.size() - exampleSet.getAttributes().size()) - 1.0d))) * sqrt) / dArr2[i15];
                        dArr5[i14] = tolerance2;
                        dArr4[i14] = (applyMethod.coefficients[i14] * dArr2[i15]) / dArr[i15];
                        if (!com.rapidminer.tools.Tools.isZero(dArr3[i14])) {
                            dArr6[i14] = applyMethod.coefficients[i14] / dArr3[i14];
                            double probabilityForValue2 = fDistribution.getProbabilityForValue(dArr6[i14] * dArr6[i14]);
                            dArr7[i14] = probabilityForValue2 < 0.0d ? 1.0d : Math.max(0.0d, 1.0d - probabilityForValue2);
                        } else if (com.rapidminer.tools.Tools.isZero(applyMethod.coefficients[i14])) {
                            dArr6[i14] = 0.0d;
                            dArr7[i14] = 1.0d;
                        } else {
                            dArr6[i14] = Double.POSITIVE_INFINITY;
                            dArr7[i14] = 0.0d;
                        }
                        i14++;
                    }
                }
            }
            if (matrix2 == null) {
                dArr3[dArr3.length - 1] = Double.POSITIVE_INFINITY;
            } else {
                dArr3[dArr3.length - 1] = Math.sqrt(size3 * matrix2.get(0, 0));
            }
            dArr5[dArr5.length - 1] = Double.NaN;
            dArr4[dArr4.length - 1] = Double.NaN;
            if (!com.rapidminer.tools.Tools.isZero(dArr3[dArr3.length - 1])) {
                dArr6[dArr6.length - 1] = applyMethod.coefficients[applyMethod.coefficients.length - 1] / dArr3[dArr3.length - 1];
                double probabilityForValue3 = fDistribution.getProbabilityForValue(dArr6[dArr6.length - 1] * dArr6[dArr6.length - 1]);
                dArr7[dArr7.length - 1] = probabilityForValue3 < 0.0d ? 1.0d : Math.max(0.0d, 1.0d - probabilityForValue3);
            } else if (com.rapidminer.tools.Tools.isZero(applyMethod.coefficients[applyMethod.coefficients.length - 1])) {
                dArr6[dArr6.length - 1] = 0.0d;
                dArr7[dArr7.length - 1] = 1.0d;
            } else {
                dArr6[dArr6.length - 1] = Double.POSITIVE_INFINITY;
                dArr7[dArr7.length - 1] = 0.0d;
            }
            if (this.weightOutput.isConnected()) {
                AttributeWeights attributeWeights = new AttributeWeights(exampleSet);
                int i16 = 0;
                for (int i17 = 0; i17 < strArr.length; i17++) {
                    if (zArr[i17]) {
                        attributeWeights.setWeight(strArr[i17], applyMethod.coefficients[i16]);
                        i16++;
                    } else {
                        attributeWeights.setWeight(strArr[i17], 0.0d);
                    }
                }
                this.weightOutput.deliver(attributeWeights);
            }
            return new LinearRegressionModel(exampleSet, applyMethod.isUsedAttribute, applyMethod.coefficients, dArr3, dArr4, dArr5, dArr6, dArr7, parameterAsBoolean, str, str2);
        } catch (IllegalAccessException e) {
            throw new UserError(this, 904, PARAMETER_FEATURE_SELECTION, e.getMessage());
        } catch (InstantiationException e2) {
            throw new UserError(this, 904, PARAMETER_FEATURE_SELECTION, e2.getMessage());
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double getTolerance(ExampleSet exampleSet, boolean[] zArr, int i, double d, boolean z) throws UndefinedParameterError {
        LinkedList linkedList = new LinkedList();
        Attribute attribute = null;
        int i2 = 0;
        for (Attribute attribute2 : exampleSet.getAttributes()) {
            if (zArr[i2]) {
                if (i2 != i) {
                    linkedList.add(attribute2);
                } else {
                    attribute = attribute2;
                }
            }
            i2++;
        }
        Attribute[] attributeArr = new Attribute[linkedList.size()];
        linkedList.toArray(attributeArr);
        double[] performRegression = performRegression(exampleSet, attributeArr, attribute, d);
        double[] dArr = new double[exampleSet.size()];
        double[] dArr2 = new double[exampleSet.size()];
        int i3 = 0;
        for (Example example : exampleSet) {
            dArr[i3] = example.getValue(attribute);
            int i4 = 0;
            double d2 = 0.0d;
            for (Attribute attribute3 : attributeArr) {
                d2 += performRegression[i4] * example.getValue(attribute3);
                i4++;
            }
            if (z) {
                d2 += performRegression[performRegression.length - 1];
            }
            dArr2[i3] = d2;
            i3++;
        }
        double correlation = MathFunctions.correlation(dArr, dArr2);
        return 1.0d - (correlation * correlation);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double getSquaredError(ExampleSet exampleSet, boolean[] zArr, double[] dArr, boolean z) {
        double d = 0.0d;
        for (Example example : exampleSet) {
            double regressionPrediction = regressionPrediction(example, zArr, dArr, z) - example.getLabel();
            d += regressionPrediction * regressionPrediction;
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double getCorrelation(ExampleSet exampleSet, boolean[] zArr, double[] dArr, boolean z) {
        double[] dArr2 = new double[exampleSet.size()];
        double[] dArr3 = new double[exampleSet.size()];
        int i = 0;
        for (Example example : exampleSet) {
            dArr2[i] = example.getLabel();
            dArr3[i] = regressionPrediction(example, zArr, dArr, z);
            i++;
        }
        return MathFunctions.correlation(dArr2, dArr3);
    }

    private double regressionPrediction(Example example, boolean[] zArr, double[] dArr, boolean z) {
        double d = 0.0d;
        int i = 0;
        int i2 = 0;
        for (Attribute attribute : example.getAttributes()) {
            int i3 = i2;
            i2++;
            if (zArr[i3]) {
                d += dArr[i] * example.getValue(attribute);
                i++;
            }
        }
        if (z) {
            d += dArr[i];
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double[] performRegression(ExampleSet exampleSet, boolean[] zArr, double[] dArr, double d, double d2) throws UndefinedParameterError {
        int i = 0;
        for (boolean z : zArr) {
            if (z) {
                i++;
            }
        }
        Matrix matrix = null;
        Matrix matrix2 = null;
        double[] dArr2 = null;
        if (i > 0) {
            matrix = new Matrix(exampleSet.size(), i);
            matrix2 = new Matrix(exampleSet.size(), 1);
            int i2 = 0;
            dArr2 = new double[exampleSet.size()];
            Attribute weight = exampleSet.getAttributes().getWeight();
            for (Example example : exampleSet) {
                int i3 = 0;
                matrix2.set(i2, 0, example.getLabel());
                int i4 = 0;
                for (Attribute attribute : exampleSet.getAttributes()) {
                    if (zArr[i4]) {
                        matrix.set(i2, i3, example.getValue(attribute) - dArr[i4]);
                        i3++;
                    }
                    i4++;
                }
                if (weight != null) {
                    dArr2[i2] = example.getValue(weight);
                } else {
                    dArr2[i2] = 1.0d;
                }
                i2++;
            }
        }
        double[] dArr3 = new double[i + 1];
        if (i > 0) {
            System.arraycopy(com.rapidminer.tools.math.LinearRegression.performRegression(matrix, matrix2, dArr2, d2), 0, dArr3, 0, i);
        }
        dArr3[i] = d;
        int i5 = 0;
        for (int i6 = 0; i6 < zArr.length; i6++) {
            if (zArr[i6]) {
                int length = dArr3.length - 1;
                dArr3[length] = dArr3[length] - (dArr3[i5] * dArr[i6]);
                i5++;
            }
        }
        return dArr3;
    }

    double[] performRegression(ExampleSet exampleSet, Attribute[] attributeArr, Attribute attribute, double d) throws UndefinedParameterError {
        Matrix matrix = null;
        Matrix matrix2 = null;
        double[] dArr = null;
        if (attributeArr.length > 0) {
            matrix = new Matrix(exampleSet.size(), attributeArr.length);
            matrix2 = new Matrix(exampleSet.size(), 1);
            int i = 0;
            dArr = new double[exampleSet.size()];
            Attribute weight = exampleSet.getAttributes().getWeight();
            for (Example example : exampleSet) {
                int i2 = 0;
                matrix2.set(i, 0, example.getLabel());
                for (Attribute attribute2 : attributeArr) {
                    matrix.set(i, i2, example.getValue(attribute2) - exampleSet.getStatistics(attribute2, "average"));
                    i2++;
                }
                if (weight != null) {
                    dArr[i] = example.getValue(weight);
                } else {
                    dArr[i] = 1.0d;
                }
                i++;
            }
        }
        double[] dArr2 = new double[attributeArr.length + 1];
        if (attributeArr.length > 0) {
            System.arraycopy(com.rapidminer.tools.math.LinearRegression.performRegression(matrix, matrix2, dArr, d), 0, dArr2, 0, attributeArr.length);
        }
        dArr2[attributeArr.length] = exampleSet.getStatistics(attribute, "average");
        for (int i3 = 0; i3 < attributeArr.length; i3++) {
            int length = dArr2.length - 1;
            dArr2[length] = dArr2[length] - (dArr2[i3] * exampleSet.getStatistics(attributeArr[i3], "average"));
        }
        return dArr2;
    }

    @Override // com.rapidminer.operator.learner.AbstractLearner
    public Class<? extends PredictionModel> getModelClass() {
        return LinearRegressionModel.class;
    }

    @Override // com.rapidminer.operator.learner.CapabilityProvider
    public boolean supportsCapability(OperatorCapability operatorCapability) {
        return operatorCapability.equals(OperatorCapability.NUMERICAL_ATTRIBUTES) || operatorCapability.equals(OperatorCapability.NUMERICAL_LABEL) || operatorCapability.equals(OperatorCapability.BINOMINAL_LABEL) || operatorCapability == OperatorCapability.WEIGHTED_EXAMPLES;
    }

    @Override // com.rapidminer.operator.Operator, com.rapidminer.parameter.ParameterHandler
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        String[] strArr = (String[]) SELECTION_METHODS.keySet().toArray(new String[SELECTION_METHODS.size()]);
        parameterTypes.add(new ParameterTypeCategory(PARAMETER_FEATURE_SELECTION, "The feature selection method used during regression.", strArr, 1));
        int i = 0;
        Iterator<Map.Entry<String, Class<? extends LinearRegressionMethod>>> it = SELECTION_METHODS.entrySet().iterator();
        while (it.hasNext()) {
            try {
                for (ParameterType parameterType : it.next().getValue().newInstance().getParameterTypes()) {
                    parameterTypes.add(parameterType);
                    parameterType.registerDependencyCondition(new EqualTypeCondition(this, PARAMETER_FEATURE_SELECTION, strArr, true, i));
                }
            } catch (IllegalAccessException e) {
            } catch (InstantiationException e2) {
            }
            i++;
        }
        parameterTypes.add(new ParameterTypeBoolean(PARAMETER_ELIMINATE_COLINEAR_FEATURES, "Indicates if the algorithm should try to delete colinear features during the regression.", true));
        ParameterTypeDouble parameterTypeDouble = new ParameterTypeDouble(PARAMETER_MIN_TOLERANCE, "The minimum tolerance for the removal of colinear features.", 0.0d, 1.0d, 0.05d);
        parameterTypeDouble.registerDependencyCondition(new BooleanParameterCondition(this, PARAMETER_ELIMINATE_COLINEAR_FEATURES, true, true));
        parameterTypes.add(parameterTypeDouble);
        parameterTypes.add(new ParameterTypeBoolean("use_bias", "Indicates if an intercept value should be calculated.", true));
        parameterTypes.add(new ParameterTypeDouble("ridge", "The ridge parameter used for ridge regression. A value of zero switches to ordinary least squares estimate.", 0.0d, Double.POSITIVE_INFINITY, 1.0E-8d));
        return parameterTypes;
    }

    @Override // com.rapidminer.operator.Operator, com.rapidminer.operator.annotation.ResourceConsumer
    public ResourceConsumptionEstimator getResourceConsumptionEstimator() {
        return OperatorResourceConsumptionHandler.getResourceConsumptionEstimator(getExampleSetInputPort(), LinearRegression.class, null);
    }

    static {
        SELECTION_METHODS.put("none", PlainLinearRegressionMethod.class);
        SELECTION_METHODS.put("M5 prime", M5PLinearRegressionMethod.class);
        SELECTION_METHODS.put("greedy", GreedyLinearRegressionMethod.class);
        SELECTION_METHODS.put("T-Test", TTestLinearRegressionMethod.class);
        SELECTION_METHODS.put("Iterative T-Test", IterativeTTestLinearRegressionMethod.class);
    }
}
