package com.aliasi.chunk;

import com.aliasi.symbol.SymbolTableCompiler;
import com.aliasi.tokenizer.TokenCategorizer;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Compilable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.TreeSet;
import org.jdesktop.swingx.JXLabel;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:lib/palladian.jar:com/aliasi/chunk/TrainableEstimator.class */
public final class TrainableEstimator implements Compilable {
    private Node mRootTagNode;
    private Node mRootTokenNode;
    private final SymbolTableCompiler mTokenSymbolTable;
    private final SymbolTableCompiler mTagSymbolTable;
    private double mLambdaFactor;
    private double mLogUniformVocabEstimate;
    private final TokenCategorizer mTokenCategorizer;

    /* loaded from: input_file:lib/palladian.jar:com/aliasi/chunk/TrainableEstimator$Externalizer.class */
    static class Externalizer extends AbstractExternalizable {
        private static final long serialVersionUID = 4179100933315980535L;
        final TrainableEstimator mEstimator;

        public Externalizer() {
            this(null);
        }

        public Externalizer(TrainableEstimator trainableEstimator) {
            this.mEstimator = trainableEstimator;
        }

        @Override // com.aliasi.util.AbstractExternalizable
        public Object read(ObjectInput objectInput) throws ClassNotFoundException, IOException {
            return new CompiledEstimator(objectInput);
        }

        @Override // com.aliasi.util.AbstractExternalizable, java.io.Externalizable
        public void writeExternal(ObjectOutput objectOutput) throws IOException {
            AbstractExternalizable.compileOrSerialize(this.mEstimator.mTokenCategorizer, objectOutput);
            this.mEstimator.generateSymbols();
            this.mEstimator.mTagSymbolTable.compileTo(objectOutput);
            this.mEstimator.mTokenSymbolTable.compileTo(objectOutput);
            this.mEstimator.writeEstimator(this.mEstimator.mRootTagNode, objectOutput);
            this.mEstimator.writeEstimator(this.mEstimator.mRootTokenNode, objectOutput);
            objectOutput.writeDouble(this.mEstimator.mLogUniformVocabEstimate);
        }
    }

    public TrainableEstimator(double d, double d2, TokenCategorizer tokenCategorizer) {
        this.mTokenSymbolTable = new SymbolTableCompiler();
        this.mTagSymbolTable = new SymbolTableCompiler();
        this.mLambdaFactor = d;
        this.mLogUniformVocabEstimate = d2;
        this.mTokenCategorizer = tokenCategorizer;
        this.mRootTagNode = new Node(null, this.mTagSymbolTable, null);
        this.mRootTokenNode = new Node(null, this.mTokenSymbolTable, null);
        this.mTagSymbolTable.addSymbol("O");
    }

    public TrainableEstimator(TokenCategorizer tokenCategorizer) {
        this(4.0d, Math.log(1.0E-6d), tokenCategorizer);
    }

    public void setLambdaFactor(double d) {
        if (d < JXLabel.NORMAL || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("Lambda factor must be > 0. Was=" + d);
        }
        this.mLambdaFactor = d;
    }

    public void setLogUniformVocabularyEstimate(double d) {
        if (d >= JXLabel.NORMAL || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("Log vocab estimate must be < 0. Was=" + d);
        }
        this.mLogUniformVocabEstimate = d;
    }

    public void handle(String[] strArr, String[] strArr2) {
        if (strArr.length < 1) {
            return;
        }
        trainOutcome(strArr[0], strArr2[0], "O", ".", ".");
        if (strArr.length < 2) {
            trainOutcome(".", "O", strArr2[0], strArr[0], ".");
            return;
        }
        trainOutcome(strArr[1], strArr2[1], strArr2[0], strArr[0], ".");
        for (int i = 2; i < strArr.length; i++) {
            trainOutcome(strArr[i], strArr2[i], strArr2[i - 1], strArr[i - 1], strArr[i - 2]);
        }
        trainOutcome(".", "O", strArr2[strArr2.length - 1], strArr[strArr.length - 1], strArr[strArr.length - 2]);
    }

    @Override // com.aliasi.util.Compilable
    public void compileTo(ObjectOutput objectOutput) throws IOException {
        objectOutput.writeObject(new Externalizer(this));
    }

    public void trainOutcome(String str, String str2, String str3, String str4, String str5) {
        this.mTagSymbolTable.addSymbol(str2);
        this.mTokenSymbolTable.addSymbol(str);
        String innerTag = str3 == null ? null : Tags.toInnerTag(str3);
        trainTokenModel(str, str2, innerTag, str4);
        trainTagModel(str2, innerTag, str4, str5);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void generateSymbols() {
        this.mRootTagNode.generateSymbols();
        this.mRootTokenNode.generateSymbols();
        for (String str : this.mTokenCategorizer.categories()) {
            this.mTokenSymbolTable.addSymbol(str);
        }
    }

    public void trainTokenModel(String str, String str2, String str3, String str4) {
        if (str2 == null || str == null) {
            return;
        }
        Node orCreateChild = this.mRootTokenNode.getOrCreateChild(str2, null, this.mTagSymbolTable);
        orCreateChild.incrementOutcome(str, this.mTokenSymbolTable);
        if (str3 == null) {
            return;
        }
        Node orCreateChild2 = orCreateChild.getOrCreateChild(str3, orCreateChild, this.mTagSymbolTable);
        orCreateChild2.incrementOutcome(str, this.mTokenSymbolTable);
        if (str4 == null) {
            return;
        }
        orCreateChild2.getOrCreateChild(str4, orCreateChild2, this.mTokenSymbolTable).incrementOutcome(str, this.mTokenSymbolTable);
    }

    public void trainTagModel(String str, String str2, String str3, String str4) {
        if (str == null || str2 == null) {
            return;
        }
        Node orCreateChild = this.mRootTagNode.getOrCreateChild(str2, null, this.mTagSymbolTable);
        orCreateChild.incrementOutcome(str, this.mTagSymbolTable);
        if (str3 == null) {
            return;
        }
        Node orCreateChild2 = orCreateChild.getOrCreateChild(str3, orCreateChild, this.mTokenSymbolTable);
        orCreateChild2.incrementOutcome(str, this.mTagSymbolTable);
        if (str4 == null) {
            return;
        }
        orCreateChild2.getOrCreateChild(str4, orCreateChild2, this.mTokenSymbolTable).incrementOutcome(str, this.mTagSymbolTable);
    }

    public void trainTokenOutcome(String str, String str2) {
        trainTokenModel(str, str2, null, null);
    }

    public int numTagNodes() {
        return this.mRootTagNode.numNodes();
    }

    public int numTagOutcomes() {
        return this.mRootTagNode.numCounters();
    }

    public int numTokenNodes() {
        return this.mRootTokenNode.numNodes();
    }

    public int numTokenOutcomes() {
        return this.mRootTokenNode.numCounters();
    }

    public void prune(int i, int i2) {
        this.mRootTagNode.prune(i);
        this.mRootTokenNode.prune(i2);
    }

    public void smoothTags(int i) {
        String[] symbols = this.mTagSymbolTable.symbols();
        for (String str : symbols) {
            for (String str2 : symbols) {
                if (!Tags.illegalSequence(str, str2)) {
                    for (int i2 = 0; i2 < i; i2++) {
                        trainTagModel(str2, str, null, null);
                    }
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void writeEstimator(Node node, ObjectOutput objectOutput) throws IOException {
        node.compileEstimates(this.mLambdaFactor);
        indexNodes(node);
        objectOutput.writeInt(node.numNodes());
        writeNodes(node, objectOutput);
        objectOutput.writeInt(node.numCounters());
        writeOutcomes(node, objectOutput);
    }

    private static void indexNodes(Node node) {
        LinkedList linkedList = new LinkedList();
        linkedList.addLast(node);
        int i = 0;
        while (linkedList.size() > 0) {
            Node node2 = (Node) linkedList.removeFirst();
            int i2 = i;
            i++;
            node2.setIndex(i2);
            Iterator<String> it = node2.children().iterator();
            while (it.hasNext()) {
                linkedList.addLast(node2.getChild(it.next()));
            }
        }
    }

    private static void writeNodes(Node node, ObjectOutput objectOutput) throws IOException {
        LinkedList linkedList = new LinkedList();
        linkedList.addLast(new Object[]{node, null});
        int i = 0;
        int i2 = 0;
        while (linkedList.size() > 0) {
            Node node2 = (Node) ((Object[]) linkedList.removeFirst())[0];
            objectOutput.writeInt(node2.getSymbolID());
            objectOutput.writeInt(i);
            i += node2.outcomes().size();
            TreeSet treeSet = new TreeSet(node2.children());
            if (treeSet.size() == 0) {
                objectOutput.writeInt(i2);
            } else {
                Node child = node2.getChild((String) treeSet.iterator().next());
                objectOutput.writeInt(child.index());
                i2 = child.index() + node2.children().size();
                Iterator it = treeSet.iterator();
                while (it.hasNext()) {
                    String str = (String) it.next();
                    linkedList.addLast(new Object[]{node2.getChild(str), str});
                }
            }
            objectOutput.writeFloat(node2.oneMinusLambda());
            objectOutput.writeInt(node2.backoffNode() == null ? -1 : node2.backoffNode().index());
        }
    }

    private static void writeOutcomes(Node node, ObjectOutput objectOutput) throws IOException {
        LinkedList linkedList = new LinkedList();
        linkedList.addLast(node);
        while (linkedList.size() > 0) {
            Node node2 = (Node) linkedList.removeFirst();
            Iterator<String> it = node2.outcomes().iterator();
            while (it.hasNext()) {
                OutcomeCounter outcome = node2.getOutcome(it.next());
                objectOutput.writeInt(outcome.getSymbolID());
                objectOutput.writeFloat(outcome.estimate());
            }
            Iterator<String> it2 = node2.children().iterator();
            while (it2.hasNext()) {
                linkedList.addLast(node2.getChild(it2.next()));
            }
        }
    }
}
