package com.aliasi.chunk;

import com.aliasi.hmm.HmmDecoder;
import com.aliasi.symbol.SymbolTable;
import com.aliasi.tag.TagLattice;
import com.aliasi.tokenizer.Tokenizer;
import com.aliasi.tokenizer.TokenizerFactory;
import com.aliasi.util.BoundedPriorityQueue;
import com.aliasi.util.Iterators;
import com.aliasi.util.Math;
import com.aliasi.util.Scored;
import com.aliasi.util.ScoredObject;
import com.aliasi.util.Strings;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;

/* loaded from: input_file:lib/palladian.jar:com/aliasi/chunk/HmmChunker.class */
public class HmmChunker implements NBestChunker, ConfidenceChunker {
    private final TokenizerFactory mTokenizerFactory;
    private final HmmDecoder mDecoder;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/palladian.jar:com/aliasi/chunk/HmmChunker$ChunkItState.class */
    public static class ChunkItState implements Scored {
        final int mStartCharPos;
        final int mTokPos;
        final String mTag;
        final double mForward;
        final double mBack;
        final double mScore;
        final int mCurrentTagId;
        final int mMidTagId;
        final int mEndTagId;

        ChunkItState(int i, int i2, String str, int i3, int i4, int i5, double d, double d2) {
            this.mStartCharPos = i;
            this.mTokPos = i2;
            this.mTag = str;
            this.mCurrentTagId = i3;
            this.mMidTagId = i4;
            this.mEndTagId = i5;
            this.mForward = d;
            this.mBack = d2;
            this.mScore = d + d2;
        }

        @Override // com.aliasi.util.Scored
        public double score() {
            return this.mScore;
        }
    }

    /* loaded from: input_file:lib/palladian.jar:com/aliasi/chunk/HmmChunker$NBestChunkIt.class */
    private static class NBestChunkIt extends Iterators.Buffered<Chunk> {
        final TagLattice<String> mLattice;
        final String[] mWhites;
        final int mMaxNBest;
        final int[] mTokenStartIndexes;
        final int[] mTokenEndIndexes;
        String[] mBeginTags;
        int[] mBeginTagIds;
        int[] mMidTagIds;
        int[] mEndTagIds;
        String[] mWholeTags;
        int[] mWholeTagIds;
        final BoundedPriorityQueue<Scored> mQueue;
        final int mNumToks;
        final double mTotal;
        int mCount = 0;

        NBestChunkIt(TagLattice<String> tagLattice, String[] strArr, int i) {
            this.mTotal = Math.naturalLogToBase2Log(tagLattice.logZ());
            this.mLattice = tagLattice;
            this.mWhites = strArr;
            String[] strArr2 = (String[]) tagLattice.tokenList().toArray(Strings.EMPTY_STRING_ARRAY);
            this.mNumToks = strArr2.length;
            this.mTokenStartIndexes = new int[this.mNumToks];
            this.mTokenEndIndexes = new int[this.mNumToks];
            int i2 = 0;
            for (int i3 = 0; i3 < this.mNumToks; i3++) {
                int length = i2 + strArr[i3].length();
                this.mTokenStartIndexes[i3] = length;
                i2 = length + strArr2[i3].length();
                this.mTokenEndIndexes[i3] = i2;
            }
            this.mMaxNBest = i;
            this.mQueue = new BoundedPriorityQueue<>(ScoredObject.comparator(), i);
            initializeTags();
            initializeQueue();
        }

        void initializeTags() {
            SymbolTable tagSymbolTable = this.mLattice.tagSymbolTable();
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            ArrayList arrayList3 = new ArrayList();
            ArrayList arrayList4 = new ArrayList();
            ArrayList arrayList5 = new ArrayList();
            ArrayList arrayList6 = new ArrayList();
            int numSymbols = tagSymbolTable.numSymbols();
            for (int i = 0; i < numSymbols; i++) {
                String idToSymbol = tagSymbolTable.idToSymbol(i);
                if (idToSymbol.startsWith(BioTagChunkCodec.BEGIN_TAG_PREFIX)) {
                    String substring = idToSymbol.substring(2);
                    arrayList.add(substring);
                    arrayList2.add(Integer.valueOf(i));
                    arrayList3.add(Integer.valueOf(tagSymbolTable.symbolToID("M_" + substring)));
                    arrayList4.add(Integer.valueOf(tagSymbolTable.symbolToID("E_" + substring)));
                } else if (idToSymbol.startsWith("W_")) {
                    arrayList5.add(idToSymbol.substring(2));
                    arrayList6.add(Integer.valueOf(i));
                }
            }
            this.mBeginTags = HmmChunker.toStringArray(arrayList);
            this.mBeginTagIds = HmmChunker.toIntArray(arrayList2);
            this.mMidTagIds = HmmChunker.toIntArray(arrayList3);
            this.mEndTagIds = HmmChunker.toIntArray(arrayList4);
            this.mWholeTags = HmmChunker.toStringArray(arrayList5);
            this.mWholeTagIds = HmmChunker.toIntArray(arrayList6);
        }

        void initializeQueue() {
            int length = this.mWhites.length - 1;
            for (int i = 0; i < length; i++) {
                for (int i2 = 0; i2 < this.mBeginTagIds.length; i2++) {
                    initializeBeginTag(i, i2);
                }
                for (int i3 = 0; i3 < this.mWholeTagIds.length; i3++) {
                    initializeWholeTag(i, i3);
                }
            }
        }

        void initializeBeginTag(int i, int i2) {
            int i3 = this.mTokenStartIndexes[i];
            String str = this.mBeginTags[i2];
            int i4 = this.mBeginTagIds[i2];
            this.mQueue.offer(new ChunkItState(i3, i, str, i4, this.mMidTagIds[i2], this.mEndTagIds[i2], Math.naturalLogToBase2Log(this.mLattice.logForward(i, i4)), Math.naturalLogToBase2Log(this.mLattice.logBackward(i, i4))));
        }

        void initializeWholeTag(int i, int i2) {
            this.mQueue.offer(ChunkFactory.createChunk(this.mTokenStartIndexes[i], this.mTokenEndIndexes[i], this.mWholeTags[i2], Math.naturalLogToBase2Log(this.mLattice.logProbability(i, this.mWholeTagIds[i2]))));
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // com.aliasi.util.Iterators.Buffered
        public Chunk bufferNext() {
            if (this.mCount > this.mMaxNBest) {
                return null;
            }
            while (!this.mQueue.isEmpty()) {
                Scored poll = this.mQueue.poll();
                if (poll instanceof Chunk) {
                    this.mCount++;
                    Chunk chunk = (Chunk) poll;
                    return ChunkFactory.createChunk(chunk.start(), chunk.end(), chunk.type(), chunk.score() - this.mTotal);
                }
                ChunkItState chunkItState = (ChunkItState) poll;
                addNextMidState(chunkItState);
                addNextEndState(chunkItState);
            }
            return null;
        }

        void addNextMidState(ChunkItState chunkItState) {
            int i = chunkItState.mTokPos + 1;
            if (i + 1 >= this.mNumToks) {
                return;
            }
            int i2 = chunkItState.mMidTagId;
            this.mQueue.offer(new ChunkItState(chunkItState.mStartCharPos, i, chunkItState.mTag, i2, chunkItState.mMidTagId, chunkItState.mEndTagId, chunkItState.mForward + Math.naturalLogToBase2Log(this.mLattice.logTransition(i - 1, chunkItState.mCurrentTagId, i2)), Math.naturalLogToBase2Log(this.mLattice.logBackward(i, i2))));
        }

        void addNextEndState(ChunkItState chunkItState) {
            int i = chunkItState.mTokPos + 1;
            if (i >= this.mNumToks) {
                return;
            }
            int i2 = chunkItState.mEndTagId;
            this.mQueue.offer(ChunkFactory.createChunk(chunkItState.mStartCharPos, this.mTokenEndIndexes[i], chunkItState.mTag, chunkItState.mForward + Math.naturalLogToBase2Log(this.mLattice.logTransition(i - 1, chunkItState.mCurrentTagId, i2)) + Math.naturalLogToBase2Log(this.mLattice.logBackward(i, i2))));
        }
    }

    /* loaded from: input_file:lib/palladian.jar:com/aliasi/chunk/HmmChunker$NBestIt.class */
    private static class NBestIt implements Iterator<ScoredObject<Chunking>> {
        final Iterator<ScoredObject<String[]>> mIt;
        final String[] mWhites;
        final String[] mToks;

        NBestIt(Iterator<ScoredObject<String[]>> it, String[][] strArr) {
            this.mIt = it;
            this.mToks = strArr[0];
            this.mWhites = strArr[1];
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return this.mIt.hasNext();
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public ScoredObject<Chunking> next() {
            ScoredObject<String[]> next = this.mIt.next();
            double score = next.score();
            String[] object = next.getObject();
            HmmChunker.decodeNormalize(object);
            return new ScoredObject<>(ChunkTagHandlerAdapter2.toChunkingBIO(this.mToks, this.mWhites, object), score);
        }

        @Override // java.util.Iterator
        public void remove() {
            this.mIt.remove();
        }
    }

    public HmmChunker(TokenizerFactory tokenizerFactory, HmmDecoder hmmDecoder) {
        this.mTokenizerFactory = tokenizerFactory;
        this.mDecoder = hmmDecoder;
    }

    public HmmDecoder getDecoder() {
        return this.mDecoder;
    }

    public TokenizerFactory getTokenizerFactory() {
        return this.mTokenizerFactory;
    }

    @Override // com.aliasi.chunk.Chunker
    public Chunking chunk(char[] cArr, int i, int i2) {
        Strings.checkArgsStartEnd(cArr, i, i2);
        Tokenizer tokenizer = this.mTokenizerFactory.tokenizer(cArr, i, i2 - i);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        tokenizer.tokenize(arrayList, arrayList2);
        String[] stringArray = toStringArray(arrayList);
        String[] stringArray2 = toStringArray(arrayList2);
        String[] firstBest = this.mDecoder.firstBest(stringArray);
        decodeNormalize(firstBest);
        return ChunkTagHandlerAdapter2.toChunkingBIO(stringArray, stringArray2, firstBest);
    }

    @Override // com.aliasi.chunk.Chunker
    public Chunking chunk(CharSequence charSequence) {
        char[] charArray = Strings.toCharArray(charSequence);
        return chunk(charArray, 0, charArray.length);
    }

    @Override // com.aliasi.chunk.NBestChunker
    public Iterator<ScoredObject<Chunking>> nBest(char[] cArr, int i, int i2, int i3) {
        Strings.checkArgsStartEnd(cArr, i, i2);
        if (i3 < 1) {
            throw new IllegalArgumentException("Maximum n-best value must be greater than zero. Found maxNBest=" + i3);
        }
        String[][] toksWhites = getToksWhites(cArr, i, i2);
        return new NBestIt(this.mDecoder.nBest(toksWhites[0], i3), toksWhites);
    }

    public Iterator<ScoredObject<Chunking>> nBestConditional(char[] cArr, int i, int i2, int i3) {
        Strings.checkArgsStartEnd(cArr, i, i2);
        if (i3 < 1) {
            throw new IllegalArgumentException("Maximum n-best value must be greater than zero. Found maxNBest=" + i3);
        }
        String[][] toksWhites = getToksWhites(cArr, i, i2);
        return new NBestIt(this.mDecoder.nBestConditional(toksWhites[0]), toksWhites);
    }

    @Override // com.aliasi.chunk.ConfidenceChunker
    public Iterator<Chunk> nBestChunks(char[] cArr, int i, int i2, int i3) {
        String[][] toksWhites = getToksWhites(cArr, i, i2);
        return new NBestChunkIt(this.mDecoder.tagMarginal(Arrays.asList(toksWhites[0])), toksWhites[1], i3);
    }

    /* JADX WARN: Type inference failed for: r0v12, types: [java.lang.String[], java.lang.String[][]] */
    String[][] getToksWhites(char[] cArr, int i, int i2) {
        Strings.checkArgsStartEnd(cArr, i, i2);
        Tokenizer tokenizer = this.mTokenizerFactory.tokenizer(cArr, i, i2 - i);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        tokenizer.tokenize(arrayList, arrayList2);
        return new String[]{toStringArray(arrayList), toStringArray(arrayList2)};
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static String[] toStringArray(Collection<String> collection) {
        return (String[]) collection.toArray(Strings.EMPTY_STRING_ARRAY);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static int[] toIntArray(Collection<Integer> collection) {
        int[] iArr = new int[collection.size()];
        Iterator<Integer> it = collection.iterator();
        int i = 0;
        while (it.hasNext()) {
            iArr[i] = it.next().intValue();
            i++;
        }
        return iArr;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static String baseTag(String str) {
        return ChunkTagHandlerAdapter2.isOutTag(str) ? str : str.substring(2);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static String[] trainNormalize(String[] strArr) {
        if (strArr.length == 0) {
            return strArr;
        }
        String[] strArr2 = new String[strArr.length];
        for (int i = 0; i < strArr2.length; i++) {
            strArr2[i] = trainNormalize(i - 1 >= 0 ? strArr[i - 1] : "W_BOS", strArr[i], i + 1 < strArr.length ? strArr[i + 1] : "W_BOS");
        }
        return strArr2;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void decodeNormalize(String[] strArr) {
        for (int i = 0; i < strArr.length; i++) {
            strArr[i] = decodeNormalize(strArr[i]);
        }
    }

    static String trainNormalize(String str, String str2, String str3) {
        if (ChunkTagHandlerAdapter2.isOutTag(str2)) {
            return ChunkTagHandlerAdapter2.isOutTag(str) ? ChunkTagHandlerAdapter2.isOutTag(str3) ? "MM_O" : "EE_O_" + baseTag(str3) : ChunkTagHandlerAdapter2.isOutTag(str3) ? "BB_O_" + baseTag(str) : "WW_O_" + baseTag(str3);
        }
        if (ChunkTagHandlerAdapter2.isBeginTag(str2)) {
            return ChunkTagHandlerAdapter2.isInTag(str3) ? BioTagChunkCodec.BEGIN_TAG_PREFIX + baseTag(str2) : "W_" + baseTag(str2);
        }
        if (ChunkTagHandlerAdapter2.isInTag(str2)) {
            return ChunkTagHandlerAdapter2.isInTag(str3) ? "M_" + baseTag(str2) : "E_" + baseTag(str2);
        }
        throw new IllegalArgumentException("Unknown tag triple. prevTag=" + str + " tag=" + str2 + " nextTag=" + str3);
    }

    private static String decodeNormalize(String str) {
        return (str.startsWith(BioTagChunkCodec.BEGIN_TAG_PREFIX) || str.startsWith("W_")) ? ChunkTagHandlerAdapter2.toBeginTag(str.substring(2)) : (str.startsWith("M_") || str.startsWith("E_")) ? ChunkTagHandlerAdapter2.toInTag(str.substring(2)) : ChunkTagHandlerAdapter2.OUT_TAG;
    }
}
