package edu.stanford.nlp.ie.crf;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.math.SloppyMath;
import edu.stanford.nlp.util.Index;
import java.util.ArrayList;
import java.util.Arrays;
import opennlp.tools.parser.Parse;
import ws.palladian.helper.io.FileHelper;

/* loaded from: input_file:lib/palladian.jar:edu/stanford/nlp/ie/crf/FactorTable.class */
public class FactorTable {
    private final int numClasses;
    private final int windowSize;
    private final double[] table;

    public FactorTable(int i, int i2) {
        this.numClasses = i;
        this.windowSize = i2;
        this.table = new double[SloppyMath.intPow(i, i2)];
        Arrays.fill(this.table, Double.NEGATIVE_INFINITY);
    }

    public FactorTable(FactorTable factorTable) {
        this.numClasses = factorTable.numClasses();
        this.windowSize = factorTable.windowSize();
        this.table = new double[factorTable.size()];
        System.arraycopy(factorTable.table, 0, this.table, 0, factorTable.size());
    }

    public boolean hasNaN() {
        return ArrayMath.hasNaN(this.table);
    }

    public String toProbString() {
        StringBuilder sb = new StringBuilder("{\n");
        for (int i = 0; i < this.table.length; i++) {
            sb.append(Arrays.toString(toArray(i)));
            sb.append(": ");
            sb.append(prob(toArray(i)));
            sb.append(FileHelper.NEWLINE_CHARACTER);
        }
        sb.append(Parse.BRACKET_RCB);
        return sb.toString();
    }

    public String toNonLogString() {
        StringBuilder sb = new StringBuilder("{\n");
        for (int i = 0; i < this.table.length; i++) {
            sb.append(Arrays.toString(toArray(i)));
            sb.append(": ");
            sb.append(Math.exp(getValue(i)));
            sb.append(FileHelper.NEWLINE_CHARACTER);
        }
        sb.append(Parse.BRACKET_RCB);
        return sb.toString();
    }

    public <L> String toString(Index<L> index) {
        StringBuilder sb = new StringBuilder("{\n");
        for (int i = 0; i < this.table.length; i++) {
            sb.append(toString(toArray(i), index));
            sb.append(": ");
            sb.append(getValue(i));
            sb.append(FileHelper.NEWLINE_CHARACTER);
        }
        sb.append(Parse.BRACKET_RCB);
        return sb.toString();
    }

    public String toString() {
        StringBuilder sb = new StringBuilder("{\n");
        for (int i = 0; i < this.table.length; i++) {
            sb.append(Arrays.toString(toArray(i)));
            sb.append(": ");
            sb.append(getValue(i));
            sb.append(FileHelper.NEWLINE_CHARACTER);
        }
        sb.append(Parse.BRACKET_RCB);
        return sb.toString();
    }

    private static <L> String toString(int[] iArr, Index<L> index) {
        ArrayList arrayList = new ArrayList(iArr.length);
        for (int i : iArr) {
            arrayList.add(index.get(i));
        }
        return arrayList.toString();
    }

    public int[] toArray(int i) {
        int[] iArr = new int[this.windowSize];
        for (int length = iArr.length - 1; length >= 0; length--) {
            iArr[length] = i % this.numClasses;
            i /= this.numClasses;
        }
        return iArr;
    }

    private int indexOf(int[] iArr) {
        int i = 0;
        for (int i2 : iArr) {
            i = (i * this.numClasses) + i2;
        }
        return i;
    }

    private int indexOf(int[] iArr, int i) {
        int i2 = 0;
        for (int i3 : iArr) {
            i2 = (i2 * this.numClasses) + i3;
        }
        return (i2 * this.numClasses) + i;
    }

    private int indexOf(int i, int[] iArr) {
        int i2 = i;
        for (int i3 : iArr) {
            i2 = (i2 * this.numClasses) + i3;
        }
        return i2;
    }

    private int[] indicesEnd(int[] iArr) {
        int i = 0;
        for (int i2 : iArr) {
            i = (i * this.numClasses) + i2;
        }
        int[] iArr2 = new int[SloppyMath.intPow(this.numClasses, this.windowSize - iArr.length)];
        int intPow = SloppyMath.intPow(this.numClasses, iArr.length);
        for (int i3 = 0; i3 < iArr2.length; i3++) {
            iArr2[i3] = i;
            i += intPow;
        }
        return iArr2;
    }

    private int indicesFront(int[] iArr) {
        int i = 0;
        for (int i2 : iArr) {
            i = (i * this.numClasses) + i2;
        }
        return i * SloppyMath.intPow(this.numClasses, this.windowSize - iArr.length);
    }

    public int windowSize() {
        return this.windowSize;
    }

    public int numClasses() {
        return this.numClasses;
    }

    public int size() {
        return this.table.length;
    }

    public double totalMass() {
        return ArrayMath.logSum(this.table);
    }

    public double unnormalizedLogProb(int[] iArr) {
        return getValue(iArr);
    }

    public double logProb(int[] iArr) {
        return unnormalizedLogProb(iArr) - totalMass();
    }

    public double prob(int[] iArr) {
        return Math.exp(unnormalizedLogProb(iArr) - totalMass());
    }

    public double conditionalLogProbGivenPrevious(int[] iArr, int i) {
        if (iArr.length != this.windowSize - 1) {
            throw new IllegalArgumentException("conditionalLogProbGivenPrevious requires given one less than clique size (" + this.windowSize + ") but was " + Arrays.toString(iArr));
        }
        int indicesFront = indicesFront(iArr);
        return this.table[indicesFront + i] - ArrayMath.logSum(this.table, indicesFront, indicesFront + this.numClasses);
    }

    public double[] conditionalLogProbsGivenPrevious(int[] iArr) {
        if (iArr.length != this.windowSize - 1) {
            throw new IllegalArgumentException("conditionalLogProbsGivenPrevious requires given one less than clique size (" + this.windowSize + ") but was " + Arrays.toString(iArr));
        }
        double[] dArr = new double[this.numClasses];
        for (int i = 0; i < this.numClasses; i++) {
            dArr[i] = this.table[indexOf(iArr, i)];
        }
        ArrayMath.logNormalize(dArr);
        return dArr;
    }

    public double conditionalLogProbGivenFirst(int i, int[] iArr) {
        if (iArr.length != this.windowSize - 1) {
            throw new IllegalArgumentException("conditionalLogProbGivenFirst requires of one less than clique size (" + this.windowSize + ") but was " + Arrays.toString(iArr));
        }
        int[] iArr2 = new int[this.windowSize];
        iArr2[0] = i;
        System.arraycopy(iArr, 0, iArr2, 1, this.windowSize - 1);
        return unnormalizedLogProb(iArr2) - unnormalizedLogProbFront(i);
    }

    public double unnormalizedConditionalLogProbGivenFirst(int i, int[] iArr) {
        if (iArr.length != this.windowSize - 1) {
            throw new IllegalArgumentException("unnormalizedConditionalLogProbGivenFirst requires of one less than clique size (" + this.windowSize + ") but was " + Arrays.toString(iArr));
        }
        int[] iArr2 = new int[this.windowSize];
        iArr2[0] = i;
        System.arraycopy(iArr, 0, iArr2, 1, this.windowSize - 1);
        return unnormalizedLogProb(iArr2);
    }

    public double conditionalLogProbGivenNext(int[] iArr, int i) {
        if (iArr.length != this.windowSize - 1) {
            throw new IllegalArgumentException("conditionalLogProbGivenNext requires given one less than clique size (" + this.windowSize + ") but was " + Arrays.toString(iArr));
        }
        int[] indicesEnd = indicesEnd(iArr);
        double[] dArr = new double[indicesEnd.length];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = this.table[indicesEnd[i2]];
        }
        return this.table[indexOf(i, iArr)] - ArrayMath.logSum(dArr);
    }

    public double unnormalizedLogProbFront(int[] iArr) {
        int indicesFront = indicesFront(iArr);
        return ArrayMath.logSum(this.table, indicesFront, indicesFront + SloppyMath.intPow(this.numClasses, this.windowSize - iArr.length));
    }

    public double logProbFront(int[] iArr) {
        return unnormalizedLogProbFront(iArr) - totalMass();
    }

    public double unnormalizedLogProbFront(int i) {
        return unnormalizedLogProbFront(new int[]{i});
    }

    public double logProbFront(int i) {
        return unnormalizedLogProbFront(i) - totalMass();
    }

    public double unnormalizedLogProbEnd(int[] iArr) {
        int[] indicesEnd = indicesEnd(iArr);
        double[] dArr = new double[indicesEnd.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = this.table[indicesEnd[i]];
        }
        return ArrayMath.logSum(dArr);
    }

    public double logProbEnd(int[] iArr) {
        return unnormalizedLogProbEnd(iArr) - totalMass();
    }

    public double unnormalizedLogProbEnd(int i) {
        return unnormalizedLogProbEnd(new int[]{i});
    }

    public double logProbEnd(int i) {
        return unnormalizedLogProbEnd(i) - totalMass();
    }

    public double getValue(int i) {
        return this.table[i];
    }

    public double getValue(int[] iArr) {
        return this.table[indexOf(iArr)];
    }

    public void setValue(int i, double d) {
        this.table[i] = d;
    }

    public void setValue(int[] iArr, double d) {
        this.table[indexOf(iArr)] = d;
    }

    public void incrementValue(int[] iArr, double d) {
        incrementValue(indexOf(iArr), d);
    }

    public void incrementValue(int i, double d) {
        double[] dArr = this.table;
        dArr[i] = dArr[i] + d;
    }

    void logIncrementValue(int i, double d) {
        this.table[i] = SloppyMath.logAdd(this.table[i], d);
    }

    public void logIncrementValue(int[] iArr, double d) {
        logIncrementValue(indexOf(iArr), d);
    }

    public void multiplyInFront(FactorTable factorTable) {
        int intPow = SloppyMath.intPow(this.numClasses, this.windowSize - factorTable.windowSize());
        for (int i = 0; i < this.table.length; i++) {
            double[] dArr = this.table;
            int i2 = i;
            dArr[i2] = dArr[i2] + factorTable.getValue(i / intPow);
        }
    }

    public void multiplyInEnd(FactorTable factorTable) {
        int intPow = SloppyMath.intPow(this.numClasses, factorTable.windowSize());
        for (int i = 0; i < this.table.length; i++) {
            double[] dArr = this.table;
            int i2 = i;
            dArr[i2] = dArr[i2] + factorTable.getValue(i % intPow);
        }
    }

    public FactorTable sumOutEnd() {
        FactorTable factorTable = new FactorTable(this.numClasses, this.windowSize - 1);
        int size = factorTable.size();
        for (int i = 0; i < size; i++) {
            factorTable.table[i] = ArrayMath.logSum(this.table, i * this.numClasses, (i + 1) * this.numClasses);
        }
        return factorTable;
    }

    public FactorTable sumOutFront() {
        FactorTable factorTable = new FactorTable(this.numClasses, this.windowSize - 1);
        int size = factorTable.size();
        for (int i = 0; i < size; i++) {
            factorTable.setValue(i, ArrayMath.logSum(this.table, i, this.table.length, size));
        }
        return factorTable;
    }

    public void divideBy(FactorTable factorTable) {
        for (int i = 0; i < this.table.length; i++) {
            if (this.table[i] != Double.NEGATIVE_INFINITY || factorTable.table[i] != Double.NEGATIVE_INFINITY) {
                double[] dArr = this.table;
                int i2 = i;
                dArr[i2] = dArr[i2] - factorTable.table[i];
            }
        }
    }

    public static void main(String[] strArr) {
        System.err.printf("Creating factor table with %d classes and window (clique) size %d%n", 6, 3);
        FactorTable factorTable = new FactorTable(6, 3);
        for (int i = 0; i < 6; i++) {
            for (int i2 = 0; i2 < 6; i2++) {
                for (int i3 = 0; i3 < 6; i3++) {
                    factorTable.setValue(new int[]{i, i2, i3}, (i * 4) + (i2 * 2) + i3);
                }
            }
        }
        System.err.println(factorTable);
        double d = 0.0d;
        for (int i4 = 0; i4 < 6; i4++) {
            for (int i5 = 0; i5 < 6; i5++) {
                for (int i6 = 0; i6 < 6; i6++) {
                    d += factorTable.unnormalizedLogProb(new int[]{i4, i5, i6});
                }
            }
        }
        System.err.println("Normalization Z = " + d);
        System.err.println(factorTable.sumOutFront());
        FactorTable factorTable2 = new FactorTable(6, 2);
        for (int i7 = 0; i7 < 6; i7++) {
            for (int i8 = 0; i8 < 6; i8++) {
                factorTable2.setValue(new int[]{i7, i8}, (i7 * 6) + i8);
            }
        }
        System.err.println(factorTable2);
        for (int i9 = 0; i9 < 6; i9++) {
            for (int i10 = 0; i10 < 6; i10++) {
                int[] iArr = {i9, i10};
                double d2 = 0.0d;
                for (int i11 = 0; i11 < 6; i11++) {
                    d2 += Math.exp(factorTable.conditionalLogProbGivenPrevious(iArr, i11));
                    System.err.println(i11 + "|" + i9 + "," + i10 + " : " + Math.exp(factorTable.conditionalLogProbGivenPrevious(iArr, i11)));
                }
                System.err.println(d2);
            }
        }
        System.err.println("conditionalLogProbGivenFirst");
        for (int i12 = 0; i12 < 6; i12++) {
            for (int i13 = 0; i13 < 6; i13++) {
                int[] iArr2 = {i12, i13};
                double d3 = 0.0d;
                for (int i14 = 0; i14 < 6; i14++) {
                    d3 += factorTable.unnormalizedConditionalLogProbGivenFirst(i14, iArr2);
                    System.err.println(i14 + "|" + i12 + "," + i13 + " : " + factorTable.unnormalizedConditionalLogProbGivenFirst(i14, iArr2));
                }
                System.err.println(d3);
            }
        }
        System.err.println("conditionalLogProbGivenFirst");
        for (int i15 = 0; i15 < 6; i15++) {
            for (int i16 = 0; i16 < 6; i16++) {
                int[] iArr3 = {i15, i16};
                double d4 = 0.0d;
                for (int i17 = 0; i17 < 6; i17++) {
                    d4 += factorTable.conditionalLogProbGivenNext(iArr3, i17);
                    System.err.println(i15 + "," + i16 + "|" + i17 + " : " + factorTable.conditionalLogProbGivenNext(iArr3, i17));
                }
                System.err.println(d4);
            }
        }
        FactorTable factorTable3 = new FactorTable(2, 3);
        factorTable3.setValue(new int[]{0, 0, 0}, Math.log(0.25d));
        factorTable3.setValue(new int[]{0, 0, 1}, Math.log(0.35d));
        factorTable3.setValue(new int[]{0, 1, 0}, Math.log(0.05d));
        factorTable3.setValue(new int[]{0, 1, 1}, Math.log(0.07d));
        factorTable3.setValue(new int[]{1, 0, 0}, Math.log(0.08d));
        factorTable3.setValue(new int[]{1, 0, 1}, Math.log(0.16d));
        factorTable3.setValue(new int[]{1, 1, 0}, Math.log(1.0E-50d));
        factorTable3.setValue(new int[]{1, 1, 1}, Math.log(1.0E-50d));
        System.err.println(factorTable3.sumOutFront().toNonLogString());
        System.err.println(factorTable3.sumOutEnd().toNonLogString());
    }
}
