package com.rapidminer.operator.validation;

import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.OperatorVersion;
import com.rapidminer.operator.ProcessStoppedException;
import com.rapidminer.operator.ValueDouble;
import com.rapidminer.operator.ports.metadata.AttributeMetaData;
import com.rapidminer.operator.ports.metadata.CapabilityPrecondition;
import com.rapidminer.operator.ports.metadata.MDInteger;
import com.rapidminer.operator.ports.metadata.Precondition;
import com.rapidminer.operator.ports.quickfix.ParameterSettingQuickFix;
import com.rapidminer.operator.ports.quickfix.QuickFix;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.parameter.conditions.BooleanParameterCondition;
import com.rapidminer.parameter.conditions.EqualTypeCondition;
import com.rapidminer.tools.RandomGenerator;
import java.util.List;

/* loaded from: input_file:gen_lib/rapidminer.jar:com/rapidminer/operator/validation/XValidation.class */
public class XValidation extends ValidationChain {
    public static final String PARAMETER_NUMBER_OF_VALIDATIONS = "number_of_validations";
    public static final String PARAMETER_LEAVE_ONE_OUT = "leave_one_out";
    public static final String PARAMETER_SAMPLING_TYPE = "sampling_type";
    public static final String PARAMETER_AVERAGE_PERFORMANCES_ONLY = "average_performances_only";
    private int iteration;

    public XValidation(OperatorDescription operatorDescription) {
        super(operatorDescription);
        addValue(new ValueDouble("iteration", "The number of the current iteration.") { // from class: com.rapidminer.operator.validation.XValidation.1
            @Override // com.rapidminer.operator.ValueDouble
            public double getDoubleValue() {
                return XValidation.this.iteration;
            }
        });
    }

    @Override // com.rapidminer.operator.validation.ValidationChain
    protected Precondition getCapabilityPrecondition() {
        return new CapabilityPrecondition(this, this.trainingSetInput) { // from class: com.rapidminer.operator.validation.XValidation.2
            /* JADX INFO: Access modifiers changed from: protected */
            @Override // com.rapidminer.operator.ports.metadata.CapabilityPrecondition
            public List<QuickFix> getFixesForRegressionWhenClassificationSupported(AttributeMetaData attributeMetaData) {
                List<QuickFix> fixesForRegressionWhenClassificationSupported = super.getFixesForRegressionWhenClassificationSupported(attributeMetaData);
                fixesForRegressionWhenClassificationSupported.add(0, new ParameterSettingQuickFix(XValidation.this, "sampling_type", "1", "switch_to_shuffled_sampling", new Object[0]));
                return fixesForRegressionWhenClassificationSupported;
            }
        };
    }

    @Override // com.rapidminer.operator.validation.ValidationChain
    public void estimatePerformance(ExampleSet exampleSet) throws OperatorException {
        int size = getParameterAsBoolean("leave_one_out") ? exampleSet.size() : getParameterAsInt("number_of_validations");
        getLogger().fine("Starting " + size + "-fold cross validation");
        SplittedExampleSet splittedExampleSet = new SplittedExampleSet(exampleSet, size, getParameterAsInt("sampling_type"), getParameterAsBoolean(RandomGenerator.PARAMETER_USE_LOCAL_RANDOM_SEED), getParameterAsInt(RandomGenerator.PARAMETER_LOCAL_RANDOM_SEED), getCompatibilityLevel().isAtMost(SplittedExampleSet.VERSION_SAMPLING_CHANGED));
        this.iteration = 0;
        while (this.iteration < size) {
            performIteration(splittedExampleSet, this.iteration);
            this.iteration++;
        }
    }

    protected void performIteration(SplittedExampleSet splittedExampleSet, int i) throws OperatorException, ProcessStoppedException {
        splittedExampleSet.selectAllSubsetsBut(i);
        learn(splittedExampleSet);
        splittedExampleSet.selectSingleSubset(i);
        evaluate(splittedExampleSet);
        inApplyLoop();
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [com.rapidminer.operator.ports.metadata.MDInteger] */
    @Override // com.rapidminer.operator.validation.ValidationChain
    protected MDInteger getTestSetSize(MDInteger mDInteger) throws UndefinedParameterError {
        return getParameterAsBoolean("leave_one_out") ? new MDInteger(1) : mDInteger.multiply2(1.0d / getParameterAsDouble("number_of_validations"));
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [com.rapidminer.operator.ports.metadata.MDInteger] */
    @Override // com.rapidminer.operator.validation.ValidationChain
    protected MDInteger getTrainingSetSize(MDInteger mDInteger) throws UndefinedParameterError {
        return getParameterAsBoolean("leave_one_out") ? mDInteger.add((Integer) (-1)) : mDInteger.multiply2(1.0d - (1.0d / getParameterAsDouble("number_of_validations")));
    }

    @Override // com.rapidminer.operator.validation.ValidationChain, com.rapidminer.operator.Operator, com.rapidminer.parameter.ParameterHandler
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        parameterTypes.add(new ParameterTypeBoolean("average_performances_only", "Indicates if only performance vectors should be averaged or all types of averagable result vectors", true));
        parameterTypes.add(new ParameterTypeBoolean("leave_one_out", "Set the number of validations to the number of examples. If set to true, number_of_validations is ignored", false, false));
        ParameterTypeInt parameterTypeInt = new ParameterTypeInt("number_of_validations", "Number of subsets for the crossvalidation.", 2, Integer.MAX_VALUE, 10);
        parameterTypeInt.registerDependencyCondition(new BooleanParameterCondition(this, "leave_one_out", false, false));
        parameterTypeInt.setExpert(false);
        parameterTypes.add(parameterTypeInt);
        ParameterTypeCategory parameterTypeCategory = new ParameterTypeCategory("sampling_type", "Defines the sampling type of the cross validation (linear = consecutive subsets, shuffled = random subsets, stratified = random subsets with class distribution kept constant)", SplittedExampleSet.SAMPLING_NAMES, 2);
        parameterTypeCategory.setExpert(false);
        parameterTypeCategory.registerDependencyCondition(new BooleanParameterCondition(this, "leave_one_out", false, false));
        parameterTypes.add(parameterTypeCategory);
        for (ParameterType parameterType : RandomGenerator.getRandomGeneratorParameters(this)) {
            parameterType.registerDependencyCondition(new BooleanParameterCondition(this, "leave_one_out", false, false));
            parameterType.registerDependencyCondition(new EqualTypeCondition(this, "sampling_type", SplittedExampleSet.SAMPLING_NAMES, false, 1, 2));
            parameterTypes.add(parameterType);
        }
        return parameterTypes;
    }

    @Override // com.rapidminer.operator.Operator
    public OperatorVersion[] getIncompatibleVersionChanges() {
        return new OperatorVersion[]{SplittedExampleSet.VERSION_SAMPLING_CHANGED};
    }

    @Override // com.rapidminer.operator.learner.CapabilityProvider
    public boolean supportsCapability(OperatorCapability operatorCapability) {
        switch (operatorCapability) {
            case NO_LABEL:
                return false;
            case NUMERICAL_LABEL:
                try {
                    return getParameterAsInt("sampling_type") != 2;
                } catch (UndefinedParameterError e) {
                    return false;
                }
            default:
                return true;
        }
    }
}
