package ws.palladian.classification.featureselection;

import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.commons.lang3.Validate;
import org.jdesktop.swingx.JXLabel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickdt.randomForest.RandomForestBuilder;
import ws.palladian.classification.Classifier;
import ws.palladian.classification.Learner;
import ws.palladian.classification.Model;
import ws.palladian.classification.dt.QuickDtClassifier;
import ws.palladian.classification.dt.QuickDtLearner;
import ws.palladian.classification.utils.ClassificationUtils;
import ws.palladian.classification.utils.ClassifierEvaluation;
import ws.palladian.helper.ProgressMonitor;
import ws.palladian.helper.collection.CollectionHelper;
import ws.palladian.helper.collection.ConstantFactory;
import ws.palladian.helper.collection.EqualsFilter;
import ws.palladian.helper.collection.Factory;
import ws.palladian.helper.collection.Filter;
import ws.palladian.helper.collection.Function;
import ws.palladian.helper.collection.InverseFilter;
import ws.palladian.helper.math.ConfusionMatrix;
import ws.palladian.processing.Trainable;

/* loaded from: input_file:lib/palladian.jar:ws/palladian/classification/featureselection/BackwardFeatureElimination.class */
public final class BackwardFeatureElimination<M extends Model> implements FeatureRanker {
    private final Factory<? extends Learner<M>> learnerFactory;
    private final Factory<? extends Classifier<M>> classifierFactory;
    private final Function<ConfusionMatrix, Double> scorer;
    private final int numThreads;
    private static final Logger LOGGER = LoggerFactory.getLogger(BackwardFeatureElimination.class);
    public static final Function<ConfusionMatrix, Double> ACCURACY_SCORER = new Function<ConfusionMatrix, Double>() { // from class: ws.palladian.classification.featureselection.BackwardFeatureElimination.1
        @Override // ws.palladian.helper.collection.Function
        public Double compute(ConfusionMatrix confusionMatrix) {
            return Double.valueOf(confusionMatrix.getAccuracy());
        }
    };

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/palladian.jar:ws/palladian/classification/featureselection/BackwardFeatureElimination$TestRun.class */
    public final class TestRun implements Callable<TestRunResult> {
        private final Collection<? extends Trainable> trainData;
        private final Collection<? extends Trainable> testData;
        private final List<String> featuresToEliminate;
        private final ProgressMonitor monitor;

        public TestRun(Collection<? extends Trainable> collection, Collection<? extends Trainable> collection2, List<String> list, ProgressMonitor progressMonitor) {
            this.trainData = collection;
            this.testData = collection2;
            this.featuresToEliminate = list;
            this.monitor = progressMonitor;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public TestRunResult call() throws Exception {
            String str = (String) CollectionHelper.getLast(this.featuresToEliminate);
            BackwardFeatureElimination.LOGGER.debug("Starting elimination for {}", str);
            InverseFilter create = InverseFilter.create(EqualsFilter.create((Collection) this.featuresToEliminate));
            List<Trainable> filterFeatures = ClassificationUtils.filterFeatures(this.trainData, create);
            Double d = (Double) BackwardFeatureElimination.this.scorer.compute(ClassifierEvaluation.evaluate((Classifier) BackwardFeatureElimination.this.classifierFactory.create(), ClassificationUtils.filterFeatures(this.testData, create), ((Learner) BackwardFeatureElimination.this.learnerFactory.create()).train(filterFeatures)));
            BackwardFeatureElimination.LOGGER.debug("Finished elimination for {}", str);
            this.monitor.incrementAndPrintProgress();
            return new TestRunResult(d, str);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/palladian.jar:ws/palladian/classification/featureselection/BackwardFeatureElimination$TestRunResult.class */
    public static final class TestRunResult {
        private final Double score;
        private final String eliminatedFeature;

        public TestRunResult(Double d, String str) {
            this.score = d;
            this.eliminatedFeature = str;
        }
    }

    public BackwardFeatureElimination(Learner<M> learner, Classifier<M> classifier) {
        this(ConstantFactory.create(learner), ConstantFactory.create(classifier), ACCURACY_SCORER, 1);
    }

    public BackwardFeatureElimination(Factory<? extends Learner<M>> factory, Factory<? extends Classifier<M>> factory2, Function<ConfusionMatrix, Double> function, int i) {
        Validate.notNull(factory, "learnerFactory must not be null", new Object[0]);
        Validate.notNull(factory2, "classifierFactory must not be null", new Object[0]);
        Validate.notNull(function, "scorer must not be null", new Object[0]);
        Validate.isTrue(i > 0, "numThreads must be greater zero", new Object[0]);
        this.learnerFactory = factory;
        this.classifierFactory = factory2;
        this.scorer = function;
        this.numThreads = i;
    }

    @Override // ws.palladian.classification.featureselection.FeatureRanker
    public FeatureRanking rankFeatures(Collection<? extends Trainable> collection) {
        ArrayList arrayList = new ArrayList(collection);
        Collections.shuffle(arrayList);
        return rankFeatures(arrayList.subList(0, arrayList.size() / 2), arrayList.subList(arrayList.size() / 2, arrayList.size()));
    }

    public FeatureRanking rankFeatures(Collection<? extends Trainable> collection, Collection<? extends Trainable> collection2) {
        FeatureRanking featureRanking = new FeatureRanking();
        Set<String> featureNames = ClassificationUtils.getFeatureNames(collection);
        ArrayList newArrayList = CollectionHelper.newArrayList();
        int size = (featureNames.size() * (featureNames.size() + 1)) / 2;
        ProgressMonitor progressMonitor = new ProgressMonitor(size, JXLabel.NORMAL);
        int i = 0;
        LOGGER.info("# of features in dataset: {}", Integer.valueOf(featureNames.size()));
        LOGGER.info("# of iterations: {}", Integer.valueOf(size));
        try {
            LOGGER.info("Score with all features {}", new TestRun(collection, collection2, Arrays.asList("<none>"), progressMonitor).call().score);
            ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.numThreads);
            while (true) {
                HashSet<String> hashSet = new HashSet(featureNames);
                hashSet.removeAll(newArrayList);
                if (hashSet.isEmpty()) {
                    newFixedThreadPool.shutdown();
                    return featureRanking;
                }
                ArrayList newArrayList2 = CollectionHelper.newArrayList();
                for (String str : hashSet) {
                    ArrayList arrayList = new ArrayList(newArrayList);
                    arrayList.add(str);
                    newArrayList2.add(new TestRun(collection, collection2, arrayList, progressMonitor));
                }
                String str2 = null;
                double d = 0.0d;
                Iterator it = newFixedThreadPool.invokeAll(newArrayList2).iterator();
                while (it.hasNext()) {
                    TestRunResult testRunResult = (TestRunResult) ((Future) it.next()).get();
                    if (testRunResult.score.doubleValue() >= d || str2 == null) {
                        d = testRunResult.score.doubleValue();
                        str2 = testRunResult.eliminatedFeature;
                    }
                }
                LOGGER.info("Selected {} for elimination, score {}", str2, Double.valueOf(d));
                newArrayList.add(str2);
                int i2 = i;
                i++;
                featureRanking.add(str2, i2);
            }
        } catch (InterruptedException e) {
            throw new IllegalStateException(e);
        } catch (ExecutionException e2) {
            throw new IllegalStateException(e2);
        } catch (Exception e3) {
            throw new IllegalStateException(e3);
        }
    }

    public static void main(String[] strArr) {
        List<Trainable> readCsv = ClassificationUtils.readCsv("/Users/pk/Dropbox/temp_bfe_location/fd_tud_train_1376394038036.csv", true);
        List<Trainable> readCsv2 = ClassificationUtils.readCsv("/Users/pk/Dropbox/temp_bfe_location/fd_lgl_train_1376399225449.csv", true);
        List<Trainable> readCsv3 = ClassificationUtils.readCsv("/Users/pk/Dropbox/temp_bfe_location/fd_clust_train_1376413884470.csv", true);
        List drawRandomSubset = ClassificationUtils.drawRandomSubset(readCsv2, 30);
        List drawRandomSubset2 = ClassificationUtils.drawRandomSubset(readCsv3, 15);
        ArrayList newArrayList = CollectionHelper.newArrayList();
        newArrayList.addAll(readCsv);
        newArrayList.addAll(drawRandomSubset);
        newArrayList.addAll(drawRandomSubset2);
        List<Trainable> readCsv4 = ClassificationUtils.readCsv("/Users/pk/Dropbox/temp_bfe_location/fd_tud_validation_1376419927925.csv", true);
        List<Trainable> readCsv5 = ClassificationUtils.readCsv("/Users/pk/Dropbox/temp_bfe_location/fd_lgl_validation_1376420924580.csv", true);
        List<Trainable> readCsv6 = ClassificationUtils.readCsv("/Users/pk/Dropbox/temp_bfe_location/fd_clust_validation_1376422975187.csv", true);
        List drawRandomSubset3 = ClassificationUtils.drawRandomSubset(readCsv5, 30);
        List drawRandomSubset4 = ClassificationUtils.drawRandomSubset(readCsv6, 15);
        ArrayList newArrayList2 = CollectionHelper.newArrayList();
        newArrayList2.addAll(readCsv4);
        newArrayList2.addAll(drawRandomSubset3);
        newArrayList2.addAll(drawRandomSubset4);
        ClassificationUtils.writeCsv(newArrayList, new File("/Users/pk/Desktop/fd_merged_train.csv"));
        ClassificationUtils.writeCsv(newArrayList2, new File("/Users/pk/Desktop/fd_merged_validation.csv"));
        System.exit(0);
        List<Trainable> filterFeatures = ClassificationUtils.filterFeatures(newArrayList, InverseFilter.create(EqualsFilter.create("indexScore")));
        List<Trainable> filterFeatures2 = ClassificationUtils.filterFeatures(newArrayList2, InverseFilter.create(EqualsFilter.create("indexScore")));
        Filter<String> filter = new Filter<String>() { // from class: ws.palladian.classification.featureselection.BackwardFeatureElimination.2
            @Override // ws.palladian.helper.collection.Filter
            public boolean accept(String str) {
                if (str.startsWith("containsMarker")) {
                    return str.equals("containsMarker(*)");
                }
                return true;
            }
        };
        CollectionHelper.print(new BackwardFeatureElimination(new Factory<QuickDtLearner>() { // from class: ws.palladian.classification.featureselection.BackwardFeatureElimination.3
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // ws.palladian.helper.collection.Factory
            public QuickDtLearner create() {
                return new QuickDtLearner(new RandomForestBuilder().numTrees(10));
            }
        }, ConstantFactory.create(new QuickDtClassifier()), new Function<ConfusionMatrix, Double>() { // from class: ws.palladian.classification.featureselection.BackwardFeatureElimination.4
            @Override // ws.palladian.helper.collection.Function
            public Double compute(ConfusionMatrix confusionMatrix) {
                return Double.valueOf(confusionMatrix.getF(1.0d, "true"));
            }
        }, 4).rankFeatures(ClassificationUtils.filterFeatures(filterFeatures, filter), ClassificationUtils.filterFeatures(filterFeatures2, filter)).getAll());
    }
}
