package ws.palladian.extraction.keyphrase.extractors;

import com.aliasi.util.Strings;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.jdesktop.swingx.JXLabel;
import ws.palladian.classification.Instance;
import ws.palladian.classification.dt.BaggedDecisionTreeClassifier;
import ws.palladian.classification.dt.BaggedDecisionTreeModel;
import ws.palladian.extraction.feature.DuplicateTokenConsolidator;
import ws.palladian.extraction.feature.DuplicateTokenRemover;
import ws.palladian.extraction.feature.HtmlCleaner;
import ws.palladian.extraction.feature.IdfAnnotator;
import ws.palladian.extraction.feature.LengthTokenRemover;
import ws.palladian.extraction.feature.NGramCreator;
import ws.palladian.extraction.feature.RegExTokenRemover;
import ws.palladian.extraction.feature.StemmerAnnotator;
import ws.palladian.extraction.feature.StopTokenRemover;
import ws.palladian.extraction.feature.TermCorpus;
import ws.palladian.extraction.feature.TextDocumentPipelineProcessor;
import ws.palladian.extraction.feature.TfIdfAnnotator;
import ws.palladian.extraction.feature.TokenMetricsCalculator;
import ws.palladian.extraction.keyphrase.Keyphrase;
import ws.palladian.extraction.keyphrase.KeyphraseExtractor;
import ws.palladian.extraction.keyphrase.features.AdditionalFeatureExtractor;
import ws.palladian.extraction.keyphrase.temp.CooccurrenceMatrix;
import ws.palladian.extraction.token.BaseTokenizer;
import ws.palladian.extraction.token.RegExTokenizer;
import ws.palladian.helper.collection.CollectionHelper;
import ws.palladian.helper.constants.Language;
import ws.palladian.processing.Classifiable;
import ws.palladian.processing.DocumentUnprocessableException;
import ws.palladian.processing.PerformanceCheckProcessingPipeline;
import ws.palladian.processing.PipelineDocument;
import ws.palladian.processing.ProcessingPipeline;
import ws.palladian.processing.TextDocument;
import ws.palladian.processing.Trainable;
import ws.palladian.processing.features.FeatureVector;
import ws.palladian.processing.features.ListFeature;
import ws.palladian.processing.features.NominalFeature;
import ws.palladian.processing.features.NumericFeature;
import ws.palladian.processing.features.PositionAnnotation;

/* loaded from: input_file:lib/palladian.jar:ws/palladian/extraction/keyphrase/extractors/MachineLearningBasedExtractor.class */
public final class MachineLearningBasedExtractor extends KeyphraseExtractor {
    static final String IS_KEYWORD = "isKeyword";
    private final ProcessingPipeline candidateGenerationPipeline;
    private StemmerAnnotator stemmer;
    private BaggedDecisionTreeModel model;
    private final int TRAIN_DOC_LIMIT = 100;
    private final TermCorpus termCorpus = new TermCorpus();
    private final TermCorpus keyphraseCorpus = new TermCorpus();
    private final CooccurrenceMatrix<String> cooccurrenceMatrix = new CooccurrenceMatrix<>();
    private int trainCount = 0;
    private final Map<PipelineDocument<String>, Set<String>> trainDocuments = new HashMap();
    private BaggedDecisionTreeClassifier classifier = createClassifier();
    private final ProcessingPipeline corpusGenerationPipeline = new PerformanceCheckProcessingPipeline();

    public MachineLearningBasedExtractor() {
        this.corpusGenerationPipeline.add(new HtmlCleaner());
        this.corpusGenerationPipeline.add(new RegExTokenizer());
        this.corpusGenerationPipeline.add(new StopTokenRemover(Language.ENGLISH));
        this.corpusGenerationPipeline.add(new LengthTokenRemover(4));
        this.corpusGenerationPipeline.add(new RegExTokenRemover("[^A-Za-z0-9-]+"));
        this.stemmer = new StemmerAnnotator(Language.ENGLISH, StemmerAnnotator.Mode.MODIFY);
        this.corpusGenerationPipeline.add(this.stemmer);
        this.corpusGenerationPipeline.add(new DuplicateTokenRemover());
        this.candidateGenerationPipeline = new ProcessingPipeline();
        this.candidateGenerationPipeline.add(new HtmlCleaner());
        this.candidateGenerationPipeline.add(new RegExTokenizer());
        this.candidateGenerationPipeline.add(new StopTokenRemover(Language.ENGLISH));
        this.candidateGenerationPipeline.add(new LengthTokenRemover(4));
        this.candidateGenerationPipeline.add(new RegExTokenRemover("[^A-Za-z0-9-]+"));
        this.candidateGenerationPipeline.add(this.stemmer);
        this.candidateGenerationPipeline.add(new NGramCreator(3, StemmerAnnotator.UNSTEM));
        this.candidateGenerationPipeline.add(new TokenMetricsCalculator());
        this.candidateGenerationPipeline.add(new DuplicateTokenConsolidator());
        this.candidateGenerationPipeline.add(new IdfAnnotator(this.termCorpus));
        this.candidateGenerationPipeline.add(new TfIdfAnnotator());
        this.candidateGenerationPipeline.add(new AdditionalFeatureExtractor());
        this.candidateGenerationPipeline.add(new TextDocumentPipelineProcessor() { // from class: ws.palladian.extraction.keyphrase.extractors.MachineLearningBasedExtractor.1
            @Override // ws.palladian.extraction.feature.TextDocumentPipelineProcessor
            public void processDocument(TextDocument textDocument) throws DocumentUnprocessableException {
                Iterator<T> it = BaseTokenizer.getTokenAnnotations(textDocument).iterator();
                while (it.hasNext()) {
                    ((PositionAnnotation) it.next()).getFeatureVector().add(new NumericFeature("prior", Double.valueOf((MachineLearningBasedExtractor.this.keyphraseCorpus.getCount(r0.getValue()) + 1) / MachineLearningBasedExtractor.this.keyphraseCorpus.getNumDocs())));
                }
            }
        });
    }

    private BaggedDecisionTreeClassifier createClassifier() {
        return new BaggedDecisionTreeClassifier();
    }

    @Override // ws.palladian.extraction.keyphrase.KeyphraseExtractor
    public boolean needsTraining() {
        return true;
    }

    @Override // ws.palladian.extraction.keyphrase.KeyphraseExtractor
    public void startTraining() {
        System.out.println("Building corpus ...");
        super.startTraining();
    }

    @Override // ws.palladian.extraction.keyphrase.KeyphraseExtractor
    public void train(String str, Set<String> set) {
        TextDocument textDocument = new TextDocument(str);
        try {
            this.corpusGenerationPipeline.process(textDocument);
            List list = (List) textDocument.get(ListFeature.class, BaseTokenizer.PROVIDED_FEATURE);
            HashSet newHashSet = CollectionHelper.newHashSet();
            Iterator it = list.iterator();
            while (it.hasNext()) {
                newHashSet.add(((PositionAnnotation) it.next()).getValue());
            }
            this.termCorpus.addTermsFromDocument(newHashSet);
            this.keyphraseCorpus.addTermsFromDocument(stem(set));
            this.cooccurrenceMatrix.addAll(stem(set));
            if (this.trainCount <= 100) {
                this.trainDocuments.put(textDocument, set);
            }
            this.trainCount++;
        } catch (DocumentUnprocessableException e) {
            throw new IllegalStateException(e);
        }
    }

    @Override // ws.palladian.extraction.keyphrase.KeyphraseExtractor
    public void endTraining() {
        System.out.println("finished building corpus, # train docs: " + this.trainDocuments.size());
        ArrayList newArrayList = CollectionHelper.newArrayList();
        Iterator<Map.Entry<PipelineDocument<String>, Set<String>>> it = this.trainDocuments.entrySet().iterator();
        int i = 0;
        int i2 = 0;
        while (it.hasNext()) {
            Map.Entry<PipelineDocument<String>, Set<String>> next = it.next();
            PipelineDocument<String> key = next.getKey();
            Set<String> value = next.getValue();
            try {
                this.candidateGenerationPipeline.process(key);
                List<PositionAnnotation> list = (List) key.get(ListFeature.class, BaseTokenizer.PROVIDED_FEATURE);
                i += value.size();
                i2 += markCandidates(list, value);
                newArrayList.addAll(list);
                it.remove();
                System.out.println(this.trainDocuments.size());
            } catch (DocumentUnprocessableException e) {
                throw new IllegalStateException(e);
            }
        }
        System.out.println("# annotations: " + newArrayList.size());
        System.out.println("% sample coverage: " + (i2 / i));
        int i3 = 0;
        int i4 = 0;
        ArrayList arrayList = new ArrayList();
        Iterator it2 = newArrayList.iterator();
        while (it2.hasNext()) {
            FeatureVector featureVector = ((PositionAnnotation) it2.next()).getFeatureVector();
            String value2 = ((NominalFeature) featureVector.get(NominalFeature.class, IS_KEYWORD)).getValue();
            FeatureVector cleanFeatureVector = cleanFeatureVector(featureVector);
            if ("true".equals(value2)) {
                i3++;
            } else {
                i4++;
            }
            arrayList.add(new Instance(value2, cleanFeatureVector));
        }
        System.out.println("# negative samples: " + i4);
        System.out.println("# positive samples: " + i3);
        System.out.println("% positive sample rate: " + (i3 / (i4 + i3)));
        System.out.println("building classifier ...");
        this.model = this.classifier.train((Iterable<? extends Trainable>) arrayList);
        System.out.println(this.model.toString());
        System.out.println("... finished building classifier.");
    }

    private FeatureVector cleanFeatureVector(FeatureVector featureVector) {
        FeatureVector featureVector2 = new FeatureVector(featureVector);
        featureVector2.remove(IS_KEYWORD);
        featureVector2.remove(StemmerAnnotator.UNSTEM);
        featureVector2.remove(BaseTokenizer.PROVIDED_FEATURE);
        featureVector2.remove(AdditionalFeatureExtractor.CASE_SIGNATURE);
        return featureVector2;
    }

    private int markCandidates(List<PositionAnnotation> list, Set<String> set) {
        HashSet hashSet = new HashSet();
        int i = 0;
        for (String str : set) {
            hashSet.add(str.toLowerCase().trim());
            hashSet.add(str.toLowerCase().trim().replaceAll("\\s", ""));
            hashSet.add(stem(str.toLowerCase()).trim());
            hashSet.add(stem(str.toLowerCase()).trim().replaceAll("\\s", ""));
            hashSet.add(canonicalize(str.toLowerCase().trim()));
            hashSet.add(canonicalize(str.toLowerCase().trim().replaceAll("\\s", "")));
            hashSet.add(canonicalize(stem(str.toLowerCase()).trim()));
            hashSet.add(canonicalize(stem(str.toLowerCase()).trim().replaceAll("\\s", "")));
        }
        for (PositionAnnotation positionAnnotation : list) {
            String value = positionAnnotation.getValue();
            String value2 = ((NominalFeature) positionAnnotation.getFeatureVector().get(NominalFeature.class, StemmerAnnotator.UNSTEM)).getValue();
            boolean contains = hashSet.contains(value) | hashSet.contains(value.toLowerCase()) | hashSet.contains(value.replaceAll("\\s", "")) | hashSet.contains(value.toLowerCase().replaceAll("\\s", "")) | hashSet.contains(value2) | hashSet.contains(value2.toLowerCase()) | hashSet.contains(value2.replaceAll("\\s", "")) | hashSet.contains(value2.toLowerCase().replaceAll("\\s", "")) | hashSet.contains(canonicalize(value)) | hashSet.contains(canonicalize(value.toLowerCase())) | hashSet.contains(canonicalize(value.replaceAll("\\s", ""))) | hashSet.contains(canonicalize(value.toLowerCase().replaceAll("\\s", ""))) | hashSet.contains(canonicalize(value2)) | hashSet.contains(canonicalize(value2.toLowerCase())) | hashSet.contains(canonicalize(value2.replaceAll("\\s", ""))) | hashSet.contains(canonicalize(value2.toLowerCase().replaceAll("\\s", "")));
            positionAnnotation.getFeatureVector().add(new NominalFeature(IS_KEYWORD, String.valueOf(contains)));
            if (contains) {
                i++;
            }
        }
        return i;
    }

    private static String canonicalize(String str) {
        ArrayList newArrayList = CollectionHelper.newArrayList();
        for (String str2 : str.split("\\s")) {
            newArrayList.add(str2);
        }
        Collections.sort(newArrayList);
        return StringUtils.join(newArrayList, Strings.SINGLE_SPACE_STRING);
    }

    private static List<String> canonicalize(Collection<String> collection) {
        ArrayList newArrayList = CollectionHelper.newArrayList();
        Iterator<String> it = collection.iterator();
        while (it.hasNext()) {
            newArrayList.add(canonicalize(it.next()));
        }
        return newArrayList;
    }

    private String stem(String str) {
        ArrayList arrayList = new ArrayList();
        for (String str2 : str.split("\\s")) {
            arrayList.add(this.stemmer.stem(str2));
        }
        return StringUtils.join(arrayList, Strings.SINGLE_SPACE_STRING);
    }

    private Set<String> stem(Set<String> set) {
        HashSet newHashSet = CollectionHelper.newHashSet();
        Iterator<String> it = set.iterator();
        while (it.hasNext()) {
            newHashSet.add(stem(it.next()));
        }
        return newHashSet;
    }

    @Override // ws.palladian.extraction.keyphrase.KeyphraseExtractor
    public void reset() {
        this.termCorpus.reset();
        this.keyphraseCorpus.reset();
        this.cooccurrenceMatrix.reset();
        this.trainCount = 0;
        this.trainDocuments.clear();
        this.classifier = createClassifier();
        super.reset();
    }

    @Override // ws.palladian.extraction.keyphrase.KeyphraseExtractor
    public List<Keyphrase> extract(String str) {
        TextDocument textDocument = new TextDocument(str);
        try {
            this.corpusGenerationPipeline.process(textDocument);
            this.candidateGenerationPipeline.process(textDocument);
            List<PositionAnnotation> list = (List) textDocument.get(ListFeature.class, BaseTokenizer.PROVIDED_FEATURE);
            ArrayList arrayList = new ArrayList();
            for (PositionAnnotation positionAnnotation : list) {
                double probability = this.classifier.classify((Classifiable) cleanFeatureVector(positionAnnotation.getFeatureVector()), this.model).getProbability("true");
                if (probability != JXLabel.NORMAL) {
                    arrayList.add(new Keyphrase(positionAnnotation.getValue(), probability));
                }
            }
            reRankCooccurrences(arrayList);
            synthetesize(arrayList);
            Collections.sort(arrayList);
            if (arrayList.size() > getKeyphraseCount()) {
                arrayList.subList(getKeyphraseCount(), arrayList.size()).clear();
            }
            return arrayList;
        } catch (DocumentUnprocessableException e) {
            throw new IllegalStateException();
        }
    }

    private int synthetesize(List<Keyphrase> list) {
        Collections.sort(list);
        HashSet hashSet = new HashSet();
        Iterator<Keyphrase> it = list.iterator();
        while (it.hasNext()) {
            hashSet.add(it.next().getValue());
        }
        HashMap newHashMap = CollectionHelper.newHashMap();
        for (Keyphrase keyphrase : list.subList(0, (int) Math.sqrt(list.size()))) {
            for (Pair<String, Double> pair : this.cooccurrenceMatrix.getHighest(keyphrase.getValue(), 5)) {
                String left = pair.getLeft();
                Double valueOf = Double.valueOf(pair.getRight().doubleValue() * 1.0d);
                if (!hashSet.contains(left) && valueOf.doubleValue() >= 0.01d && this.cooccurrenceMatrix.getCount(keyphrase.getValue(), left) >= 2) {
                    if (newHashMap.containsKey(left)) {
                        Keyphrase keyphrase2 = (Keyphrase) newHashMap.get(left);
                        keyphrase2.setWeight(keyphrase2.getWeight() + (keyphrase.getWeight() * valueOf.doubleValue()));
                    } else {
                        Keyphrase keyphrase3 = new Keyphrase(left);
                        keyphrase3.setWeight(keyphrase.getWeight() * valueOf.doubleValue());
                        newHashMap.put(left, keyphrase3);
                    }
                }
            }
        }
        list.addAll(newHashMap.values());
        return newHashMap.size();
    }

    private void reRankOverlaps(List<Keyphrase> list) {
        for (Keyphrase keyphrase : list) {
            if (keyphrase.getWeight() > JXLabel.NORMAL) {
                for (Keyphrase keyphrase2 : list) {
                    if (!keyphrase.getValue().equals(keyphrase2.getValue()) && keyphrase.getValue().contains(keyphrase2.getValue())) {
                        keyphrase2.setWeight(keyphrase2.getWeight() - keyphrase.getWeight());
                    }
                }
            }
        }
    }

    private void reRankCooccurrences(List<Keyphrase> list) {
        int sqrt = (int) Math.sqrt(list.size());
        Collections.sort(list);
        for (Keyphrase keyphrase : list.subList(0, sqrt)) {
            String value = keyphrase.getValue();
            for (Keyphrase keyphrase2 : list.subList(0, sqrt)) {
                String value2 = keyphrase2.getValue();
                if (!value.equals(value2)) {
                    keyphrase.setWeight(keyphrase.getWeight() + (keyphrase2.getWeight() * this.cooccurrenceMatrix.getConditionalProbabilityLaplace(value2, value)));
                }
            }
        }
    }

    @Override // ws.palladian.extraction.keyphrase.KeyphraseExtractor
    public String getExtractorName() {
        return getClass().getSimpleName();
    }

    public String toString() {
        return getExtractorName();
    }
}
