package ws.palladian.classification.featureselection;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import org.apache.commons.lang3.Validate;
import org.jdesktop.swingx.JXLabel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ws.palladian.classification.Classifier;
import ws.palladian.classification.Learner;
import ws.palladian.classification.Model;
import ws.palladian.classification.nb.NaiveBayesClassifier;
import ws.palladian.classification.utils.ClassificationUtils;
import ws.palladian.classification.utils.ClassifierEvaluation;
import ws.palladian.helper.ProgressMonitor;
import ws.palladian.helper.collection.CollectionHelper;
import ws.palladian.helper.collection.EqualsFilter;
import ws.palladian.helper.collection.Function;
import ws.palladian.helper.math.ConfusionMatrix;
import ws.palladian.processing.Trainable;

/* loaded from: input_file:lib/palladian.jar:ws/palladian/classification/featureselection/SingleFeatureClassification.class */
public final class SingleFeatureClassification<M extends Model> implements FeatureRanker {
    private static final Logger LOGGER = LoggerFactory.getLogger(SingleFeatureClassification.class);
    private final Learner<M> learner;
    private final Classifier<M> classifier;
    private final Function<ConfusionMatrix, Double> scorer;

    public SingleFeatureClassification(Learner<M> learner, Classifier<M> classifier, Function<ConfusionMatrix, Double> function) {
        Validate.notNull(learner, "learner must not be null", new Object[0]);
        Validate.notNull(classifier, "classifier must not be null", new Object[0]);
        Validate.notNull(function, "scorer must not be null", new Object[0]);
        this.learner = learner;
        this.classifier = classifier;
        this.scorer = function;
    }

    @Override // ws.palladian.classification.featureselection.FeatureRanker
    public FeatureRanking rankFeatures(Collection<? extends Trainable> collection) {
        ArrayList arrayList = new ArrayList(collection);
        Collections.shuffle(arrayList);
        return rankFeatures(arrayList.subList(0, arrayList.size() / 2), arrayList.subList(arrayList.size() / 2, arrayList.size()));
    }

    public FeatureRanking rankFeatures(Collection<? extends Trainable> collection, Collection<? extends Trainable> collection2) {
        FeatureRanking featureRanking = new FeatureRanking();
        Set<String> featureNames = ClassificationUtils.getFeatureNames(collection);
        ProgressMonitor progressMonitor = new ProgressMonitor(featureNames.size(), JXLabel.NORMAL);
        for (String str : featureNames) {
            EqualsFilter create = EqualsFilter.create(str);
            List<Trainable> filterFeatures = ClassificationUtils.filterFeatures(collection, create);
            Double compute = this.scorer.compute(ClassifierEvaluation.evaluate(this.classifier, ClassificationUtils.filterFeatures(collection2, create), this.learner.train(filterFeatures)));
            LOGGER.info("Finished testing with {}: {}", str, compute);
            progressMonitor.incrementAndPrintProgress();
            featureRanking.add(str, compute.doubleValue());
        }
        return featureRanking;
    }

    public static void main(String[] strArr) {
        List<Trainable> readCsv = ClassificationUtils.readCsv("/Users/pk/Dropbox/LocationExtraction/BFE/fd_merged_train.csv");
        List<Trainable> readCsv2 = ClassificationUtils.readCsv("/Users/pk/Dropbox/LocationExtraction/BFE/fd_merged_validation.csv");
        NaiveBayesClassifier naiveBayesClassifier = new NaiveBayesClassifier();
        CollectionHelper.print(new SingleFeatureClassification(naiveBayesClassifier, naiveBayesClassifier, new Function<ConfusionMatrix, Double>() { // from class: ws.palladian.classification.featureselection.SingleFeatureClassification.1
            @Override // ws.palladian.helper.collection.Function
            public Double compute(ConfusionMatrix confusionMatrix) {
                double f = confusionMatrix.getF(1.0d, "true");
                return Double.valueOf(Double.isNaN(f) ? JXLabel.NORMAL : f);
            }
        }).rankFeatures(readCsv, readCsv2).getAll());
    }
}
