package com.rapidminer.operator.meta;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorChain;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.OperatorVersion;
import com.rapidminer.operator.ValueDouble;
import com.rapidminer.operator.learner.CapabilityCheck;
import com.rapidminer.operator.learner.CapabilityProvider;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.PortPairExtender;
import com.rapidminer.operator.ports.metadata.CapabilityPrecondition;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.ExampleSetPassThroughRule;
import com.rapidminer.operator.ports.metadata.MDInteger;
import com.rapidminer.operator.ports.metadata.PassThroughRule;
import com.rapidminer.operator.ports.metadata.SetRelation;
import com.rapidminer.operator.ports.metadata.SubprocessTransformRule;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.parameter.conditions.BooleanParameterCondition;
import com.rapidminer.tools.RandomGenerator;
import java.util.List;

/* loaded from: input_file:gen_lib/rapidminer.jar:com/rapidminer/operator/meta/XVPrediction.class */
public class XVPrediction extends OperatorChain implements CapabilityProvider {
    public static final String PARAMETER_NUMBER_OF_VALIDATIONS = "number_of_validations";
    public static final String PARAMETER_LEAVE_ONE_OUT = "leave_one_out";
    public static final String PARAMETER_SAMPLING_TYPE = "sampling_type";
    private int number;
    private int iteration;
    private final InputPort exampleSetInput;
    private final OutputPort trainingProcessExampleSource;
    private final InputPort trainingProcessModelSink;
    private final PortPairExtender throughExtender;
    private final OutputPort applyProcessModelSource;
    private final OutputPort applyProcessExampleSource;
    private final InputPort applyProcessExampleInnerSink;
    private final OutputPort exampleSetOutput;

    public XVPrediction(OperatorDescription operatorDescription) {
        super(operatorDescription, "Training", "Model Application");
        this.exampleSetInput = getInputPorts().createPort("example set", ExampleSet.class);
        this.trainingProcessExampleSource = getSubprocess(0).getInnerSources().createPort("training");
        this.trainingProcessModelSink = getSubprocess(0).getInnerSinks().createPort("model");
        this.throughExtender = new PortPairExtender("through", getSubprocess(0).getInnerSinks(), getSubprocess(1).getInnerSources());
        this.applyProcessModelSource = getSubprocess(1).getInnerSources().createPort("model");
        this.applyProcessExampleSource = getSubprocess(1).getInnerSources().createPort("unlabelled data");
        this.applyProcessExampleInnerSink = getSubprocess(1).getInnerSinks().createPort("labelled data");
        this.exampleSetOutput = getOutputPorts().createPort("labelled data");
        this.exampleSetInput.addPrecondition(new CapabilityPrecondition(this, this.exampleSetInput));
        this.throughExtender.start();
        getTransformer().addRule(new ExampleSetPassThroughRule(this.exampleSetInput, this.trainingProcessExampleSource, SetRelation.EQUAL) { // from class: com.rapidminer.operator.meta.XVPrediction.1
            @Override // com.rapidminer.operator.ports.metadata.ExampleSetPassThroughRule
            public ExampleSetMetaData modifyExampleSet(ExampleSetMetaData exampleSetMetaData) throws UndefinedParameterError {
                try {
                    exampleSetMetaData.setNumberOfExamples(XVPrediction.this.getTrainingSetSize(exampleSetMetaData.getNumberOfExamples()));
                } catch (UndefinedParameterError e) {
                }
                return super.modifyExampleSet(exampleSetMetaData);
            }
        });
        getTransformer().addRule(new ExampleSetPassThroughRule(this.exampleSetInput, this.applyProcessExampleSource, SetRelation.EQUAL) { // from class: com.rapidminer.operator.meta.XVPrediction.2
            @Override // com.rapidminer.operator.ports.metadata.ExampleSetPassThroughRule
            public ExampleSetMetaData modifyExampleSet(ExampleSetMetaData exampleSetMetaData) throws UndefinedParameterError {
                try {
                    exampleSetMetaData.setNumberOfExamples(XVPrediction.this.getTestSetSize(exampleSetMetaData.getNumberOfExamples()));
                } catch (UndefinedParameterError e) {
                }
                return super.modifyExampleSet(exampleSetMetaData);
            }
        });
        getTransformer().addRule(new SubprocessTransformRule(getSubprocess(0)));
        getTransformer().addRule(new PassThroughRule(this.trainingProcessModelSink, this.applyProcessModelSource, false));
        getTransformer().addRule(this.throughExtender.makePassThroughRule());
        getTransformer().addRule(new SubprocessTransformRule(getSubprocess(1)));
        getTransformer().addPassThroughRule(this.applyProcessExampleInnerSink, this.exampleSetOutput);
        addValue(new ValueDouble("iteration", "The number of the current iteration.") { // from class: com.rapidminer.operator.meta.XVPrediction.3
            @Override // com.rapidminer.operator.ValueDouble
            public double getDoubleValue() {
                return XVPrediction.this.iteration;
            }
        });
    }

    @Override // com.rapidminer.operator.OperatorChain, com.rapidminer.operator.Operator
    public void doWork() throws OperatorException {
        ExampleSet exampleSet = (ExampleSet) this.exampleSetInput.getData(ExampleSet.class);
        new CapabilityCheck(this, false).checkLearnerCapabilities(this, exampleSet);
        if (getParameterAsBoolean("leave_one_out")) {
            this.number = exampleSet.size();
        } else {
            this.number = getParameterAsInt("number_of_validations");
        }
        log("Starting " + this.number + "-fold cross validation prediction");
        ExampleSet exampleSet2 = (ExampleSet) exampleSet.clone();
        Attribute createPredictedLabel = PredictionModel.createPredictedLabel(exampleSet2, exampleSet.getAttributes().getLabel());
        List<String> values = createPredictedLabel.isNominal() ? createPredictedLabel.getMapping().getValues() : null;
        SplittedExampleSet splittedExampleSet = new SplittedExampleSet(exampleSet, this.number, getParameterAsInt("sampling_type"), getParameterAsBoolean(RandomGenerator.PARAMETER_USE_LOCAL_RANDOM_SEED), getParameterAsInt(RandomGenerator.PARAMETER_LOCAL_RANDOM_SEED), getCompatibilityLevel().isAtMost(SplittedExampleSet.VERSION_SAMPLING_CHANGED));
        this.iteration = 0;
        while (this.iteration < this.number) {
            splittedExampleSet.selectAllSubsetsBut(this.iteration);
            this.trainingProcessExampleSource.deliver(splittedExampleSet);
            getSubprocess(0).execute();
            splittedExampleSet.selectSingleSubset(this.iteration);
            this.applyProcessExampleSource.deliver(splittedExampleSet);
            this.throughExtender.passDataThrough();
            this.applyProcessModelSource.deliver(this.trainingProcessModelSink.getData(IOObject.class));
            getSubprocess(1).execute();
            ExampleSet exampleSet3 = (ExampleSet) this.applyProcessExampleInnerSink.getData(ExampleSet.class);
            for (int i = 0; i < splittedExampleSet.size(); i++) {
                Example example = exampleSet3.getExample(i);
                Example example2 = exampleSet2.getExample(splittedExampleSet.getActualParentIndex(i));
                example2.setValue(createPredictedLabel, example.getPredictedLabel());
                if (createPredictedLabel.isNominal()) {
                    for (String str : values) {
                        example2.setConfidence(str, example.getConfidence(str));
                    }
                }
            }
            inApplyLoop();
            this.iteration++;
        }
        this.exampleSetOutput.deliver(exampleSet2);
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [com.rapidminer.operator.ports.metadata.MDInteger] */
    protected MDInteger getTestSetSize(MDInteger mDInteger) throws UndefinedParameterError {
        return getParameterAsBoolean("leave_one_out") ? new MDInteger(1) : mDInteger.multiply2(1.0d / getParameterAsDouble("number_of_validations"));
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [com.rapidminer.operator.ports.metadata.MDInteger] */
    protected MDInteger getTrainingSetSize(MDInteger mDInteger) throws UndefinedParameterError {
        return getParameterAsBoolean("leave_one_out") ? mDInteger.add((Integer) (-1)) : mDInteger.multiply2(1.0d - (1.0d / getParameterAsDouble("number_of_validations")));
    }

    @Override // com.rapidminer.operator.Operator, com.rapidminer.parameter.ParameterHandler
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        parameterTypes.add(new ParameterTypeBoolean("leave_one_out", "Set the number of validations to the number of examples. If set to true, number_of_validations is ignored.", false, false));
        ParameterTypeInt parameterTypeInt = new ParameterTypeInt("number_of_validations", "Number of subsets for the crossvalidation.", 2, Integer.MAX_VALUE, 10, false);
        parameterTypeInt.registerDependencyCondition(new BooleanParameterCondition(this, "leave_one_out", false, false));
        parameterTypes.add(parameterTypeInt);
        parameterTypes.add(new ParameterTypeCategory("sampling_type", "Defines the sampling type of the cross validation.", SplittedExampleSet.SAMPLING_NAMES, 2, false));
        parameterTypes.addAll(RandomGenerator.getRandomGeneratorParameters(this));
        return parameterTypes;
    }

    @Override // com.rapidminer.operator.learner.CapabilityProvider
    public boolean supportsCapability(OperatorCapability operatorCapability) {
        switch (operatorCapability) {
            case NO_LABEL:
                return false;
            case NUMERICAL_LABEL:
                try {
                    return getParameterAsInt("sampling_type") != 2;
                } catch (UndefinedParameterError e) {
                    return false;
                }
            default:
                return true;
        }
    }

    @Override // com.rapidminer.operator.Operator
    public OperatorVersion[] getIncompatibleVersionChanges() {
        return new OperatorVersion[]{SplittedExampleSet.VERSION_SAMPLING_CHANGED};
    }
}
