package ws.palladian.classification.universal;

import java.util.Arrays;
import java.util.EnumSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.jdesktop.swingx.JXLabel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ws.palladian.classification.CategoryEntries;
import ws.palladian.classification.CategoryEntriesMap;
import ws.palladian.classification.Classifier;
import ws.palladian.classification.Instance;
import ws.palladian.classification.Learner;
import ws.palladian.classification.nb.NaiveBayesClassifier;
import ws.palladian.classification.nb.NaiveBayesModel;
import ws.palladian.classification.numeric.KnnClassifier;
import ws.palladian.classification.numeric.KnnModel;
import ws.palladian.classification.text.DictionaryModel;
import ws.palladian.classification.text.FeatureSetting;
import ws.palladian.classification.text.PalladianTextClassifier;
import ws.palladian.extraction.token.BaseTokenizer;
import ws.palladian.helper.ProgressHelper;
import ws.palladian.helper.collection.ConstantFactory;
import ws.palladian.helper.collection.LazyMap;
import ws.palladian.processing.Classifiable;
import ws.palladian.processing.Trainable;
import ws.palladian.processing.features.FeatureVector;

/* loaded from: input_file:lib/palladian.jar:ws/palladian/classification/universal/UniversalClassifier.class */
public class UniversalClassifier implements Learner<UniversalClassifierModel>, Classifier<UniversalClassifierModel> {
    private static final Logger LOGGER = LoggerFactory.getLogger(UniversalClassifier.class);
    private final PalladianTextClassifier textClassifier;
    private final KnnClassifier numericClassifier;
    private final NaiveBayesClassifier nominalClassifier;
    private final FeatureSetting featureSetting;
    private final EnumSet<ClassifierSetting> settings;

    /* loaded from: input_file:lib/palladian.jar:ws/palladian/classification/universal/UniversalClassifier$ClassifierSetting.class */
    public enum ClassifierSetting {
        NUMERIC,
        TEXT,
        NOMINAL
    }

    public UniversalClassifier() {
        this(EnumSet.allOf(ClassifierSetting.class), new FeatureSetting(FeatureSetting.TextFeatureType.CHAR_NGRAMS, 3, 7));
    }

    public UniversalClassifier(FeatureSetting featureSetting) {
        this(EnumSet.allOf(ClassifierSetting.class), featureSetting);
    }

    public UniversalClassifier(EnumSet<ClassifierSetting> enumSet) {
        this(enumSet, new FeatureSetting(FeatureSetting.TextFeatureType.CHAR_NGRAMS, 3, 7));
    }

    public UniversalClassifier(EnumSet<ClassifierSetting> enumSet, FeatureSetting featureSetting) {
        this.textClassifier = new PalladianTextClassifier(featureSetting);
        this.featureSetting = featureSetting;
        this.numericClassifier = new KnnClassifier();
        this.nominalClassifier = new NaiveBayesClassifier();
        this.settings = enumSet;
    }

    private void learnClassifierWeights(List<Instance> list, UniversalClassifierModel universalClassifierModel) {
        int[] iArr = new int[3];
        Arrays.fill(iArr, 0);
        int i = 1;
        for (Instance instance : list) {
            int[] evaluateResults = evaluateResults(instance, internalClassify(instance.getFeatureVector(), universalClassifierModel));
            iArr[0] = iArr[0] + evaluateResults[0];
            iArr[1] = iArr[1] + evaluateResults[1];
            iArr[2] = iArr[2] + evaluateResults[2];
            int i2 = i;
            i++;
            ProgressHelper.printProgress(i2, list.size(), JXLabel.NORMAL);
        }
        universalClassifierModel.setWeights(iArr[0] / list.size(), iArr[1] / list.size(), iArr[2] / list.size());
        LOGGER.debug("weight text   : " + universalClassifierModel.getWeights()[0]);
        LOGGER.debug("weight numeric: " + universalClassifierModel.getWeights()[1]);
        LOGGER.debug("weight nominal: " + universalClassifierModel.getWeights()[2]);
    }

    private int[] evaluateResults(Instance instance, UniversalClassificationResult universalClassificationResult) {
        int[] iArr = new int[3];
        Arrays.fill(iArr, 0);
        CategoryEntries textResults = universalClassificationResult.getTextResults();
        if (textResults != null && textResults.getMostLikelyCategory() != null && textResults.getMostLikelyCategory().equals(instance.getTargetClass())) {
            iArr[0] = iArr[0] + 1;
        }
        CategoryEntries numericResults = universalClassificationResult.getNumericResults();
        if (numericResults != null && numericResults.getMostLikelyCategory() != null && numericResults.getMostLikelyCategory().equals(instance.getTargetClass())) {
            iArr[1] = iArr[1] + 1;
        }
        CategoryEntries nominalResults = universalClassificationResult.getNominalResults();
        if (nominalResults != null && nominalResults.getMostLikelyCategory() != null && nominalResults.getMostLikelyCategory().equals(instance.getTargetClass())) {
            iArr[2] = iArr[2] + 1;
        }
        return iArr;
    }

    protected UniversalClassificationResult internalClassify(Classifiable classifiable, UniversalClassifierModel universalClassifierModel) {
        CategoryEntries categoryEntries = null;
        CategoryEntries categoryEntries2 = null;
        CategoryEntries categoryEntries3 = null;
        FeatureVector featureVector = new FeatureVector(classifiable.getFeatureVector());
        featureVector.remove(BaseTokenizer.PROVIDED_FEATURE);
        if (universalClassifierModel.getDictionaryModel() != null) {
            categoryEntries = this.textClassifier.classify((Classifiable) classifiable.getFeatureVector(), universalClassifierModel.getDictionaryModel());
        }
        if (universalClassifierModel.getKnnModel() != null) {
            categoryEntries2 = this.numericClassifier.classify((Classifiable) featureVector, universalClassifierModel.getKnnModel());
        }
        if (universalClassifierModel.getBayesModel() != null) {
            categoryEntries3 = this.nominalClassifier.classify((Classifiable) featureVector, universalClassifierModel.getBayesModel());
        }
        return new UniversalClassificationResult(categoryEntries, categoryEntries2, categoryEntries3);
    }

    private CategoryEntries mergeResults(UniversalClassificationResult universalClassificationResult, UniversalClassifierModel universalClassifierModel) {
        LazyMap create = LazyMap.create(ConstantFactory.create(Double.valueOf(JXLabel.NORMAL)));
        if (universalClassifierModel.getDictionaryModel() != null) {
            create.put(universalClassificationResult.getTextResults(), Double.valueOf(universalClassifierModel.getWeights()[0]));
        }
        if (universalClassifierModel.getKnnModel() != null) {
            create.put(universalClassificationResult.getNumericResults(), Double.valueOf(universalClassifierModel.getWeights()[1]));
        }
        if (universalClassifierModel.getBayesModel() != null) {
            create.put(universalClassificationResult.getNominalResults(), Double.valueOf(universalClassifierModel.getWeights()[2]));
        }
        return normalize(create);
    }

    protected CategoryEntries normalize(Map<CategoryEntries, Double> map) {
        CategoryEntriesMap categoryEntriesMap = new CategoryEntriesMap();
        LazyMap create = LazyMap.create(ConstantFactory.create(Double.valueOf(JXLabel.NORMAL)));
        for (Map.Entry<CategoryEntries, Double> entry : map.entrySet()) {
            for (String str : entry.getKey()) {
                create.put(str, Double.valueOf(((Double) create.get(str)).doubleValue() + (entry.getKey().getProbability(str) * entry.getValue().doubleValue())));
            }
        }
        double d = 0.0d;
        Iterator it = create.entrySet().iterator();
        while (it.hasNext()) {
            d += ((Double) ((Map.Entry) it.next()).getValue()).doubleValue();
        }
        for (Map.Entry entry2 : create.entrySet()) {
            categoryEntriesMap.set((String) entry2.getKey(), ((Double) entry2.getValue()).doubleValue() / d);
        }
        return categoryEntriesMap;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // ws.palladian.classification.Learner
    public UniversalClassifierModel train(Iterable<? extends Trainable> iterable) {
        NaiveBayesModel naiveBayesModel = null;
        KnnModel knnModel = null;
        DictionaryModel dictionaryModel = null;
        if (this.settings.contains(ClassifierSetting.TEXT)) {
            LOGGER.debug("training text classifier");
            dictionaryModel = this.textClassifier.train(iterable);
        }
        Iterator<? extends Trainable> it = iterable.iterator();
        while (it.hasNext()) {
            it.next().getFeatureVector().remove(BaseTokenizer.PROVIDED_FEATURE);
        }
        if (this.settings.contains(ClassifierSetting.NUMERIC)) {
            LOGGER.debug("training numeric classifier");
            knnModel = this.numericClassifier.train(iterable);
        }
        if (this.settings.contains(ClassifierSetting.NOMINAL)) {
            LOGGER.debug("training nominal classifier");
            naiveBayesModel = this.nominalClassifier.train(iterable);
        }
        UniversalClassifierModel universalClassifierModel = new UniversalClassifierModel(naiveBayesModel, knnModel, dictionaryModel);
        LOGGER.debug("learning classifier weights");
        return universalClassifierModel;
    }

    @Override // ws.palladian.classification.Classifier
    public CategoryEntries classify(Classifiable classifiable, UniversalClassifierModel universalClassifierModel) {
        return mergeResults(internalClassify(classifiable, universalClassifierModel), universalClassifierModel);
    }

    public FeatureSetting getFeatureSetting() {
        return this.featureSetting;
    }

    @Override // ws.palladian.classification.Learner
    public /* bridge */ /* synthetic */ UniversalClassifierModel train(Iterable iterable) {
        return train((Iterable<? extends Trainable>) iterable);
    }
}
