package com.rapidminer.operator.learner.functions.neuralnet;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.tools.RandomGenerator;
import java.util.List;
import org.encog.neural.activation.ActivationLinear;
import org.encog.neural.activation.ActivationSigmoid;
import org.encog.neural.data.NeuralDataSet;
import org.encog.neural.data.basic.BasicNeuralDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.FeedforwardLayer;
import org.encog.neural.networks.training.backpropagation.Backpropagation;

/* loaded from: input_file:gen_lib/rapidminer.jar:com/rapidminer/operator/learner/functions/neuralnet/SimpleNeuralNetLearner.class */
public class SimpleNeuralNetLearner extends AbstractLearner {
    public static final String PARAMETER_DEFINE_DIFFERENT_HIDDEN_LAYERS = "define_different_hidden_layers";
    public static final String PARAMETER_HIDDEN_LAYER_SIZES = "hidden_layer_sizes";
    public static final String PARAMETER_DEFAULT_NUMBER_OF_HIDDEN_LAYERS = "default_number_of_hidden_layers";
    public static final String PARAMETER_DEFAULT_HIDDEN_LAYER_SIZE = "default_hidden_layer_size";
    public static final String PARAMETER_TRAINING_CYCLES = "training_cycles";
    public static final String PARAMETER_LEARNING_RATE = "learning_rate";
    public static final String PARAMETER_MOMENTUM = "momentum";
    public static final String PARAMETER_ERROR_EPSILON = "error_epsilon";
    private double[] attributeMin;
    private double[] attributeMax;
    private double labelMin;
    private double labelMax;

    public SimpleNeuralNetLearner(OperatorDescription operatorDescription) {
        super(operatorDescription);
    }

    @Override // com.rapidminer.operator.learner.Learner
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        return new SimpleNeuralNetModel(exampleSet, trainNetwork(getNetwork(exampleSet), getTraining(exampleSet), getParameterAsDouble("learning_rate"), getParameterAsDouble("momentum"), getParameterAsDouble("error_epsilon"), getParameterAsInt("training_cycles")), this.attributeMin, this.attributeMax, this.labelMin, this.labelMax);
    }

    private BasicNetwork getNetwork(ExampleSet exampleSet) throws OperatorException {
        BasicNetwork basicNetwork = new BasicNetwork();
        basicNetwork.addLayer(new FeedforwardLayer(exampleSet.getAttributes().size()));
        log("No hidden layers defined. Using default hidden layers.");
        int parameterAsInt = getParameterAsInt("default_hidden_layer_size");
        if (parameterAsInt <= 0) {
            parameterAsInt = getDefaultLayerSize(exampleSet);
        }
        for (int i = 0; i < getParameterAsInt("default_number_of_hidden_layers"); i++) {
            basicNetwork.addLayer(new FeedforwardLayer(parameterAsInt));
        }
        if (exampleSet.getAttributes().getLabel().isNominal()) {
            basicNetwork.addLayer(new FeedforwardLayer(new ActivationSigmoid(), 1));
        } else {
            basicNetwork.addLayer(new FeedforwardLayer(new ActivationLinear(), 1));
        }
        basicNetwork.reset(RandomGenerator.getRandomGenerator(getParameterAsBoolean(RandomGenerator.PARAMETER_USE_LOCAL_RANDOM_SEED), getParameterAsInt(RandomGenerator.PARAMETER_LOCAL_RANDOM_SEED)));
        return basicNetwork;
    }

    private int getDefaultLayerSize(ExampleSet exampleSet) {
        return ((int) Math.round(exampleSet.getAttributes().size() / 2.0d)) + 1;
    }

    private NeuralDataSet getTraining(ExampleSet exampleSet) {
        double[][] dArr = new double[exampleSet.size()][exampleSet.getAttributes().size()];
        double[][] dArr2 = new double[exampleSet.size()][1];
        int i = 0;
        Attribute label = exampleSet.getAttributes().getLabel();
        this.attributeMin = new double[exampleSet.getAttributes().size()];
        this.attributeMax = new double[this.attributeMin.length];
        exampleSet.recalculateAllAttributeStatistics();
        int i2 = 0;
        for (Attribute attribute : exampleSet.getAttributes()) {
            this.attributeMin[i2] = exampleSet.getStatistics(attribute, "minimum");
            this.attributeMax[i2] = exampleSet.getStatistics(attribute, "maximum");
            i2++;
        }
        this.labelMin = exampleSet.getStatistics(label, "minimum");
        this.labelMax = exampleSet.getStatistics(label, "maximum");
        for (Example example : exampleSet) {
            int i3 = 0;
            for (Attribute attribute2 : exampleSet.getAttributes()) {
                if (this.attributeMin[i3] != this.attributeMax[i3]) {
                    dArr[i][i3] = (example.getValue(attribute2) - this.attributeMin[i3]) / (this.attributeMax[i3] - this.attributeMin[i3]);
                } else {
                    dArr[i][i3] = example.getValue(attribute2) - this.attributeMin[i3];
                }
                i3++;
            }
            if (label.isNominal()) {
                dArr2[i][0] = example.getValue(label);
            } else if (this.labelMax != this.labelMin) {
                dArr2[i][0] = (example.getValue(label) - this.labelMin) / (this.labelMax - this.labelMin);
            } else {
                dArr2[i][0] = example.getValue(label) - this.labelMin;
            }
            i++;
        }
        return new BasicNeuralDataSet(dArr, dArr2);
    }

    private BasicNetwork trainNetwork(BasicNetwork basicNetwork, NeuralDataSet neuralDataSet, double d, double d2, double d3, int i) {
        Backpropagation backpropagation = new Backpropagation(basicNetwork, neuralDataSet, d, d2);
        int i2 = 1;
        do {
            backpropagation.iteration();
            i2++;
            if (i2 >= i) {
                break;
            }
        } while (backpropagation.getError() > d3);
        return (BasicNetwork) backpropagation.getNetwork();
    }

    @Override // com.rapidminer.operator.learner.AbstractLearner
    public Class<? extends PredictionModel> getModelClass() {
        return SimpleNeuralNetModel.class;
    }

    @Override // com.rapidminer.operator.learner.CapabilityProvider
    public boolean supportsCapability(OperatorCapability operatorCapability) {
        return operatorCapability == OperatorCapability.NUMERICAL_ATTRIBUTES || operatorCapability == OperatorCapability.BINOMINAL_LABEL || operatorCapability == OperatorCapability.NUMERICAL_LABEL;
    }

    @Override // com.rapidminer.operator.Operator, com.rapidminer.parameter.ParameterHandler
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        ParameterTypeInt parameterTypeInt = new ParameterTypeInt("default_number_of_hidden_layers", "The number of hidden layers. Only used if no layers are defined by the list hidden_layer_types.", 1, Integer.MAX_VALUE, 1);
        parameterTypeInt.setExpert(false);
        parameterTypes.add(parameterTypeInt);
        parameterTypes.add(new ParameterTypeInt("default_hidden_layer_size", "The default size  of hidden layers. Only used if no layers are defined by the list hidden_layer_types. -1 means size (number of attributes + number of classes) / 2", -1, Integer.MAX_VALUE, -1));
        ParameterTypeInt parameterTypeInt2 = new ParameterTypeInt("training_cycles", "The number of training cycles used for the neural network training.", 1, Integer.MAX_VALUE, 500);
        parameterTypeInt2.setExpert(false);
        parameterTypes.add(parameterTypeInt2);
        ParameterTypeDouble parameterTypeDouble = new ParameterTypeDouble("learning_rate", "The learning rate determines by how much we change the weights at each step.", 0.0d, 1.0d, 0.3d);
        parameterTypeDouble.setExpert(false);
        parameterTypes.add(parameterTypeDouble);
        parameterTypes.add(new ParameterTypeDouble("momentum", "The momentum simply adds a fraction of the previous weight update to the current one (prevent local maxima and smoothes optimization directions).", 0.0d, 1.0d, 0.2d));
        parameterTypes.add(new ParameterTypeDouble("error_epsilon", "The optimization is stopped if the training error gets below this epsilon value.", 0.0d, Double.POSITIVE_INFINITY, 0.01d));
        parameterTypes.addAll(RandomGenerator.getRandomGeneratorParameters(this));
        return parameterTypes;
    }
}
