package ws.palladian.classification.nb;

import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.Validate;
import org.javatuples.Pair;
import org.javatuples.Triplet;
import org.jdesktop.swingx.JXLabel;
import ws.palladian.classification.CategoryEntries;
import ws.palladian.classification.CategoryEntriesMap;
import ws.palladian.classification.Classifier;
import ws.palladian.classification.Learner;
import ws.palladian.helper.collection.CollectionHelper;
import ws.palladian.helper.collection.CountMap;
import ws.palladian.helper.collection.Factory;
import ws.palladian.helper.collection.LazyMap;
import ws.palladian.processing.Classifiable;
import ws.palladian.processing.Trainable;
import ws.palladian.processing.features.Feature;
import ws.palladian.processing.features.NominalFeature;
import ws.palladian.processing.features.NumericFeature;

/* loaded from: input_file:lib/palladian.jar:ws/palladian/classification/nb/NaiveBayesClassifier.class */
public final class NaiveBayesClassifier implements Learner<NaiveBayesModel>, Classifier<NaiveBayesModel> {
    private static final double DEFAULT_LAPLACE_CORRECTOR = 1.0E-5d;
    private final double laplace;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/palladian.jar:ws/palladian/classification/nb/NaiveBayesClassifier$Stats.class */
    public static final class Stats {
        private final List<Double> values = CollectionHelper.newArrayList();

        public void add(Double d) {
            this.values.add(d);
        }

        public double getMean() {
            double d = 0.0d;
            Iterator<Double> it = this.values.iterator();
            while (it.hasNext()) {
                d += it.next().doubleValue();
            }
            return d / this.values.size();
        }

        public double getStandardDeviation() {
            if (this.values.size() == 1) {
                return JXLabel.NORMAL;
            }
            double mean = getMean();
            double d = 0.0d;
            Iterator<Double> it = this.values.iterator();
            while (it.hasNext()) {
                d += Math.pow(it.next().doubleValue() - mean, 2.0d);
            }
            return Math.sqrt(d / (this.values.size() - 1));
        }
    }

    public NaiveBayesClassifier() {
        this(1.0E-5d);
    }

    public NaiveBayesClassifier(double d) {
        Validate.isTrue(d >= JXLabel.NORMAL, "The Laplace corrector must be equal or greater than zero.", new Object[0]);
        this.laplace = d;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // ws.palladian.classification.Learner
    public NaiveBayesModel train(Iterable<? extends Trainable> iterable) {
        CountMap create = CountMap.create();
        CountMap create2 = CountMap.create();
        LazyMap create3 = LazyMap.create(new Factory<Stats>() { // from class: ws.palladian.classification.nb.NaiveBayesClassifier.1
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // ws.palladian.helper.collection.Factory
            public Stats create() {
                return new Stats();
            }
        });
        for (Trainable trainable : iterable) {
            String targetClass = trainable.getTargetClass();
            create.add(targetClass);
            Iterator<Feature<?>> it = trainable.getFeatureVector().iterator();
            while (it.hasNext()) {
                Feature<?> next = it.next();
                String name = next.getName();
                if (next instanceof NominalFeature) {
                    create2.add(new Triplet(name, ((NominalFeature) next).getValue(), targetClass));
                }
                if (next instanceof NumericFeature) {
                    ((Stats) create3.get(new Pair(name, targetClass))).add(((NumericFeature) next).getValue());
                }
            }
        }
        HashMap newHashMap = CollectionHelper.newHashMap();
        HashMap newHashMap2 = CollectionHelper.newHashMap();
        Iterator it2 = create3.entrySet().iterator();
        while (it2.hasNext()) {
            Map.Entry entry = (Map.Entry) it2.next();
            newHashMap.put(entry.getKey(), Double.valueOf(((Stats) entry.getValue()).getMean()));
            newHashMap2.put(entry.getKey(), Double.valueOf(((Stats) entry.getValue()).getStandardDeviation()));
        }
        return new NaiveBayesModel(create2, create, newHashMap, newHashMap2);
    }

    @Override // ws.palladian.classification.Classifier
    public CategoryEntries classify(Classifiable classifiable, NaiveBayesModel naiveBayesModel) {
        HashMap newHashMap = CollectionHelper.newHashMap();
        for (String str : naiveBayesModel.getCategoryNames()) {
            double prior = naiveBayesModel.getPrior(str);
            Iterator<Feature<?>> it = classifiable.getFeatureVector().iterator();
            while (it.hasNext()) {
                Feature<?> next = it.next();
                String name = next.getName();
                if (next instanceof NominalFeature) {
                    prior *= naiveBayesModel.getProbability(name, ((NominalFeature) next).getValue(), str, this.laplace);
                }
                if (next instanceof NumericFeature) {
                    double density = naiveBayesModel.getDensity(name, ((NumericFeature) next).getValue().doubleValue(), str);
                    if (density > JXLabel.NORMAL) {
                        prior *= density;
                    }
                }
            }
            newHashMap.put(str, Double.valueOf(prior));
        }
        return new CategoryEntriesMap(newHashMap);
    }

    @Override // ws.palladian.classification.Learner
    public /* bridge */ /* synthetic */ NaiveBayesModel train(Iterable iterable) {
        return train((Iterable<? extends Trainable>) iterable);
    }
}
