package ws.palladian.classification.numeric;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.configuration.tree.DefaultExpressionEngine;
import ws.palladian.classification.Instance;
import ws.palladian.classification.Model;
import ws.palladian.classification.utils.MinMaxNormalization;
import ws.palladian.processing.Trainable;
import ws.palladian.processing.features.FeatureVector;
import ws.palladian.processing.features.NumericFeature;

/* loaded from: input_file:lib/palladian.jar:ws/palladian/classification/numeric/KnnModel.class */
public final class KnnModel implements Model {
    private static final long serialVersionUID = -6528509220813706056L;
    private List<TrainingExample> trainingExamples;
    private boolean isNormalized = false;
    private MinMaxNormalization normalizationInformation;

    public KnnModel(Iterable<? extends Trainable> iterable) {
        this.trainingExamples = initTrainingInstances(iterable);
    }

    private List<TrainingExample> initTrainingInstances(Iterable<? extends Trainable> iterable) {
        ArrayList arrayList = new ArrayList();
        for (Trainable trainable : iterable) {
            TrainingExample trainingExample = new TrainingExample();
            trainingExample.targetClass = trainable.getTargetClass();
            trainingExample.features = new HashMap();
            for (NumericFeature numericFeature : trainable.getFeatureVector().getAll(NumericFeature.class)) {
                trainingExample.features.put(numericFeature.getName(), numericFeature.getValue());
            }
            arrayList.add(trainingExample);
        }
        return arrayList;
    }

    public List<Trainable> getTrainingExamples() {
        return convertTrainingInstances(this.trainingExamples);
    }

    private List<Trainable> convertTrainingInstances(List<TrainingExample> list) {
        ArrayList arrayList = new ArrayList(list.size());
        for (TrainingExample trainingExample : this.trainingExamples) {
            FeatureVector featureVector = new FeatureVector();
            for (Map.Entry<String, Double> entry : trainingExample.features.entrySet()) {
                featureVector.add(new NumericFeature(entry.getKey(), entry.getValue()));
            }
            arrayList.add(new Instance(trainingExample.targetClass, featureVector));
        }
        return arrayList;
    }

    public void normalize() {
        List<Trainable> convertTrainingInstances = convertTrainingInstances(this.trainingExamples);
        this.normalizationInformation = new MinMaxNormalization(convertTrainingInstances);
        this.normalizationInformation.normalize(convertTrainingInstances);
        this.trainingExamples = initTrainingInstances(convertTrainingInstances);
        this.isNormalized = true;
    }

    public void normalize(FeatureVector featureVector) {
        if (!this.isNormalized) {
            throw new IllegalStateException("Tried calling normalize for an unnormalized model. Please normalize this model before you try this again.");
        }
        this.normalizationInformation.normalize(featureVector);
    }

    public boolean isNormalized() {
        return this.isNormalized;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("KnnModel [");
        sb.append("# trainingInstances=").append(this.trainingExamples.size());
        sb.append(", isNormalized=").append(this.isNormalized);
        sb.append(DefaultExpressionEngine.DEFAULT_ATTRIBUTE_END);
        return sb.toString();
    }
}
