package com.rapidminer.operator.validation;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorChain;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.ValueDouble;
import com.rapidminer.operator.learner.CapabilityProvider;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.performance.PerformanceCriterion;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.PortPairExtender;
import com.rapidminer.operator.ports.metadata.CapabilityPrecondition;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.ExampleSetPassThroughRule;
import com.rapidminer.operator.ports.metadata.MDInteger;
import com.rapidminer.operator.ports.metadata.MetaData;
import com.rapidminer.operator.ports.metadata.PassThroughRule;
import com.rapidminer.operator.ports.metadata.Precondition;
import com.rapidminer.operator.ports.metadata.SetRelation;
import com.rapidminer.operator.ports.metadata.SimplePrecondition;
import com.rapidminer.operator.ports.metadata.SubprocessTransformRule;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.tools.math.AverageVector;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:gen_lib/rapidminer.jar:com/rapidminer/operator/validation/ValidationChain.class */
public abstract class ValidationChain extends OperatorChain implements CapabilityProvider {
    public static final String PARAMETER_CREATE_COMPLETE_MODEL = "create_complete_model";
    protected final InputPort trainingSetInput;
    protected final OutputPort trainingProcessExampleSetOutput;
    private final InputPort trainingProcessModelInput;
    private final PortPairExtender throughExtender;
    private final OutputPort applyProcessModelOutput;
    private final OutputPort applyProcessExampleSetOutput;
    private final PortPairExtender applyProcessPerformancePortExtender;
    private final OutputPort modelOutput;
    private final OutputPort exampleSetOutput;
    private double lastMainPerformance;
    private double lastMainVariance;
    private double lastMainDeviation;
    private double lastFirstPerformance;
    private double lastSecondPerformance;
    private double lastThirdPerformance;

    public ValidationChain(OperatorDescription operatorDescription) {
        super(operatorDescription, "Training", "Testing");
        this.trainingSetInput = getInputPorts().createPort("training", ExampleSet.class);
        this.trainingProcessExampleSetOutput = getSubprocess(0).getInnerSources().createPort("training");
        this.trainingProcessModelInput = getSubprocess(0).getInnerSinks().createPort("model", Model.class);
        this.throughExtender = new PortPairExtender("through", getSubprocess(0).getInnerSinks(), getSubprocess(1).getInnerSources());
        this.applyProcessModelOutput = getSubprocess(1).getInnerSources().createPort("model");
        this.applyProcessExampleSetOutput = getSubprocess(1).getInnerSources().createPort("test set");
        this.applyProcessPerformancePortExtender = new PortPairExtender("averagable", getSubprocess(1).getInnerSinks(), getOutputPorts(), new MetaData(AverageVector.class));
        this.modelOutput = getOutputPorts().createPort("model");
        this.exampleSetOutput = getOutputPorts().createPort("training");
        this.lastMainPerformance = Double.NaN;
        this.lastMainVariance = Double.NaN;
        this.lastMainDeviation = Double.NaN;
        this.lastFirstPerformance = Double.NaN;
        this.lastSecondPerformance = Double.NaN;
        this.lastThirdPerformance = Double.NaN;
        this.throughExtender.start();
        this.trainingSetInput.addPrecondition(getCapabilityPrecondition());
        this.applyProcessPerformancePortExtender.ensureMinimumNumberOfPorts(1);
        InputPort inputPort = this.applyProcessPerformancePortExtender.getManagedPairs().iterator().next().getInputPort();
        inputPort.addPrecondition(new SimplePrecondition(inputPort, new MetaData(PerformanceVector.class)));
        this.applyProcessPerformancePortExtender.start();
        getTransformer().addRule(new ExampleSetPassThroughRule(this.trainingSetInput, this.trainingProcessExampleSetOutput, SetRelation.EQUAL) { // from class: com.rapidminer.operator.validation.ValidationChain.1
            @Override // com.rapidminer.operator.ports.metadata.ExampleSetPassThroughRule
            public ExampleSetMetaData modifyExampleSet(ExampleSetMetaData exampleSetMetaData) throws UndefinedParameterError {
                try {
                    exampleSetMetaData.setNumberOfExamples(ValidationChain.this.getTrainingSetSize(exampleSetMetaData.getNumberOfExamples()));
                } catch (UndefinedParameterError e) {
                }
                return super.modifyExampleSet(exampleSetMetaData);
            }
        });
        getTransformer().addRule(new ExampleSetPassThroughRule(this.trainingSetInput, this.applyProcessExampleSetOutput, SetRelation.EQUAL) { // from class: com.rapidminer.operator.validation.ValidationChain.2
            @Override // com.rapidminer.operator.ports.metadata.ExampleSetPassThroughRule
            public ExampleSetMetaData modifyExampleSet(ExampleSetMetaData exampleSetMetaData) throws UndefinedParameterError {
                try {
                    exampleSetMetaData.setNumberOfExamples(ValidationChain.this.getTestSetSize(exampleSetMetaData.getNumberOfExamples()));
                } catch (UndefinedParameterError e) {
                }
                return super.modifyExampleSet(exampleSetMetaData);
            }
        });
        getTransformer().addRule(new SubprocessTransformRule(getSubprocess(0)));
        getTransformer().addRule(new PassThroughRule(this.trainingProcessModelInput, this.applyProcessModelOutput, false));
        getTransformer().addRule(new PassThroughRule(this.trainingProcessModelInput, this.modelOutput, false));
        getTransformer().addRule(this.throughExtender.makePassThroughRule());
        getTransformer().addRule(new SubprocessTransformRule(getSubprocess(1)));
        getTransformer().addRule(this.applyProcessPerformancePortExtender.makePassThroughRule());
        getTransformer().addPassThroughRule(this.trainingSetInput, this.exampleSetOutput);
        addValue(new ValueDouble("performance", "The last performance average (main criterion).") { // from class: com.rapidminer.operator.validation.ValidationChain.3
            @Override // com.rapidminer.operator.ValueDouble
            public double getDoubleValue() {
                return ValidationChain.this.lastMainPerformance;
            }
        });
        addValue(new ValueDouble("variance", "The variance of the last performance (main criterion).") { // from class: com.rapidminer.operator.validation.ValidationChain.4
            @Override // com.rapidminer.operator.ValueDouble
            public double getDoubleValue() {
                return ValidationChain.this.lastMainVariance;
            }
        });
        addValue(new ValueDouble("deviation", "The standard deviation of the last performance (main criterion).") { // from class: com.rapidminer.operator.validation.ValidationChain.5
            @Override // com.rapidminer.operator.ValueDouble
            public double getDoubleValue() {
                return ValidationChain.this.lastMainDeviation;
            }
        });
        addValue(new ValueDouble("performance1", "The last performance average (first criterion).") { // from class: com.rapidminer.operator.validation.ValidationChain.6
            @Override // com.rapidminer.operator.ValueDouble
            public double getDoubleValue() {
                return ValidationChain.this.lastFirstPerformance;
            }
        });
        addValue(new ValueDouble("performance2", "The last performance average (second criterion).") { // from class: com.rapidminer.operator.validation.ValidationChain.7
            @Override // com.rapidminer.operator.ValueDouble
            public double getDoubleValue() {
                return ValidationChain.this.lastSecondPerformance;
            }
        });
        addValue(new ValueDouble("performance3", "The last performance average (third criterion).") { // from class: com.rapidminer.operator.validation.ValidationChain.8
            @Override // com.rapidminer.operator.ValueDouble
            public double getDoubleValue() {
                return ValidationChain.this.lastThirdPerformance;
            }
        });
    }

    protected Precondition getCapabilityPrecondition() {
        return new CapabilityPrecondition(this, this.trainingSetInput);
    }

    protected abstract MDInteger getTrainingSetSize(MDInteger mDInteger) throws UndefinedParameterError;

    protected abstract MDInteger getTestSetSize(MDInteger mDInteger) throws UndefinedParameterError;

    @Override // com.rapidminer.operator.Operator
    public boolean shouldAutoConnect(OutputPort outputPort) {
        return outputPort == this.modelOutput ? getParameterAsBoolean(PARAMETER_CREATE_COMPLETE_MODEL) : outputPort == this.exampleSetOutput ? getParameterAsBoolean("keep_example_set") : super.shouldAutoConnect(outputPort);
    }

    public abstract void estimatePerformance(ExampleSet exampleSet) throws OperatorException;

    protected void executeLearner() throws OperatorException {
        getSubprocess(0).execute();
    }

    protected void executeEvaluator() throws OperatorException {
        getSubprocess(1).execute();
    }

    private final void setResult(PerformanceVector performanceVector) {
        PerformanceCriterion criterion;
        PerformanceCriterion criterion2;
        PerformanceCriterion criterion3;
        this.lastMainPerformance = Double.NaN;
        this.lastMainVariance = Double.NaN;
        this.lastMainDeviation = Double.NaN;
        this.lastFirstPerformance = Double.NaN;
        this.lastSecondPerformance = Double.NaN;
        this.lastThirdPerformance = Double.NaN;
        if (performanceVector != null) {
            PerformanceCriterion mainCriterion = performanceVector.getMainCriterion();
            if (mainCriterion == null && performanceVector.size() > 0) {
                mainCriterion = performanceVector.getCriterion(0);
            }
            if (mainCriterion != null) {
                this.lastMainPerformance = mainCriterion.getAverage();
                this.lastMainVariance = mainCriterion.getVariance();
                this.lastMainDeviation = mainCriterion.getStandardDeviation();
            }
            if (performanceVector.size() >= 1 && (criterion3 = performanceVector.getCriterion(0)) != null) {
                this.lastFirstPerformance = criterion3.getAverage();
            }
            if (performanceVector.size() >= 2 && (criterion2 = performanceVector.getCriterion(1)) != null) {
                this.lastSecondPerformance = criterion2.getAverage();
            }
            if (performanceVector.size() < 3 || (criterion = performanceVector.getCriterion(2)) == null) {
                return;
            }
            this.lastThirdPerformance = criterion.getAverage();
        }
    }

    @Override // com.rapidminer.operator.OperatorChain, com.rapidminer.operator.Operator
    public void doWork() throws OperatorException {
        ExampleSet exampleSet = (ExampleSet) this.trainingSetInput.getData(ExampleSet.class);
        estimatePerformance(exampleSet);
        if (this.modelOutput.isConnected()) {
            learnFinalModel(exampleSet);
            this.modelOutput.deliver(this.trainingProcessModelInput.getData(IOObject.class));
        }
        this.exampleSetOutput.deliver(exampleSet);
        boolean z = false;
        Iterator it = this.applyProcessPerformancePortExtender.getOutputData(IOObject.class).iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            IOObject iOObject = (IOObject) it.next();
            if (iOObject instanceof PerformanceVector) {
                setResult((PerformanceVector) iOObject);
                z = true;
                break;
            }
        }
        if (z) {
            return;
        }
        getLogger().warning("No performance vector found among averagable results. Performance will not be loggable.");
    }

    protected void learnFinalModel(ExampleSet exampleSet) throws OperatorException {
        learn(exampleSet);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final void learn(ExampleSet exampleSet) throws OperatorException {
        this.trainingProcessExampleSetOutput.deliver(exampleSet);
        executeLearner();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final void evaluate(ExampleSet exampleSet) throws OperatorException {
        Attribute predictedLabel = exampleSet.getAttributes().getPredictedLabel();
        this.applyProcessExampleSetOutput.deliver(exampleSet);
        this.applyProcessModelOutput.deliver(this.trainingProcessModelInput.getData(IOObject.class));
        this.throughExtender.passDataThrough();
        executeEvaluator();
        Tools.buildAverages(this.applyProcessPerformancePortExtender);
        Attribute predictedLabel2 = exampleSet.getAttributes().getPredictedLabel();
        if (predictedLabel2 != null) {
            if (predictedLabel == null || predictedLabel.getTableIndex() != predictedLabel2.getTableIndex()) {
                PredictionModel.removePredictedLabel(exampleSet);
            }
        }
    }

    @Override // com.rapidminer.operator.Operator, com.rapidminer.parameter.ParameterHandler
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        ParameterTypeBoolean parameterTypeBoolean = new ParameterTypeBoolean(PARAMETER_CREATE_COMPLETE_MODEL, "Indicates if a model of the complete data set should be additionally build after estimation.", false);
        parameterTypeBoolean.setDeprecated();
        parameterTypeBoolean.setExpert(false);
        parameterTypes.add(parameterTypeBoolean);
        return parameterTypes;
    }
}
