package ws.palladian.classification.featureselection;

import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.Validate;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ws.palladian.helper.collection.CollectionHelper;
import ws.palladian.processing.Trainable;
import ws.palladian.processing.features.Feature;

/* loaded from: input_file:lib/palladian.jar:ws/palladian/classification/featureselection/InformationGainFeatureRanker.class */
public final class InformationGainFeatureRanker extends AbstractFeatureRanker {
    private static final Logger LOGGER = LoggerFactory.getLogger(InformationGainFeatureRanker.class);

    private Map<Feature<?>, Double> calculateInformationGain(Collection<? extends Trainable> collection) {
        Validate.notNull(collection);
        HashMap newHashMap = CollectionHelper.newHashMap();
        if (collection.isEmpty()) {
            LOGGER.warn("Dataset for feature selection is empty. No feature selection is carried out.");
            return newHashMap;
        }
        Collection<Pair<Set<Feature<?>>, String>> prepare = prepare(collection);
        Map<String, Double> calculateTargetClassPriors = calculateTargetClassPriors(collection);
        HashMap newHashMap2 = CollectionHelper.newHashMap();
        HashMap newHashMap3 = CollectionHelper.newHashMap();
        HashMap newHashMap4 = CollectionHelper.newHashMap();
        Integer num = 0;
        for (Pair<Set<Feature<?>>, String> pair : prepare) {
            for (Feature<?> feature : pair.getLeft()) {
                Integer num2 = (Integer) newHashMap2.get(feature);
                if (num2 == null) {
                    num2 = 0;
                }
                newHashMap2.put(feature, Integer.valueOf(num2.intValue() + 1));
                Map map = (Map) newHashMap3.get(feature);
                if (map == null) {
                    map = CollectionHelper.newHashMap();
                }
                Integer num3 = (Integer) map.get(pair.getRight());
                if (num3 == null) {
                    num3 = 0;
                }
                map.put(pair.getRight(), Integer.valueOf(num3.intValue() + 1));
                newHashMap3.put(feature, map);
                Integer num4 = (Integer) newHashMap4.get(pair.getRight());
                if (num4 == null) {
                    num4 = 0;
                }
                newHashMap4.put(pair.getRight(), Integer.valueOf(num4.intValue() + 1));
                num = Integer.valueOf(num.intValue() + 1);
            }
        }
        double d = 0.0d;
        Iterator<Map.Entry<String, Double>> it = calculateTargetClassPriors.entrySet().iterator();
        while (it.hasNext()) {
            double doubleValue = it.next().getValue().doubleValue();
            d += doubleValue * Math.log(doubleValue);
        }
        for (Map.Entry entry : newHashMap2.entrySet()) {
            double d2 = 0.0d;
            double d3 = 0.0d;
            for (Map.Entry entry2 : ((Map) newHashMap3.get(entry.getKey())).entrySet()) {
                double laplaceSmooth = laplaceSmooth(newHashMap2.keySet().size(), ((Integer) newHashMap4.get(entry2.getKey())).intValue(), ((Integer) entry2.getValue()).intValue());
                d2 += laplaceSmooth * Math.log(laplaceSmooth);
                double laplaceSmooth2 = laplaceSmooth(newHashMap2.keySet().size(), num.intValue() - ((Integer) newHashMap4.get(entry2.getKey())).intValue(), collection.size() - ((Integer) entry2.getValue()).intValue());
                d3 += laplaceSmooth2 * Math.log(laplaceSmooth2);
            }
            newHashMap.put(entry.getKey(), Double.valueOf((-d) + ((((Integer) entry.getValue()).doubleValue() / collection.size()) * d2) + (((collection.size() - ((Integer) entry.getValue()).intValue()) / collection.size()) * d3)));
        }
        return newHashMap;
    }

    private Collection<Pair<Set<Feature<?>>, String>> prepare(Collection<? extends Trainable> collection) {
        HashSet newHashSet = CollectionHelper.newHashSet();
        for (Trainable trainable : collection) {
            newHashSet.add(new ImmutablePair(convertToSet(trainable.getFeatureVector(), collection), trainable.getTargetClass()));
        }
        return newHashSet;
    }

    private Map<String, Double> calculateTargetClassPriors(Collection<? extends Trainable> collection) {
        HashMap newHashMap = CollectionHelper.newHashMap();
        HashMap newHashMap2 = CollectionHelper.newHashMap();
        for (Trainable trainable : collection) {
            Integer num = (Integer) newHashMap2.get(trainable.getTargetClass());
            if (num == null) {
                num = 0;
            }
            newHashMap2.put(trainable.getTargetClass(), Integer.valueOf(num.intValue() + 1));
        }
        for (Map.Entry entry : newHashMap2.entrySet()) {
            newHashMap.put(entry.getKey(), Double.valueOf(((Integer) entry.getValue()).doubleValue() / collection.size()));
        }
        return newHashMap;
    }

    private static double laplaceSmooth(int i, int i2, int i3) {
        return (1.0d + i3) / (i + i2);
    }

    @Override // ws.palladian.classification.featureselection.FeatureRanker
    public FeatureRanking rankFeatures(Collection<? extends Trainable> collection) {
        FeatureRanking featureRanking = new FeatureRanking();
        Map<Feature<?>, Double> calculateInformationGain = calculateInformationGain(collection);
        HashMap newHashMap = CollectionHelper.newHashMap();
        for (Map.Entry<Feature<?>, Double> entry : calculateInformationGain.entrySet()) {
            String name = entry.getKey().getName();
            List list = (List) newHashMap.get(name);
            if (list == null) {
                list = CollectionHelper.newArrayList();
            }
            list.add(entry.getValue());
            newHashMap.put(name, list);
        }
        for (Map.Entry entry2 : newHashMap.entrySet()) {
            double d = 0.0d;
            Iterator it = ((List) entry2.getValue()).iterator();
            while (it.hasNext()) {
                d += ((Double) it.next()).doubleValue();
            }
            featureRanking.add((String) entry2.getKey(), d / ((List) entry2.getValue()).size());
        }
        return featureRanking;
    }
}
