package org.encog.neural.networks.layers;

import org.encog.matrix.Matrix;
import org.encog.matrix.MatrixMath;
import org.encog.neural.NeuralNetworkError;
import org.encog.neural.activation.ActivationFunction;
import org.encog.neural.activation.ActivationSigmoid;
import org.encog.neural.data.NeuralData;
import org.encog.neural.data.basic.BasicNeuralData;
import org.encog.neural.networks.Layer;
import org.encog.neural.persist.EncogPersistedObject;
import org.encog.neural.persist.Persistor;
import org.encog.neural.persist.persistors.FeedforwardLayerPersistor;

/* loaded from: input_file:lib/encog.jar:org/encog/neural/networks/layers/FeedforwardLayer.class */
public class FeedforwardLayer extends BasicLayer implements EncogPersistedObject {
    private static final long serialVersionUID = -3698708039331150031L;
    private ActivationFunction activationFunction;

    public FeedforwardLayer(ActivationFunction activationFunction, int i) {
        super(i);
        this.activationFunction = activationFunction;
    }

    public FeedforwardLayer(int i) {
        this(new ActivationSigmoid(), i);
    }

    public FeedforwardLayer cloneStructure() {
        return new FeedforwardLayer(this.activationFunction, getNeuronCount());
    }

    @Override // org.encog.neural.networks.layers.BasicLayer, org.encog.neural.networks.Layer
    public NeuralData compute(NeuralData neuralData) {
        if (neuralData != null) {
            for (int i = 0; i < getNeuronCount(); i++) {
                setFire(i, neuralData.getData(i));
            }
        }
        Matrix createInputMatrix = createInputMatrix(getFire());
        for (int i2 = 0; i2 < getNext().getNeuronCount(); i2++) {
            getNext().setFire(i2, this.activationFunction.activationFunction(MatrixMath.dotProduct(getMatrix().getCol(i2), createInputMatrix)));
        }
        return getFire();
    }

    private Matrix createInputMatrix(NeuralData neuralData) {
        Matrix matrix = new Matrix(1, neuralData.size() + 1);
        for (int i = 0; i < neuralData.size(); i++) {
            matrix.set(0, i, neuralData.getData(i));
        }
        matrix.set(0, neuralData.size(), 1.0d);
        return matrix;
    }

    @Override // org.encog.neural.networks.layers.BasicLayer, org.encog.neural.persist.EncogPersistedObject
    public Persistor createPersistor() {
        return new FeedforwardLayerPersistor();
    }

    public ActivationFunction getActivationFunction() {
        return this.activationFunction;
    }

    public void prune(int i) {
        if (getMatrix() != null) {
            setMatrix(MatrixMath.deleteRow(getMatrix(), i));
        }
        Layer previous = getPrevious();
        if (previous == null || previous.getMatrix() == null) {
            return;
        }
        previous.setMatrix(MatrixMath.deleteCol(previous.getMatrix(), i));
    }

    public void setActivationFunction(ActivationFunction activationFunction) {
        this.activationFunction = activationFunction;
    }

    @Override // org.encog.neural.networks.layers.BasicLayer, org.encog.neural.networks.Layer
    public void setMatrix(Matrix matrix) {
        if (matrix.getRows() < 2) {
            throw new NeuralNetworkError("Weight matrix includes threshold values, and must have at least 2 rows.");
        }
        setFire(new BasicNeuralData(matrix.getRows() - 1));
        super.setMatrix(matrix);
    }

    @Override // org.encog.neural.networks.layers.BasicLayer
    public void setNeuronCount(int i) {
        setFire(new BasicNeuralData(i));
        if (getNext() != null) {
            setMatrix(new Matrix(getNeuronCount() + 1, getNext().getNeuronCount()));
        }
    }

    @Override // org.encog.neural.networks.layers.BasicLayer, org.encog.neural.networks.Layer
    public void setNext(Layer layer) {
        super.setNext(layer);
        if (hasMatrix() || getNext() == null) {
            return;
        }
        setMatrix(new Matrix(getNeuronCount() + 1, layer.getNeuronCount()));
    }

    public String toString() {
        return "[FeedforwardLayer: Neuron Count=" + getNeuronCount() + "]";
    }
}
