package ws.palladian.classification.nb;

import java.util.Map;
import java.util.Set;
import org.apache.commons.configuration.tree.DefaultExpressionEngine;
import org.apache.commons.lang3.Validate;
import org.javatuples.Pair;
import org.javatuples.Triplet;
import org.jdesktop.swingx.JXLabel;
import ws.palladian.classification.Model;
import ws.palladian.helper.collection.CountMap;

/* loaded from: input_file:lib/palladian.jar:ws/palladian/classification/nb/NaiveBayesModel.class */
public final class NaiveBayesModel implements Model {
    private static final long serialVersionUID = 1;
    private final CountMap<Triplet<String, String, String>> nominalCounts;
    private final CountMap<String> categories;
    private final Map<Pair<String, String>, Double> sampleMeans;
    private final Map<Pair<String, String>, Double> standardDeviations;

    /* JADX INFO: Access modifiers changed from: package-private */
    public NaiveBayesModel(CountMap<Triplet<String, String, String>> countMap, CountMap<String> countMap2, Map<Pair<String, String>, Double> map, Map<Pair<String, String>, Double> map2) {
        this.nominalCounts = countMap;
        this.categories = countMap2;
        this.sampleMeans = map;
        this.standardDeviations = map2;
    }

    public Set<String> getCategoryNames() {
        return this.categories.uniqueItems();
    }

    public double getPrior(String str) {
        Validate.notNull(str, "category must not be null", new Object[0]);
        return this.categories.getCount(str) / this.categories.totalSize();
    }

    public double getProbability(String str, String str2, String str3, double d) {
        Validate.notNull(str, "featureName must not be null", new Object[0]);
        Validate.notNull(str2, "featureValue must not be null", new Object[0]);
        Validate.notNull(str3, "category must not be null", new Object[0]);
        Validate.isTrue(d >= JXLabel.NORMAL, "laplace corrector must be equal or greater than zero", new Object[0]);
        return (this.nominalCounts.getCount(new Triplet(str, str2, str3)) + d) / (this.categories.getCount(str3) + (d * this.categories.uniqueSize()));
    }

    private Double getStandardDeviation(String str, String str2) {
        return this.standardDeviations.get(new Pair(str, str2));
    }

    private Double getMean(String str, String str2) {
        return this.sampleMeans.get(new Pair(str, str2));
    }

    public double getDensity(String str, double d, String str2) {
        Validate.notNull(str, "featureName must not be null", new Object[0]);
        Validate.notNull(str2, "category must not be null", new Object[0]);
        Double standardDeviation = getStandardDeviation(str, str2);
        return (standardDeviation == null || standardDeviation.doubleValue() == JXLabel.NORMAL) ? JXLabel.NORMAL : (1.0d / (Math.sqrt(6.283185307179586d) * standardDeviation.doubleValue())) * Math.pow(2.718281828459045d, (-Math.pow(d - getMean(str, str2).doubleValue(), 2.0d)) / (2.0d * Math.pow(standardDeviation.doubleValue(), 2.0d)));
    }

    public String toString() {
        return "NaiveBayesModel [nominalCounts=" + this.nominalCounts + ", categories=" + this.categories + ", sampleMeans=" + this.sampleMeans + ", standardDeviations=" + this.standardDeviations + DefaultExpressionEngine.DEFAULT_ATTRIBUTE_END;
    }
}
