/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.pmml.consumer;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import weka.classifiers.pmml.consumer.PMMLClassifier;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.pmml.DerivedFieldMetaInfo;
import weka.core.pmml.FieldMetaInfo;
import weka.core.pmml.MiningSchema;
import weka.core.pmml.NormContinuous;
import weka.core.pmml.TargetMetaInfo;

public class NeuralNetwork
extends PMMLClassifier {
    private static final long serialVersionUID = -4545904813133921249L;
    protected MiningFunction m_functionType = MiningFunction.CLASSIFICATION;
    protected ActivationFunction m_activationFunction = ActivationFunction.ARCTAN;
    protected Normalization m_normalizationMethod = Normalization.NONE;
    protected double m_threshold = 0.0;
    protected double m_width = Double.NaN;
    protected double m_altitude = 1.0;
    protected int m_numberOfInputs = 0;
    protected int m_numberOfLayers = 0;
    protected NeuralInput[] m_inputs = null;
    protected HashMap<String, Double> m_inputMap = new HashMap();
    protected NeuralLayer[] m_layers = null;
    protected NeuralOutputs m_outputs = null;

    public NeuralNetwork(Element model, Instances dataDictionary, MiningSchema miningSchema) throws Exception {
        super(dataDictionary, miningSchema);
        String alt;
        String width;
        String thresh;
        String norm;
        String act;
        String fn = model.getAttribute("functionName");
        if (fn.equals("regression")) {
            this.m_functionType = MiningFunction.REGRESSION;
        }
        if ((act = model.getAttribute("activationFunction")) == null || act.length() == 0) {
            throw new Exception("[NeuralNetwork] no activation functon defined");
        }
        for (ActivationFunction a : ActivationFunction.values()) {
            if (!a.toString().equals(act)) continue;
            this.m_activationFunction = a;
            break;
        }
        if ((norm = model.getAttribute("normalizationMethod")) != null && norm.length() > 0) {
            for (Normalization n : Normalization.values()) {
                if (!n.toString().equals(norm)) continue;
                this.m_normalizationMethod = n;
                break;
            }
        }
        if ((thresh = model.getAttribute("threshold")) != null && thresh.length() > 0) {
            this.m_threshold = Double.parseDouble(thresh);
        }
        if ((width = model.getAttribute("width")) != null && width.length() > 0) {
            this.m_width = Double.parseDouble(width);
        }
        if ((alt = model.getAttribute("altitude")) != null && alt.length() > 0) {
            this.m_altitude = Double.parseDouble(alt);
        }
        NodeList inputL = model.getElementsByTagName("NeuralInput");
        this.m_numberOfInputs = inputL.getLength();
        this.m_inputs = new NeuralInput[this.m_numberOfInputs];
        for (int i = 0; i < this.m_numberOfInputs; ++i) {
            NeuralInput nI;
            Node inputN = inputL.item(i);
            if (inputN.getNodeType() != 1) continue;
            this.m_inputs[i] = nI = new NeuralInput((Element)inputN, this.m_miningSchema);
        }
        NodeList layerL = model.getElementsByTagName("NeuralLayer");
        this.m_numberOfLayers = layerL.getLength();
        this.m_layers = new NeuralLayer[this.m_numberOfLayers];
        for (int i = 0; i < this.m_numberOfLayers; ++i) {
            NeuralLayer nL;
            Node layerN = layerL.item(i);
            if (layerN.getNodeType() != 1) continue;
            this.m_layers[i] = nL = new NeuralLayer((Element)layerN);
        }
        NodeList outputL = model.getElementsByTagName("NeuralOutputs");
        if (outputL.getLength() != 1) {
            throw new Exception("[NeuralNetwork] Should be just one NeuralOutputs element defined!");
        }
        this.m_outputs = new NeuralOutputs((Element)outputL.item(0), this.m_miningSchema);
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 8034 $");
    }

    @Override
    public double[] distributionForInstance(Instance inst) throws Exception {
        int i;
        if (!this.m_initialized) {
            this.mapToMiningSchema(inst.dataset());
        }
        double[] preds = null;
        preds = this.m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric() ? new double[1] : new double[this.m_miningSchema.getFieldsAsInstances().classAttribute().numValues()];
        double[] incoming = this.m_fieldsMap.instanceToSchema(inst, this.m_miningSchema);
        boolean hasMissing = false;
        for (i = 0; i < incoming.length; ++i) {
            if (i == this.m_miningSchema.getFieldsAsInstances().classIndex() || !Double.isNaN(incoming[i])) continue;
            hasMissing = true;
            break;
        }
        if (hasMissing) {
            if (!this.m_miningSchema.hasTargetMetaData()) {
                String message = "[NeuralNetwork] WARNING: Instance to predict has missing value(s) but there is no missing value handling meta data and no prior probabilities/default value to fall back to. No prediction will be made (" + (this.m_miningSchema.getFieldsAsInstances().classAttribute().isNominal() || this.m_miningSchema.getFieldsAsInstances().classAttribute().isString() ? "zero probabilities output)." : "NaN output).");
                if (this.m_log == null) {
                    System.err.println(message);
                } else {
                    this.m_log.logMessage(message);
                }
                if (this.m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
                    preds[0] = Utils.missingValue();
                }
                return preds;
            }
            TargetMetaInfo targetData = this.m_miningSchema.getTargetMetaData();
            if (this.m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
                preds[0] = targetData.getDefaultValue();
            } else {
                Instances miningSchemaI = this.m_miningSchema.getFieldsAsInstances();
                for (int i2 = 0; i2 < miningSchemaI.classAttribute().numValues(); ++i2) {
                    preds[i2] = targetData.getPriorProbability(miningSchemaI.classAttribute().value(i2));
                }
            }
            return preds;
        }
        this.m_inputMap.clear();
        for (i = 0; i < this.m_inputs.length; ++i) {
            double networkInVal = this.m_inputs[i].getValue(incoming);
            String ID = this.m_inputs[i].getID();
            this.m_inputMap.put(ID, networkInVal);
        }
        HashMap<String, Double> layerOut = this.m_layers[0].computeOutput(this.m_inputMap);
        for (int i3 = 1; i3 < this.m_layers.length; ++i3) {
            layerOut = this.m_layers[i3].computeOutput(layerOut);
        }
        this.m_outputs.getOuput(layerOut, preds);
        return preds;
    }

    public String toString() {
        int i;
        StringBuffer temp = new StringBuffer();
        temp.append("PMML version " + this.getPMMLVersion());
        if (!this.getCreatorApplication().equals("?")) {
            temp.append("\nApplication: " + this.getCreatorApplication());
        }
        temp.append("\nPMML Model: Neural network");
        temp.append("\n\n");
        temp.append(this.m_miningSchema);
        temp.append("Inputs:\n");
        for (i = 0; i < this.m_inputs.length; ++i) {
            temp.append(this.m_inputs[i] + "\n");
        }
        for (i = 0; i < this.m_layers.length; ++i) {
            temp.append("Layer: " + (i + 1) + "\n");
            temp.append(this.m_layers[i] + "\n");
        }
        temp.append("Outputs:\n");
        temp.append(this.m_outputs);
        return temp.toString();
    }

    static enum MiningFunction {
        CLASSIFICATION,
        REGRESSION;

    }

    static enum ActivationFunction {
        THRESHOLD("threshold"){

            @Override
            double eval(double z, double threshold, double altitude, double fanIn) {
                if (z > threshold) {
                    return 1.0;
                }
                return 0.0;
            }
        }
        ,
        LOGISTIC("logistic"){

            @Override
            double eval(double z, double threshold, double altitude, double fanIn) {
                return 1.0 / (1.0 + Math.exp(-z));
            }
        }
        ,
        TANH("tanh"){

            @Override
            double eval(double z, double threshold, double altitude, double fanIn) {
                double a = Math.exp(z);
                double b = Math.exp(-z);
                return (a - b) / (a + b);
            }
        }
        ,
        IDENTITY("identity"){

            @Override
            double eval(double z, double threshold, double altitude, double fanIn) {
                return z;
            }
        }
        ,
        EXPONENTIAL("exponential"){

            @Override
            double eval(double z, double threshold, double altitude, double fanIn) {
                return Math.exp(z);
            }
        }
        ,
        RECIPROCAL("reciprocal"){

            @Override
            double eval(double z, double threshold, double altitude, double fanIn) {
                return 1.0 / z;
            }
        }
        ,
        SQUARE("square"){

            @Override
            double eval(double z, double threshold, double altitude, double fanIn) {
                return z * z;
            }
        }
        ,
        GAUSS("gauss"){

            @Override
            double eval(double z, double threshold, double altitude, double fanIn) {
                return Math.exp(-(z * z));
            }
        }
        ,
        SINE("sine"){

            @Override
            double eval(double z, double threshold, double altitude, double fanIn) {
                return Math.sin(z);
            }
        }
        ,
        COSINE("cosine"){

            @Override
            double eval(double z, double threshold, double altitude, double fanIn) {
                return Math.cos(z);
            }
        }
        ,
        ELLICOT("ellicot"){

            @Override
            double eval(double z, double threshold, double altitude, double fanIn) {
                return z / (1.0 + Math.abs(z));
            }
        }
        ,
        ARCTAN("arctan"){

            @Override
            double eval(double z, double threshold, double altitude, double fanIn) {
                return 2.0 * Math.atan(z) / Math.PI;
            }
        }
        ,
        RADIALBASIS("radialBasis"){

            @Override
            double eval(double z, double threshold, double altitude, double fanIn) {
                return Math.exp(fanIn * Math.log(altitude) - z);
            }
        };

        private final String m_stringVal;

        abstract double eval(double var1, double var3, double var5, double var7);

        private ActivationFunction(String name) {
            this.m_stringVal = name;
        }

        public String toString() {
            return this.m_stringVal;
        }
    }

    static enum Normalization {
        NONE("none"),
        SIMPLEMAX("simplemax"),
        SOFTMAX("softmax");

        private final String m_stringVal;

        private Normalization(String name) {
            this.m_stringVal = name;
        }

        public String toString() {
            return this.m_stringVal;
        }
    }

    static class NeuralInput
    implements Serializable {
        private static final long serialVersionUID = -1902233762824835563L;
        private DerivedFieldMetaInfo m_field;
        private String m_ID = null;

        private String getID() {
            return this.m_ID;
        }

        protected NeuralInput(Element input, MiningSchema miningSchema) throws Exception {
            this.m_ID = input.getAttribute("id");
            NodeList fL = input.getElementsByTagName("DerivedField");
            if (fL.getLength() != 1) {
                throw new Exception("[NeuralInput] expecting just one derived field!");
            }
            Element dF = (Element)fL.item(0);
            Instances allFields = miningSchema.getFieldsAsInstances();
            ArrayList<Attribute> fieldDefs = new ArrayList<Attribute>();
            for (int i = 0; i < allFields.numAttributes(); ++i) {
                fieldDefs.add(allFields.attribute(i));
            }
            this.m_field = new DerivedFieldMetaInfo(dF, fieldDefs, miningSchema.getTransformationDictionary());
        }

        protected double getValue(double[] incoming) throws Exception {
            return this.m_field.getDerivedValue(incoming);
        }

        public String toString() {
            StringBuffer temp = new StringBuffer();
            temp.append("Nueral input (" + this.getID() + ")\n");
            temp.append(this.m_field);
            return temp.toString();
        }
    }

    class NeuralLayer
    implements Serializable {
        private static final long serialVersionUID = -8386042001675763922L;
        private int m_numNeurons = 0;
        private ActivationFunction m_layerActivationFunction = null;
        private double m_layerThreshold = Double.NaN;
        private double m_layerWidth = Double.NaN;
        private double m_layerAltitude = Double.NaN;
        private Normalization m_layerNormalization = null;
        private Neuron[] m_layerNeurons = null;
        private HashMap<String, Double> m_layerOutput = new HashMap();

        protected NeuralLayer(Element layerE) {
            String threshold;
            String activationFunction = layerE.getAttribute("activationFunction");
            if (activationFunction != null && activationFunction.length() > 0) {
                for (ActivationFunction a : ActivationFunction.values()) {
                    if (!a.toString().equals(activationFunction)) continue;
                    this.m_layerActivationFunction = a;
                    break;
                }
            } else {
                this.m_layerActivationFunction = NeuralNetwork.this.m_activationFunction;
            }
            this.m_layerThreshold = (threshold = layerE.getAttribute("threshold")) != null && threshold.length() > 0 ? Double.parseDouble(threshold) : NeuralNetwork.this.m_threshold;
            String width = layerE.getAttribute("width");
            this.m_layerWidth = width != null && width.length() > 0 ? Double.parseDouble(width) : NeuralNetwork.this.m_width;
            String altitude = layerE.getAttribute("altitude");
            this.m_layerAltitude = altitude != null && altitude.length() > 0 ? Double.parseDouble(altitude) : NeuralNetwork.this.m_altitude;
            String normMethod = layerE.getAttribute("normalizationMethod");
            if (normMethod != null && normMethod.length() > 0) {
                for (Normalization n : Normalization.values()) {
                    if (!n.toString().equals(normMethod)) continue;
                    this.m_layerNormalization = n;
                    break;
                }
            } else {
                this.m_layerNormalization = NeuralNetwork.this.m_normalizationMethod;
            }
            NodeList neuronL = layerE.getElementsByTagName("Neuron");
            this.m_numNeurons = neuronL.getLength();
            this.m_layerNeurons = new Neuron[this.m_numNeurons];
            for (int i = 0; i < neuronL.getLength(); ++i) {
                Node neuronN = neuronL.item(i);
                if (neuronN.getNodeType() != 1) continue;
                this.m_layerNeurons[i] = new Neuron((Element)neuronN, this);
            }
        }

        protected ActivationFunction getActivationFunction() {
            return this.m_layerActivationFunction;
        }

        protected double getThreshold() {
            return this.m_layerThreshold;
        }

        protected double getWidth() {
            return this.m_layerWidth;
        }

        protected double getAltitude() {
            return this.m_layerAltitude;
        }

        protected Normalization getNormalization() {
            return this.m_layerNormalization;
        }

        protected HashMap<String, Double> computeOutput(HashMap<String, Double> incoming) throws Exception {
            int i;
            this.m_layerOutput.clear();
            double normSum = 0.0;
            for (i = 0; i < this.m_layerNeurons.length; ++i) {
                double neuronOut = this.m_layerNeurons[i].getValue(incoming);
                String neuronID = this.m_layerNeurons[i].getID();
                if (this.m_layerNormalization == Normalization.SOFTMAX) {
                    normSum += Math.exp(neuronOut);
                } else if (this.m_layerNormalization == Normalization.SIMPLEMAX) {
                    normSum += neuronOut;
                }
                this.m_layerOutput.put(neuronID, neuronOut);
            }
            if (this.m_layerNormalization != Normalization.NONE) {
                for (i = 0; i < this.m_layerNeurons.length; ++i) {
                    double val = this.m_layerOutput.get(this.m_layerNeurons[i].getID());
                    val = this.m_layerNormalization == Normalization.SOFTMAX ? Math.exp(val) / normSum : (val /= normSum);
                    this.m_layerOutput.put(this.m_layerNeurons[i].getID(), val);
                }
            }
            return this.m_layerOutput;
        }

        public String toString() {
            StringBuffer temp = new StringBuffer();
            temp.append("activation: " + (Object)((Object)this.getActivationFunction()) + "\n");
            if (!Double.isNaN(this.getThreshold())) {
                temp.append("threshold: " + this.getThreshold() + "\n");
            }
            if (!Double.isNaN(this.getWidth())) {
                temp.append("width: " + this.getWidth() + "\n");
            }
            if (!Double.isNaN(this.getAltitude())) {
                temp.append("altitude: " + this.getAltitude() + "\n");
            }
            temp.append("normalization: " + (Object)((Object)this.m_layerNormalization) + "\n");
            for (int i = 0; i < this.m_numNeurons; ++i) {
                temp.append(this.m_layerNeurons[i] + "\n");
            }
            return temp.toString();
        }
    }

    static class NeuralOutputs
    implements Serializable {
        private static final long serialVersionUID = -233611113950482952L;
        private String[] m_outputNeurons = null;
        private int[] m_categoricalIndexes = null;
        private Attribute m_classAttribute = null;
        private NormContinuous m_regressionMapping = null;

        protected NeuralOutputs(Element outputs, MiningSchema miningSchema) throws Exception {
            this.m_classAttribute = miningSchema.getMiningSchemaAsInstances().classAttribute();
            int vals = this.m_classAttribute.isNumeric() ? 1 : this.m_classAttribute.numValues();
            this.m_outputNeurons = new String[vals];
            this.m_categoricalIndexes = new int[vals];
            NodeList outputL = outputs.getElementsByTagName("NeuralOutput");
            if (outputL.getLength() != this.m_outputNeurons.length) {
                throw new Exception("[NeuralOutputs] the number of neural outputs does not match the number expected!");
            }
            for (int i = 0; i < outputL.getLength(); ++i) {
                Node outputN = outputL.item(i);
                if (outputN.getNodeType() != 1) continue;
                Element outputE = (Element)outputN;
                this.m_outputNeurons[i] = outputE.getAttribute("outputNeuron");
                if (this.m_classAttribute.isNumeric()) {
                    NodeList contL = outputE.getElementsByTagName("NormContinuous");
                    if (contL.getLength() != 1) {
                        throw new Exception("[NeuralOutputs] Should be exactly one norm continuous element for numeric class!");
                    }
                    Node normContNode = contL.item(0);
                    String attName = ((Element)normContNode).getAttribute("field");
                    Attribute dummyTargetDef = new Attribute(attName);
                    ArrayList<Attribute> dummyFieldDefs = new ArrayList<Attribute>();
                    dummyFieldDefs.add(dummyTargetDef);
                    this.m_regressionMapping = new NormContinuous((Element)normContNode, FieldMetaInfo.Optype.CONTINUOUS, dummyFieldDefs);
                    break;
                }
                NodeList discL = outputE.getElementsByTagName("NormDiscrete");
                if (discL.getLength() != 1) {
                    throw new Exception("[NeuralOutputs] Should be only one norm discrete element per derived field/neural output for a nominal class!");
                }
                Node normDiscNode = discL.item(0);
                String attValue = ((Element)normDiscNode).getAttribute("value");
                int index = this.m_classAttribute.indexOfValue(attValue);
                if (index < 0) {
                    throw new Exception("[NeuralOutputs] Can't find specified target value " + attValue + " in class attribute " + this.m_classAttribute.name());
                }
                this.m_categoricalIndexes[i] = index;
            }
        }

        protected void getOuput(HashMap<String, Double> incoming, double[] preds) throws Exception {
            if (preds.length != this.m_outputNeurons.length) {
                throw new Exception("[NeuralOutputs] Incorrect number of predictions requested: " + preds.length + "requested, " + this.m_outputNeurons.length + " expected");
            }
            for (int i = 0; i < this.m_outputNeurons.length; ++i) {
                Double neuronOut = incoming.get(this.m_outputNeurons[i]);
                if (neuronOut == null) {
                    throw new Exception("[NeuralOutputs] Unable to find output neuron " + this.m_outputNeurons[i] + " in the incoming HashMap!!");
                }
                if (this.m_classAttribute.isNumeric()) {
                    preds[0] = neuronOut;
                    preds[0] = this.m_regressionMapping.getResultInverse(preds);
                    continue;
                }
                preds[this.m_categoricalIndexes[i]] = neuronOut;
            }
            if (this.m_classAttribute.isNominal()) {
                double min = preds[Utils.minIndex(preds)];
                if (min < 0.0) {
                    int i = 0;
                    while (i < preds.length) {
                        int n = i++;
                        preds[n] = preds[n] - min;
                    }
                }
                Utils.normalize(preds);
            }
        }

        public String toString() {
            StringBuffer temp = new StringBuffer();
            for (int i = 0; i < this.m_outputNeurons.length; ++i) {
                temp.append("Output neuron (" + this.m_outputNeurons[i] + ")\n");
                temp.append("mapping:\n");
                if (this.m_classAttribute.isNumeric()) {
                    temp.append(this.m_regressionMapping + "\n");
                    continue;
                }
                temp.append(this.m_classAttribute.name() + " = " + this.m_classAttribute.value(this.m_categoricalIndexes[i]) + "\n");
            }
            return temp.toString();
        }
    }

    static class Neuron
    implements Serializable {
        private static final long serialVersionUID = -3817434025682603443L;
        private String m_ID = null;
        private NeuralLayer m_layer;
        private double m_bias = 0.0;
        private double m_neuronWidth = Double.NaN;
        private double m_neuronAltitude = Double.NaN;
        private String[] m_connectionIDs = null;
        private double[] m_weights = null;

        protected Neuron(Element neuronE, NeuralLayer layer) {
            String altitude;
            String width;
            this.m_layer = layer;
            this.m_ID = neuronE.getAttribute("id");
            String bias = neuronE.getAttribute("bias");
            if (bias != null && bias.length() > 0) {
                this.m_bias = Double.parseDouble(bias);
            }
            if ((width = neuronE.getAttribute("width")) != null && width.length() > 0) {
                this.m_neuronWidth = Double.parseDouble(width);
            }
            if ((altitude = neuronE.getAttribute("altitude")) != null && altitude.length() > 0) {
                this.m_neuronAltitude = Double.parseDouble(altitude);
            }
            NodeList conL = neuronE.getElementsByTagName("Con");
            this.m_connectionIDs = new String[conL.getLength()];
            this.m_weights = new double[conL.getLength()];
            for (int i = 0; i < conL.getLength(); ++i) {
                Node conN = conL.item(i);
                if (conN.getNodeType() != 1) continue;
                Element conE = (Element)conN;
                this.m_connectionIDs[i] = conE.getAttribute("from");
                String weight = conE.getAttribute("weight");
                this.m_weights[i] = Double.parseDouble(weight);
            }
        }

        protected String getID() {
            return this.m_ID;
        }

        protected double getValue(HashMap<String, Double> incoming) throws Exception {
            double z = 0.0;
            double result = Double.NaN;
            double width = Double.isNaN(this.m_neuronWidth) ? this.m_layer.getWidth() : this.m_neuronWidth;
            z = this.m_bias;
            for (int i = 0; i < this.m_connectionIDs.length; ++i) {
                double inV;
                Double inVal = incoming.get(this.m_connectionIDs[i]);
                if (inVal == null) {
                    throw new Exception("[Neuron] unable to find connection " + this.m_connectionIDs[i] + " in input Map!");
                }
                if (this.m_layer.getActivationFunction() != ActivationFunction.RADIALBASIS) {
                    inV = inVal * this.m_weights[i];
                    z += inV;
                    continue;
                }
                inV = Math.pow(inVal - this.m_weights[i], 2.0);
                z += inV;
            }
            if (this.m_layer.getActivationFunction() == ActivationFunction.RADIALBASIS) {
                z /= 2.0 * (width * width);
            }
            double threshold = this.m_layer.getThreshold();
            double altitude = Double.isNaN(this.m_neuronAltitude) ? this.m_layer.getAltitude() : this.m_neuronAltitude;
            double fanIn = this.m_connectionIDs.length;
            result = this.m_layer.getActivationFunction().eval(z, threshold, altitude, fanIn);
            return result;
        }

        public String toString() {
            StringBuffer temp = new StringBuffer();
            temp.append("Nueron (" + this.m_ID + ") [bias:" + this.m_bias);
            if (!Double.isNaN(this.m_neuronWidth)) {
                temp.append(" width:" + this.m_neuronWidth);
            }
            if (!Double.isNaN(this.m_neuronAltitude)) {
                temp.append(" altitude:" + this.m_neuronAltitude);
            }
            temp.append("]\n");
            temp.append("  con. (ID:weight): ");
            for (int i = 0; i < this.m_connectionIDs.length; ++i) {
                temp.append(this.m_connectionIDs[i] + ":" + Utils.doubleToString(this.m_weights[i], 2));
                if ((i + 1) % 10 == 0 || i == this.m_connectionIDs.length - 1) {
                    temp.append("\n                    ");
                    continue;
                }
                temp.append(", ");
            }
            return temp.toString();
        }
    }
}

