package ws.palladian.classification.numeric;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.tuple.Pair;
import org.jdesktop.swingx.JXLabel;
import ws.palladian.classification.CategoryEntries;
import ws.palladian.classification.CategoryEntriesMap;
import ws.palladian.classification.Classifier;
import ws.palladian.classification.Learner;
import ws.palladian.helper.collection.CollectionHelper;
import ws.palladian.helper.collection.EntryValueComparator;
import ws.palladian.processing.Classifiable;
import ws.palladian.processing.Trainable;
import ws.palladian.processing.features.FeatureVector;
import ws.palladian.processing.features.NumericFeature;

/* loaded from: input_file:lib/palladian.jar:ws/palladian/classification/numeric/KnnClassifier.class */
public final class KnnClassifier implements Learner<KnnModel>, Classifier<KnnModel> {
    private final int k;

    public KnnClassifier(int i) {
        this.k = i;
    }

    public KnnClassifier() {
        this(3);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // ws.palladian.classification.Learner
    public KnnModel train(Iterable<? extends Trainable> iterable) {
        return new KnnModel(iterable);
    }

    @Override // ws.palladian.classification.Classifier
    public CategoryEntries classify(Classifiable classifiable, KnnModel knnModel) {
        if (knnModel.isNormalized()) {
            knnModel.normalize(classifiable.getFeatureVector());
        }
        Set<String> possibleCategories = getPossibleCategories(knnModel.getTrainingExamples());
        HashMap newHashMap = CollectionHelper.newHashMap();
        Iterator<String> it = possibleCategories.iterator();
        while (it.hasNext()) {
            newHashMap.put(it.next(), Double.valueOf(JXLabel.NORMAL));
        }
        ArrayList<Pair> newArrayList = CollectionHelper.newArrayList();
        for (Trainable trainable : knnModel.getTrainingExamples()) {
            newArrayList.add(Pair.of(trainable, Double.valueOf(getDistanceBetween(classifiable.getFeatureVector(), trainable.getFeatureVector()))));
        }
        Collections.sort(newArrayList, EntryValueComparator.ascending());
        double d = -1.0d;
        int i = 0;
        for (Pair pair : newArrayList) {
            if (i >= this.k && ((Double) pair.getValue()).doubleValue() != d) {
                break;
            }
            double doubleValue = ((Double) pair.getValue()).doubleValue();
            double d2 = 1.0d / (doubleValue + 1.0E-9d);
            String targetClass = ((Trainable) pair.getKey()).getTargetClass();
            newHashMap.put(targetClass, Double.valueOf(((Double) newHashMap.get(targetClass)).doubleValue() + d2));
            d = doubleValue;
            i++;
        }
        CategoryEntriesMap categoryEntriesMap = new CategoryEntriesMap();
        for (Map.Entry entry : newHashMap.entrySet()) {
            categoryEntriesMap.set((String) entry.getKey(), ((Double) entry.getValue()).doubleValue());
        }
        return categoryEntriesMap;
    }

    private Set<String> getPossibleCategories(List<Trainable> list) {
        HashSet newHashSet = CollectionHelper.newHashSet();
        Iterator<Trainable> it = list.iterator();
        while (it.hasNext()) {
            newHashSet.add(it.next().getTargetClass());
        }
        return newHashSet;
    }

    private double getDistanceBetween(FeatureVector featureVector, FeatureVector featureVector2) {
        double d = 0.0d;
        for (NumericFeature numericFeature : featureVector.getAll(NumericFeature.class)) {
            d += Math.pow(numericFeature.getValue().doubleValue() - ((NumericFeature) featureVector2.get(NumericFeature.class, numericFeature.getName())).getValue().doubleValue(), 2.0d);
        }
        return Math.sqrt(d);
    }

    @Override // ws.palladian.classification.Learner
    public /* bridge */ /* synthetic */ KnnModel train(Iterable iterable) {
        return train((Iterable<? extends Trainable>) iterable);
    }
}
