package com.rapidminer.operator.visualization;

import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorChain;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.OperatorVersion;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.CapabilityProvider;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.InputPortExtender;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.OutputPortExtender;
import com.rapidminer.operator.ports.metadata.CapabilityPrecondition;
import com.rapidminer.operator.ports.metadata.GenerateNewMDRule;
import com.rapidminer.operator.ports.metadata.Precondition;
import com.rapidminer.operator.ports.metadata.PredictionModelMetaData;
import com.rapidminer.operator.ports.metadata.SimplePrecondition;
import com.rapidminer.operator.ports.metadata.SubprocessTransformRule;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.tools.RandomGenerator;
import com.rapidminer.tools.math.ROCBias;
import com.rapidminer.tools.math.ROCData;
import com.rapidminer.tools.math.ROCDataGenerator;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;

/* loaded from: input_file:gen_lib/rapidminer.jar:com/rapidminer/operator/visualization/ROCBasedComparisonOperator.class */
public class ROCBasedComparisonOperator extends OperatorChain implements CapabilityProvider {
    public static final String PARAMETER_NUMBER_OF_FOLDS = "number_of_folds";
    public static final String PARAMETER_SPLIT_RATIO = "split_ratio";
    public static final String PARAMETER_SAMPLING_TYPE = "sampling_type";
    public static final String PARAMETER_USE_EXAMPLE_WEIGHTS = "use_example_weights";
    private final InputPort exampleSetInput;
    private final OutputPort exampleSetOutput;
    private final OutputPort rocComparisonOutput;
    private final OutputPortExtender trainingSetExtender;
    private final InputPortExtender modelExtender;

    public ROCBasedComparisonOperator(OperatorDescription operatorDescription) {
        super(operatorDescription, "Model Generation");
        this.exampleSetInput = getInputPorts().createPort("example set", ExampleSet.class);
        this.exampleSetOutput = getOutputPorts().createPort("exampleSet");
        this.rocComparisonOutput = getOutputPorts().createPort("rocComparison");
        this.trainingSetExtender = new OutputPortExtender("train", getSubprocess(0).getInnerSources());
        this.modelExtender = new InputPortExtender("model", getSubprocess(0).getInnerSinks()) { // from class: com.rapidminer.operator.visualization.ROCBasedComparisonOperator.1
            @Override // com.rapidminer.operator.ports.InputPortExtender
            public Precondition makePrecondition(InputPort inputPort) {
                return new SimplePrecondition(inputPort, new PredictionModelMetaData(PredictionModel.class), false);
            }
        };
        this.trainingSetExtender.start();
        this.modelExtender.start();
        this.exampleSetInput.addPrecondition(new CapabilityPrecondition(this, this.exampleSetInput));
        getTransformer().addRule(this.trainingSetExtender.makePassThroughRule(this.exampleSetInput));
        getTransformer().addPassThroughRule(this.exampleSetInput, this.exampleSetOutput);
        getTransformer().addRule(new SubprocessTransformRule(getSubprocess(0)));
        getTransformer().addRule(new GenerateNewMDRule(this.rocComparisonOutput, (Class<? extends IOObject>) ROCComparison.class));
    }

    @Override // com.rapidminer.operator.OperatorChain, com.rapidminer.operator.Operator
    public void doWork() throws OperatorException {
        ExampleSet exampleSet = (ExampleSet) this.exampleSetInput.getData(ExampleSet.class);
        if (exampleSet.getAttributes().getLabel() == null) {
            throw new UserError(this, 105);
        }
        if (!exampleSet.getAttributes().getLabel().isNominal()) {
            throw new UserError(this, 101, "ROC Comparison", exampleSet.getAttributes().getLabel());
        }
        if (exampleSet.getAttributes().getLabel().getMapping().getValues().size() != 2) {
            throw new UserError(this, 114, "ROC Comparison", exampleSet.getAttributes().getLabel());
        }
        HashMap hashMap = new HashMap();
        int parameterAsInt = getParameterAsInt(PARAMETER_NUMBER_OF_FOLDS);
        ExampleSet exampleSet2 = (ExampleSet) exampleSet.clone();
        if (parameterAsInt < 0) {
            SplittedExampleSet splittedExampleSet = new SplittedExampleSet(exampleSet2, getParameterAsDouble("split_ratio"), getParameterAsInt("sampling_type"), getParameterAsBoolean(RandomGenerator.PARAMETER_USE_LOCAL_RANDOM_SEED), getParameterAsInt(RandomGenerator.PARAMETER_LOCAL_RANDOM_SEED), getCompatibilityLevel().isAtMost(SplittedExampleSet.VERSION_SAMPLING_CHANGED));
            splittedExampleSet.selectSingleSubset(0);
            this.trainingSetExtender.deliverToAll(splittedExampleSet, false);
            getSubprocess(0).execute();
            List<Model> data = this.modelExtender.getData(Model.class, true);
            splittedExampleSet.selectSingleSubset(1);
            for (Model model : data) {
                ExampleSet apply = model.apply(splittedExampleSet);
                if (apply.getAttributes().getPredictedLabel() == null) {
                    throw new UserError(this, 107);
                }
                ROCData createROCData = new ROCDataGenerator(1.0d, 1.0d).createROCData(apply, getParameterAsBoolean("use_example_weights"), ROCBias.getROCBiasParameter(this));
                LinkedList linkedList = new LinkedList();
                linkedList.add(createROCData);
                hashMap.put(model.getSource(), linkedList);
                PredictionModel.removePredictedLabel(apply);
            }
        } else {
            SplittedExampleSet splittedExampleSet2 = new SplittedExampleSet(exampleSet2, parameterAsInt, getParameterAsInt("sampling_type"), getParameterAsBoolean(RandomGenerator.PARAMETER_USE_LOCAL_RANDOM_SEED), getParameterAsInt(RandomGenerator.PARAMETER_LOCAL_RANDOM_SEED), getCompatibilityLevel().isAtMost(SplittedExampleSet.VERSION_SAMPLING_CHANGED));
            PredictionModel.removePredictedLabel(splittedExampleSet2);
            for (int i = 0; i < parameterAsInt; i++) {
                splittedExampleSet2.selectAllSubsetsBut(i);
                this.trainingSetExtender.deliverToAll(splittedExampleSet2, false);
                getSubprocess(0).execute();
                for (Model model2 : this.modelExtender.getData(Model.class, true)) {
                    splittedExampleSet2.selectSingleSubset(i);
                    ExampleSet apply2 = model2.apply(splittedExampleSet2);
                    if (apply2.getAttributes().getPredictedLabel() == null) {
                        throw new UserError(this, 107);
                    }
                    ROCData createROCData2 = new ROCDataGenerator(1.0d, 1.0d).createROCData(apply2, getParameterAsBoolean("use_example_weights"), ROCBias.getROCBiasParameter(this));
                    List list = (List) hashMap.get(model2.getSource());
                    if (list == null) {
                        list = new LinkedList();
                        hashMap.put(model2.getSource(), list);
                    }
                    list.add(createROCData2);
                    PredictionModel.removePredictedLabel(apply2);
                }
                inApplyLoop();
            }
        }
        this.exampleSetOutput.deliver(exampleSet);
        this.rocComparisonOutput.deliver(new ROCComparison(hashMap));
    }

    @Override // com.rapidminer.operator.Operator, com.rapidminer.parameter.ParameterHandler
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        ParameterTypeInt parameterTypeInt = new ParameterTypeInt(PARAMETER_NUMBER_OF_FOLDS, "The number of folds used for a cross validation evaluation (-1: use simple split ratio).", -1, Integer.MAX_VALUE, 10);
        parameterTypeInt.setExpert(false);
        parameterTypes.add(parameterTypeInt);
        parameterTypes.add(new ParameterTypeDouble("split_ratio", "Relative size of the training set", 0.0d, 1.0d, 0.7d));
        parameterTypes.add(new ParameterTypeCategory("sampling_type", "Defines the sampling type of the cross validation (linear = consecutive subsets, shuffled = random subsets, stratified = random subsets with class distribution kept constant)", SplittedExampleSet.SAMPLING_NAMES, 2));
        parameterTypes.addAll(RandomGenerator.getRandomGeneratorParameters(this));
        parameterTypes.add(new ParameterTypeBoolean("use_example_weights", "Indicates if example weights should be regarded (use weight 1 for each example otherwise).", true));
        parameterTypes.add(ROCBias.makeParameterType());
        return parameterTypes;
    }

    @Override // com.rapidminer.operator.learner.CapabilityProvider
    public boolean supportsCapability(OperatorCapability operatorCapability) {
        switch (operatorCapability) {
            case NO_LABEL:
                return false;
            case NUMERICAL_LABEL:
                try {
                    return getParameterAsInt("sampling_type") != 2;
                } catch (UndefinedParameterError e) {
                    return false;
                }
            default:
                return true;
        }
    }

    @Override // com.rapidminer.operator.Operator
    public OperatorVersion[] getIncompatibleVersionChanges() {
        return new OperatorVersion[]{SplittedExampleSet.VERSION_SAMPLING_CHANGED};
    }
}
