package com.aliasi.lm;

import com.aliasi.corpus.ObjectHandler;
import com.aliasi.corpus.TextHandler;
import com.aliasi.io.BitInput;
import com.aliasi.io.BitOutput;
import com.aliasi.lm.LanguageModel;
import com.aliasi.stats.Model;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Arrays;
import com.aliasi.util.Math;
import com.aliasi.util.Strings;
import com.sun.media.jai.util.ImageUtil;
import java.io.Externalizable;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.LinkedList;
import org.jdesktop.swingx.JXLabel;
import ws.palladian.helper.io.FileHelper;

/* loaded from: input_file:lib/palladian.jar:com/aliasi/lm/NGramProcessLM.class */
public class NGramProcessLM implements TextHandler, Model<CharSequence>, LanguageModel.Process, LanguageModel.Conditional, LanguageModel.Dynamic, ObjectHandler<CharSequence>, Serializable {
    static final long serialVersionUID = -2865886217715962249L;
    private final TrieCharSeqCounter mTrieCharSeqCounter;
    private final int mMaxNGram;
    private double mLambdaFactor;
    private int mNumChars;
    private double mUniformEstimate;
    private double mLog2UniformEstimate;

    /* loaded from: input_file:lib/palladian.jar:com/aliasi/lm/NGramProcessLM$Externalizer.class */
    static class Externalizer extends AbstractExternalizable {
        static final long serialVersionUID = -3623859317152451545L;
        final NGramProcessLM mLM;

        public Externalizer() {
            this(null);
        }

        public Externalizer(NGramProcessLM nGramProcessLM) {
            this.mLM = nGramProcessLM;
        }

        @Override // com.aliasi.util.AbstractExternalizable
        public Object read(ObjectInput objectInput) throws IOException {
            return new CompiledNGramProcessLM(objectInput);
        }

        @Override // com.aliasi.util.AbstractExternalizable, java.io.Externalizable
        public void writeExternal(ObjectOutput objectOutput) throws IOException {
            objectOutput.writeInt(this.mLM.mMaxNGram);
            objectOutput.writeFloat((float) this.mLM.mLog2UniformEstimate);
            long uniqueSequenceCount = this.mLM.mTrieCharSeqCounter.uniqueSequenceCount();
            if (uniqueSequenceCount > 2147483647L) {
                throw new IllegalArgumentException("Maximum number of compiled nodes is Integer.MAX_VALUE = 2147483647 Found number of nodes=" + uniqueSequenceCount);
            }
            objectOutput.writeInt((int) uniqueSequenceCount);
            int lastInternalNodeIndex = this.mLM.lastInternalNodeIndex();
            objectOutput.writeInt(lastInternalNodeIndex);
            objectOutput.writeChar(ImageUtil.USHORT_MASK);
            objectOutput.writeFloat((float) this.mLM.mLog2UniformEstimate);
            double lambda = 1.0d - this.mLM.lambda(this.mLM.mTrieCharSeqCounter.mRootNode);
            objectOutput.writeFloat(Double.isNaN(lambda) ? 0.0f : (float) Math.log2(lambda));
            objectOutput.writeInt(1);
            char[] observedCharacters = this.mLM.mTrieCharSeqCounter.observedCharacters();
            LinkedList linkedList = new LinkedList();
            for (char c : observedCharacters) {
                linkedList.add(new char[]{c});
            }
            int i = 1;
            while (!linkedList.isEmpty()) {
                char[] cArr = (char[]) linkedList.removeFirst();
                objectOutput.writeChar(cArr[cArr.length - 1]);
                objectOutput.writeFloat((float) this.mLM.log2ConditionalEstimate(cArr, 0, cArr.length));
                if (i <= lastInternalNodeIndex) {
                    objectOutput.writeFloat((float) Math.log2(1.0d - this.mLM.lambda(cArr, 0, cArr.length)));
                    objectOutput.writeInt(i + linkedList.size() + 1);
                }
                for (char c2 : this.mLM.mTrieCharSeqCounter.charactersFollowing(cArr, 0, cArr.length)) {
                    linkedList.add(Arrays.concatenate(cArr, c2));
                }
                i++;
            }
        }
    }

    /* loaded from: input_file:lib/palladian.jar:com/aliasi/lm/NGramProcessLM$Serializer.class */
    static class Serializer implements Externalizable {
        static final long serialVersionUID = -7101238964823109652L;
        NGramProcessLM mLM;

        public Serializer() {
        }

        public Serializer(NGramProcessLM nGramProcessLM) {
            this.mLM = nGramProcessLM;
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // java.io.Externalizable
        public void writeExternal(ObjectOutput objectOutput) throws IOException {
            this.mLM.writeTo((OutputStream) objectOutput);
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // java.io.Externalizable
        public void readExternal(ObjectInput objectInput) throws IOException, ClassNotFoundException {
            this.mLM = NGramProcessLM.readFrom((InputStream) objectInput);
        }

        public Object readResolve() {
            return this.mLM;
        }
    }

    public NGramProcessLM(int i) {
        this(i, ImageUtil.USHORT_MASK);
    }

    public NGramProcessLM(int i, int i2) {
        this(i, i2, i);
    }

    public NGramProcessLM(int i, int i2, double d) {
        this(i2, d, new TrieCharSeqCounter(i));
    }

    public NGramProcessLM(int i, double d, TrieCharSeqCounter trieCharSeqCounter) {
        this.mMaxNGram = trieCharSeqCounter.mMaxLength;
        setLambdaFactor(d);
        setNumChars(i);
        this.mTrieCharSeqCounter = trieCharSeqCounter;
    }

    public void writeTo(OutputStream outputStream) throws IOException {
        BitOutput bitOutput = new BitOutput(outputStream);
        writeTo(bitOutput);
        bitOutput.flush();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void writeTo(BitOutput bitOutput) throws IOException {
        bitOutput.writeDelta(this.mMaxNGram);
        bitOutput.writeDelta(this.mNumChars);
        bitOutput.writeDelta((int) (this.mLambdaFactor * 1000000.0d));
        TrieCharSeqCounter.writeCounter(this.mTrieCharSeqCounter, new BitTrieWriter(bitOutput), this.mMaxNGram);
    }

    public static NGramProcessLM readFrom(InputStream inputStream) throws IOException {
        return readFrom(new BitInput(inputStream));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static NGramProcessLM readFrom(BitInput bitInput) throws IOException {
        return new NGramProcessLM((int) bitInput.readDelta(), bitInput.readDelta() / 1000000.0d, TrieCharSeqCounter.readCounter(new BitTrieReader(bitInput), (int) bitInput.readDelta()));
    }

    @Override // com.aliasi.stats.Model
    public double log2Prob(CharSequence charSequence) {
        return log2Estimate(charSequence);
    }

    @Override // com.aliasi.stats.Model
    public double prob(CharSequence charSequence) {
        return Math.pow(2.0d, log2Estimate(charSequence));
    }

    @Override // com.aliasi.lm.LanguageModel
    public final double log2Estimate(CharSequence charSequence) {
        char[] charArray = Strings.toCharArray(charSequence);
        return log2Estimate(charArray, 0, charArray.length);
    }

    @Override // com.aliasi.lm.LanguageModel
    public final double log2Estimate(char[] cArr, int i, int i2) {
        Strings.checkArgsStartEnd(cArr, i, i2);
        double d = 0.0d;
        for (int i3 = i + 1; i3 <= i2; i3++) {
            d += log2ConditionalEstimate(cArr, i, i3);
        }
        return d;
    }

    @Override // com.aliasi.lm.LanguageModel.Dynamic
    public void train(CharSequence charSequence) {
        train(charSequence, 1);
    }

    @Override // com.aliasi.lm.LanguageModel.Dynamic
    public void train(CharSequence charSequence, int i) {
        char[] charArray = Strings.toCharArray(charSequence);
        train(charArray, 0, charArray.length, i);
    }

    @Override // com.aliasi.lm.LanguageModel.Dynamic
    public void train(char[] cArr, int i, int i2) {
        train(cArr, i, i2, 1);
    }

    @Override // com.aliasi.lm.LanguageModel.Dynamic
    public void train(char[] cArr, int i, int i2, int i3) {
        Strings.checkArgsStartEnd(cArr, i, i2);
        this.mTrieCharSeqCounter.incrementSubstrings(cArr, i, i2, i3);
    }

    @Override // com.aliasi.corpus.TextHandler
    @Deprecated
    public void handle(char[] cArr, int i, int i2) {
        train(cArr, i, i + i2);
    }

    @Override // com.aliasi.corpus.ObjectHandler
    public void handle(CharSequence charSequence) {
        train(charSequence);
    }

    public void trainConditional(char[] cArr, int i, int i2, int i3) {
        Strings.checkArgsStartEnd(cArr, i, i2);
        Strings.checkArgsStartEnd(cArr, i, i3);
        if (i3 > i2) {
            throw new IllegalArgumentException("Conditional end must be < end. Found condEnd=" + i3 + " end=" + i2);
        }
        if (i3 == i2) {
            return;
        }
        this.mTrieCharSeqCounter.incrementSubstrings(cArr, i, i2);
        this.mTrieCharSeqCounter.decrementSubstrings(cArr, i, i3);
    }

    @Override // com.aliasi.lm.LanguageModel.Conditional
    public char[] observedCharacters() {
        return this.mTrieCharSeqCounter.observedCharacters();
    }

    @Override // com.aliasi.util.Compilable
    public void compileTo(ObjectOutput objectOutput) throws IOException {
        objectOutput.writeObject(new Externalizer(this));
    }

    @Override // com.aliasi.lm.LanguageModel.Conditional
    public double log2ConditionalEstimate(CharSequence charSequence) {
        return log2ConditionalEstimate(charSequence, this.mMaxNGram, this.mLambdaFactor);
    }

    @Override // com.aliasi.lm.LanguageModel.Conditional
    public double log2ConditionalEstimate(char[] cArr, int i, int i2) {
        return log2ConditionalEstimate(cArr, i, i2, this.mMaxNGram, this.mLambdaFactor);
    }

    public TrieCharSeqCounter substringCounter() {
        return this.mTrieCharSeqCounter;
    }

    public int maxNGram() {
        return this.mMaxNGram;
    }

    public double log2ConditionalEstimate(CharSequence charSequence, int i, double d) {
        char[] charArray = Strings.toCharArray(charSequence);
        return log2ConditionalEstimate(charArray, 0, charArray.length, i, d);
    }

    public double log2ConditionalEstimate(char[] cArr, int i, int i2, int i3, double d) {
        if (i2 <= i) {
            throw new IllegalArgumentException("Conditional estimates require at least one character.");
        }
        Strings.checkArgsStartEnd(cArr, i, i2);
        checkMaxNGram(i3);
        checkLambdaFactor(d);
        int min = Math.min(i3, this.mMaxNGram);
        if (i == i2) {
            return JXLabel.NORMAL;
        }
        double d2 = this.mUniformEstimate;
        int i4 = i2 - 1;
        int max = Math.max(i, i2 - min);
        for (int i5 = i4; i5 >= max; i5--) {
            long extensionCount = this.mTrieCharSeqCounter.extensionCount(cArr, i5, i4);
            if (extensionCount == 0) {
                break;
            }
            long count = this.mTrieCharSeqCounter.count(cArr, i5, i2);
            double lambda = lambda(cArr, i5, i4, d);
            d2 = (lambda * (count / extensionCount)) + ((1.0d - lambda) * d2);
        }
        return Math.log2(d2);
    }

    double lambda(char[] cArr, int i, int i2) {
        return lambda(cArr, i, i2, getLambdaFactor());
    }

    double lambda(char[] cArr, int i, int i2, double d) {
        checkLambdaFactor(d);
        Strings.checkArgsStartEnd(cArr, i, i2);
        double extensionCount = this.mTrieCharSeqCounter.extensionCount(cArr, i, i2);
        return extensionCount <= JXLabel.NORMAL ? JXLabel.NORMAL : lambda(extensionCount, this.mTrieCharSeqCounter.numCharactersFollowing(cArr, i, i2), d);
    }

    public double getLambdaFactor() {
        return this.mLambdaFactor;
    }

    public final void setLambdaFactor(double d) {
        checkLambdaFactor(d);
        this.mLambdaFactor = d;
    }

    public final void setNumChars(int i) {
        checkNumChars(i);
        this.mNumChars = i;
        this.mUniformEstimate = 1.0d / this.mNumChars;
        this.mLog2UniformEstimate = Math.log2(this.mUniformEstimate);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        toStringBuilder(sb);
        return sb.toString();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void toStringBuilder(StringBuilder sb) {
        sb.append("Max NGram=" + this.mMaxNGram + Strings.SINGLE_SPACE_STRING);
        sb.append("Num characters=" + this.mNumChars + FileHelper.NEWLINE_CHARACTER);
        sb.append("Trie of counts=\n");
        this.mTrieCharSeqCounter.toStringBuilder(sb);
    }

    void decrementUnigram(char c) {
        decrementUnigram(c, 1);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void decrementUnigram(char c, int i) {
        this.mTrieCharSeqCounter.decrementUnigram(c, i);
    }

    private double lambda(double d, double d2, double d3) {
        return d / (d + (d3 * d2));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double lambda(Node node) {
        return lambda(node.contextCount(Strings.EMPTY_CHAR_ARRAY, 0, 0), node.numOutcomes(Strings.EMPTY_CHAR_ARRAY, 0, 0), this.mLambdaFactor);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public int lastInternalNodeIndex() {
        int i = 1;
        LinkedList<Node> linkedList = new LinkedList<>();
        linkedList.add(this.mTrieCharSeqCounter.mRootNode);
        int i2 = 1;
        while (!linkedList.isEmpty()) {
            Node removeFirst = linkedList.removeFirst();
            if (removeFirst.numOutcomes(Strings.EMPTY_CHAR_ARRAY, 0, 0) > 0) {
                i = i2;
            }
            removeFirst.addDaughters(linkedList);
            i2++;
        }
        return i - 1;
    }

    private Object writeReplace() {
        return new Serializer(this);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void checkLambdaFactor(double d) {
        if (d < JXLabel.NORMAL || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new IllegalArgumentException("Lambda factor must be ordinary non-negative double. Found lambdaFactor=" + d);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void checkMaxNGram(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Maximum n-gram must be greater than zero. Found max n-gram=" + i);
        }
    }

    private static void checkNumChars(int i) {
        if (i < 0 || i > 65535) {
            throw new IllegalArgumentException("Number of characters must be > 0 and  must be less than Character.MAX_VALUE Found numChars=" + i);
        }
    }
}
