package ws.palladian.classification.featureselection;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ws.palladian.classification.discretization.Binner;
import ws.palladian.helper.ProgressMonitor;
import ws.palladian.helper.collection.CollectionHelper;
import ws.palladian.helper.collection.CountMap;
import ws.palladian.processing.Trainable;
import ws.palladian.processing.features.Feature;
import ws.palladian.processing.features.FeatureVector;
import ws.palladian.processing.features.ListFeature;
import ws.palladian.processing.features.NumericFeature;

/* loaded from: input_file:lib/palladian.jar:ws/palladian/classification/featureselection/AbstractFeatureRanker.class */
public abstract class AbstractFeatureRanker implements FeatureRanker {
    private static final Logger LOGGER = LoggerFactory.getLogger(AbstractFeatureRanker.class);
    private final Map<String, Binner> binnerCache = CollectionHelper.newHashMap();

    /* JADX INFO: Access modifiers changed from: protected */
    public Set<Feature<?>> convertToSet(FeatureVector featureVector, Collection<? extends Trainable> collection) {
        HashSet newHashSet = CollectionHelper.newHashSet();
        boolean isEmpty = this.binnerCache.isEmpty();
        if (isEmpty) {
            LOGGER.info("Converting {} features to set", Integer.valueOf(featureVector.size()));
        }
        ProgressMonitor progressMonitor = new ProgressMonitor(featureVector.size(), 1.0d);
        Iterator<Feature<?>> it = featureVector.iterator();
        while (it.hasNext()) {
            final Feature<?> next = it.next();
            if (next instanceof ListFeature) {
                Iterator it2 = ((ListFeature) next).iterator();
                while (it2.hasNext()) {
                    final Feature feature = (Feature) it2.next();
                    if (feature instanceof NumericFeature) {
                        Binner binner = this.binnerCache.get(feature.getName());
                        if (binner == null) {
                            if (isEmpty) {
                                progressMonitor.incrementAndPrintProgress();
                            }
                            binner = discretize(feature.getName(), collection, new Comparator<Trainable>() { // from class: ws.palladian.classification.featureselection.AbstractFeatureRanker.1
                                @Override // java.util.Comparator
                                public int compare(Trainable trainable, Trainable trainable2) {
                                    ListFeature listFeature = (ListFeature) trainable.getFeatureVector().get(ListFeature.class, next.getName());
                                    ListFeature listFeature2 = (ListFeature) trainable2.getFeatureVector().get(ListFeature.class, next.getName());
                                    NumericFeature numericFeature = (NumericFeature) listFeature.getFeatureWithName(feature.getName());
                                    NumericFeature numericFeature2 = (NumericFeature) listFeature2.getFeatureWithName(feature.getName());
                                    return Double.valueOf(numericFeature == null ? Double.MIN_VALUE : numericFeature.getValue().doubleValue()).compareTo(Double.valueOf(numericFeature2 == null ? Double.MIN_VALUE : numericFeature2.getValue().doubleValue()));
                                }
                            });
                            this.binnerCache.put(feature.getName(), binner);
                        }
                        newHashSet.add(binner.bin((NumericFeature) feature));
                    } else {
                        newHashSet.add(feature);
                    }
                }
            } else if (next instanceof NumericFeature) {
                Binner binner2 = this.binnerCache.get(next.getName());
                if (binner2 == null) {
                    if (isEmpty) {
                        progressMonitor.incrementAndPrintProgress();
                    }
                    binner2 = discretize(next.getName(), collection, new Comparator<Trainable>() { // from class: ws.palladian.classification.featureselection.AbstractFeatureRanker.2
                        @Override // java.util.Comparator
                        public int compare(Trainable trainable, Trainable trainable2) {
                            NumericFeature numericFeature = (NumericFeature) trainable.getFeatureVector().get(NumericFeature.class, next.getName());
                            NumericFeature numericFeature2 = (NumericFeature) trainable2.getFeatureVector().get(NumericFeature.class, next.getName());
                            return Double.valueOf(numericFeature == null ? Double.MIN_VALUE : numericFeature.getValue().doubleValue()).compareTo(Double.valueOf(numericFeature2 == null ? Double.MIN_VALUE : numericFeature2.getValue().doubleValue()));
                        }
                    });
                    this.binnerCache.put(next.getName(), binner2);
                }
                newHashSet.add(binner2.bin((NumericFeature) next));
            } else {
                newHashSet.add(next);
                if (isEmpty) {
                    progressMonitor.incrementAndPrintProgress();
                }
            }
        }
        return newHashSet;
    }

    public static Binner discretize(String str, Collection<? extends Trainable> collection, Comparator<Trainable> comparator) {
        ArrayList arrayList = new ArrayList(collection);
        Collections.sort(arrayList, comparator);
        return createBinner(arrayList, str);
    }

    private static Binner createBinner(List<Trainable> list, String str) {
        List<Integer> findBoundaryPoints = findBoundaryPoints(list);
        ArrayList arrayList = new ArrayList();
        Iterator<Trainable> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add((NumericFeature) it.next().getFeatureVector().get(NumericFeature.class, str));
        }
        return new Binner(str, findBoundaryPoints, arrayList);
    }

    private static List<Integer> findBoundaryPoints(List<Trainable> list) {
        ArrayList arrayList = new ArrayList();
        int size = list.size();
        for (int i = 1; i < list.size(); i++) {
            if (list.get(i - 1).getTargetClass() != list.get(i).getTargetClass() && gain(i, list) > ((Math.log(size - 1) / Math.log(2.0d)) - size) + (delta(i, list) / size)) {
                arrayList.add(Integer.valueOf(i));
            }
        }
        return arrayList;
    }

    private static double gain(int i, List<Trainable> list) {
        return (entropy(list) - ((r0.size() * entropy(list.subList(0, i))) / list.size())) - ((r0.size() * entropy(list.subList(i, list.size()))) / list.size());
    }

    private static double delta(int i, List<Trainable> list) {
        double calculateNumberOfClasses = calculateNumberOfClasses(list);
        List<Trainable> subList = list.subList(0, i);
        List<Trainable> subList2 = list.subList(i, list.size());
        return (Math.log(Math.pow(3.0d, calculateNumberOfClasses) - 2.0d) / Math.log(2.0d)) - (((calculateNumberOfClasses * entropy(list)) - (calculateNumberOfClasses(subList) * entropy(subList))) - (calculateNumberOfClasses(subList2) * entropy(subList2)));
    }

    private static double entropy(List<Trainable> list) {
        double d = 0.0d;
        CountMap create = CountMap.create();
        Iterator<Trainable> it = list.iterator();
        while (it.hasNext()) {
            create.add(it.next().getTargetClass());
        }
        Iterator it2 = create.uniqueItems().iterator();
        while (it2.hasNext()) {
            double count = create.getCount((String) it2.next()) / list.size();
            d -= (count * Math.log(count)) / Math.log(2.0d);
        }
        return d;
    }

    private static double calculateNumberOfClasses(List<Trainable> list) {
        HashSet hashSet = new HashSet();
        Iterator<Trainable> it = list.iterator();
        while (it.hasNext()) {
            hashSet.add(it.next().getTargetClass());
        }
        return hashSet.size();
    }
}
