package de.julielab.jnet.tagger;

import de.julielab.jnet.evaluation.IOEvaluation;
import de.julielab.jnet.utils.Utils;
import edu.stanford.nlp.classify.LinearClassifier;
import edu.umass.cs.mallet.base.fst.CRF4;
import edu.umass.cs.mallet.base.fst.MultiSegmentationEvaluator;
import edu.umass.cs.mallet.base.fst.Segment;
import edu.umass.cs.mallet.base.fst.Transducer;
import edu.umass.cs.mallet.base.fst.confidence.ConstrainedForwardBackwardConfidenceEstimator;
import edu.umass.cs.mallet.base.pipe.Pipe;
import edu.umass.cs.mallet.base.types.Instance;
import edu.umass.cs.mallet.base.types.InstanceList;
import edu.umass.cs.mallet.base.types.Sequence;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Properties;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

/* loaded from: input_file:lib/palladian.jar:de/julielab/jnet/tagger/NETagger.class */
public class NETagger {
    private CRF4 model;
    private Properties featureConfig;
    private boolean trained;

    public NETagger() {
        this.trained = false;
        Properties properties = new Properties();
        try {
            properties.load(getClass().getResourceAsStream("/defaultFeatureConf.conf"));
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e2) {
            e2.printStackTrace();
        }
        this.featureConfig = new Properties(properties);
    }

    public NETagger(File file) {
        this();
        try {
            this.featureConfig.load(new FileInputStream(file));
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e2) {
            e2.printStackTrace();
        }
    }

    public boolean isTrained() {
        return this.trained;
    }

    public void train(ArrayList<Sentence> arrayList, Tags tags) {
        System.out.println("   * training model... on " + arrayList.size() + " sentences");
        InstanceList createFeatureData = FeatureGenerator.createFeatureData(arrayList, tags.getAlphabet(), this.featureConfig);
        System.out.println("  * no of features for training: " + createFeatureData.getDataAlphabet().size());
        long currentTimeMillis = System.currentTimeMillis();
        CRF4 crf4 = new CRF4(createFeatureData.getPipe(), (Pipe) null);
        crf4.addStatesForLabelsConnectedAsIn(createFeatureData);
        crf4.train(createFeatureData, (InstanceList) null, (InstanceList) null, (MultiSegmentationEvaluator) null, 99999, 10, new double[]{0.2d, 0.5d, 0.8d});
        System.out.println("  * learning took (sec): " + ((System.currentTimeMillis() - currentTimeMillis) / 1000));
        this.model = crf4;
        this.trained = true;
    }

    public void predict(Sentence sentence, boolean z) throws JNETException {
        if (!this.trained || this.model == null) {
            throw new JNETException("No model available. Train or load trained model first.");
        }
        Sequence sequence = (Sequence) new Instance(sentence, "", "", "", this.model.getInputPipe()).getData();
        Sequence output = this.model.viterbiPath(sequence).output();
        if (output.size() != sentence.getUnits().size()) {
            throw new JNETException("Wrong number of labels predicted.");
        }
        double[] segmentConfidence = z ? getSegmentConfidence(sequence, output) : null;
        for (int i = 0; i < sentence.getUnits().size(); i++) {
            Unit unit = sentence.get(i);
            unit.setLabel((String) output.get(i));
            if (z) {
                unit.setConfidence(segmentConfidence[i]);
            }
        }
    }

    public ArrayList<String> predictIOB(ArrayList<Sentence> arrayList, boolean z) throws JNETException {
        if (!this.trained || this.model == null) {
            throw new JNETException("no model available. Train or load trained model first.");
        }
        System.out.println("  * predicting with crf model...");
        Pipe inputPipe = this.model.getInputPipe();
        ArrayList<String> arrayList2 = new ArrayList<>();
        for (int i = 0; i < arrayList.size(); i++) {
            Sentence sentence = arrayList.get(i);
            Sequence sequence = (Sequence) new Instance(sentence, "", "", "", inputPipe).getData();
            Sequence output = this.model.viterbiPath(sequence).output();
            ArrayList<Unit> units = sentence.getUnits();
            if (output.size() != sentence.getUnits().size()) {
                throw new JNETException("Wrong number of labels predicted.");
            }
            double[] segmentConfidence = z ? getSegmentConfidence(sequence, output) : null;
            for (int i2 = 0; i2 < sentence.getUnits().size(); i2++) {
                Unit unit = sentence.get(i2);
                unit.setLabel((String) output.get(i2));
                String str = units.get(i2).getRep() + LinearClassifier.TEXT_SERIALIZATION_DELIMITER + ((String) output.get(i2));
                if (z) {
                    unit.setConfidence(segmentConfidence[i2]);
                    str = str + LinearClassifier.TEXT_SERIALIZATION_DELIMITER + segmentConfidence[i2];
                }
                arrayList2.add(str);
            }
            arrayList2.add("O\tO");
        }
        return arrayList2;
    }

    private double[] getSegmentConfidence(Sequence sequence, Sequence sequence2) {
        ConstrainedForwardBackwardConfidenceEstimator constrainedForwardBackwardConfidenceEstimator = new ConstrainedForwardBackwardConfidenceEstimator(this.model);
        Transducer.Lattice forwardBackward = this.model.forwardBackward(sequence);
        double[] dArr = new double[sequence2.size()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = -1.0d;
        }
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < sequence2.size(); i2++) {
            arrayList.add((String) sequence2.get(i2));
        }
        Iterator it = IOEvaluation.getChunksIO(arrayList).keySet().iterator();
        while (it.hasNext()) {
            String[] split = ((String) it.next()).split(",");
            int intValue = new Integer(split[0]).intValue();
            int intValue2 = new Integer(split[1]).intValue();
            double estimateConfidenceFor = constrainedForwardBackwardConfidenceEstimator.estimateConfidenceFor(new Segment(sequence, sequence2, sequence2, intValue, intValue2, "entity", "entity"), forwardBackward);
            for (int i3 = intValue; i3 <= intValue2; i3++) {
                dArr[i3] = estimateConfidenceFor;
            }
        }
        return dArr;
    }

    public void writeModel(String str) {
        if (!this.trained || this.model == null || this.featureConfig == null) {
            System.err.println("train or load trained model first.");
            System.exit(0);
        }
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(new File(str + ".gz"))));
            objectOutputStream.writeObject(new FeatureSubsetModel(this.model, this.featureConfig));
            objectOutputStream.close();
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(0);
        }
    }

    public void readModel(String str) throws IOException, FileNotFoundException, ClassNotFoundException {
        FeatureSubsetModel featureSubsetModel = (FeatureSubsetModel) new ObjectInputStream(new GZIPInputStream(new FileInputStream(new File(str)))).readObject();
        this.model = featureSubsetModel.getModel();
        this.featureConfig = featureSubsetModel.getFeatureConfig();
        this.trained = true;
    }

    public CRF4 getModel() {
        return this.model;
    }

    public void setFeatureConfig(Properties properties) {
        this.featureConfig = properties;
    }

    public Properties getFeatureConfig() {
        return this.featureConfig;
    }

    public Sentence PPDtoUnits(String str) throws JNETException {
        String[] split = str.trim().split("[\t ]+");
        ArrayList arrayList = new ArrayList();
        String[] trueMetas = Utils.getTrueMetas(this.featureConfig);
        for (int i = 0; i < split.length; i++) {
            HashMap hashMap = new HashMap();
            String[] split2 = split[i].split("\\|+");
            String str2 = split2[0];
            String str3 = split2[split2.length - 1];
            if (trueMetas.length + 2 != split2.length) {
                System.err.println("Error in input format (PipedFormat)! Mal-formatted sentence: " + str + "\n token: " + split[i]);
                System.err.println("Perhaps your configuration file uses more or less meta datas as are available in your input files? If you don't use a config file, you should check whether your input files fit to the default configuration.");
                System.exit(-1);
            }
            for (int i2 = 0; i2 < trueMetas.length; i2++) {
                int parseInt = Integer.parseInt(this.featureConfig.getProperty(trueMetas[i2] + "_feat_position"));
                String property = this.featureConfig.getProperty(trueMetas[i2] + "_feat_unit");
                if (!split2[parseInt].equals(this.featureConfig.getProperty("gap_character"))) {
                    hashMap.put(property, split2[parseInt]);
                }
            }
            arrayList.add(new Unit(0, 0, str2, str3, hashMap));
        }
        return new Sentence(arrayList);
    }
}
