package ws.palladian.classification.utils;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.TreeSet;
import org.apache.commons.lang3.Validate;
import org.apache.log4j.spi.LocationInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ws.palladian.classification.CategoryEntries;
import ws.palladian.classification.CategoryEntriesMap;
import ws.palladian.classification.Classifier;
import ws.palladian.classification.Instance;
import ws.palladian.classification.Model;
import ws.palladian.helper.StopWatch;
import ws.palladian.helper.collection.CollectionHelper;
import ws.palladian.helper.collection.Filter;
import ws.palladian.helper.io.FileHelper;
import ws.palladian.helper.io.LineAction;
import ws.palladian.processing.Classifiable;
import ws.palladian.processing.Trainable;
import ws.palladian.processing.features.Feature;
import ws.palladian.processing.features.FeatureVector;
import ws.palladian.processing.features.NominalFeature;
import ws.palladian.processing.features.NumericFeature;

/* loaded from: input_file:lib/palladian.jar:ws/palladian/classification/utils/ClassificationUtils.class */
public final class ClassificationUtils {
    private static final Logger LOGGER = LoggerFactory.getLogger(ClassificationUtils.class);
    public static final String DEFAULT_SEPARATOR = ";";

    private ClassificationUtils() {
    }

    public static List<Trainable> readCsv(String str) {
        return readCsv(str, true, DEFAULT_SEPARATOR);
    }

    public static List<Trainable> readCsv(String str, boolean z) {
        return readCsv(str, z, DEFAULT_SEPARATOR);
    }

    public static List<Trainable> readCsv(String str, final boolean z, final String str2) {
        if (!new File(str).canRead()) {
            throw new IllegalArgumentException("Cannot find or read file \"" + str + "\"");
        }
        StopWatch stopWatch = new StopWatch();
        final ArrayList newArrayList = CollectionHelper.newArrayList();
        FileHelper.performActionOnEveryLine(str, new LineAction() { // from class: ws.palladian.classification.utils.ClassificationUtils.1
            String[] headNames;
            int expectedColumns;

            @Override // ws.palladian.helper.io.LineAction
            public void performAction(String str3, int i) {
                String[] split = str3.split(str2);
                if (split.length < 2) {
                    throw new IllegalStateException("Separator '" + str2 + "'was not found, lines cannot be split ('" + str3 + "').");
                }
                if (i == 0) {
                    this.expectedColumns = split.length;
                    if (z) {
                        this.headNames = split;
                        return;
                    }
                } else if (this.expectedColumns != split.length) {
                    throw new IllegalStateException("Unexpected number of entries in line " + i + "(" + split.length + ", but should be " + this.expectedColumns + ")");
                }
                FeatureVector featureVector = new FeatureVector();
                for (int i2 = 0; i2 < split.length - 1; i2++) {
                    String valueOf = this.headNames == null ? String.valueOf(i2) : this.headNames[i2];
                    String str4 = split[i2];
                    if (!str4.equals(LocationInfo.NA)) {
                        try {
                            featureVector.add(new NumericFeature(valueOf, Double.valueOf(str4)));
                        } catch (NumberFormatException e) {
                            featureVector.add(new NominalFeature(valueOf, str4));
                        }
                    }
                }
                newArrayList.add(new Instance(split[split.length - 1], featureVector));
                if (i % 10000 == 0) {
                    ClassificationUtils.LOGGER.debug("Read {} lines", Integer.valueOf(i));
                }
            }
        });
        LOGGER.info("Read {} instances from {} in {}", Integer.valueOf(newArrayList.size()), str, stopWatch);
        return newArrayList;
    }

    public static void writeCsv(Iterable<? extends Classifiable> iterable, File file) {
        Validate.notNull(iterable, "trainData must not be null", new Object[0]);
        Validate.notNull(file, "outputFile must not be null", new Object[0]);
        BufferedWriter bufferedWriter = null;
        try {
            try {
                bufferedWriter = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file), "UTF-8"));
                boolean z = true;
                int i = 0;
                int i2 = 0;
                for (Classifiable classifiable : iterable) {
                    if (z) {
                        Iterator<Feature<?>> it = classifiable.getFeatureVector().iterator();
                        while (it.hasNext()) {
                            bufferedWriter.write(it.next().getName());
                            bufferedWriter.write(DEFAULT_SEPARATOR);
                            i2++;
                        }
                        if (classifiable instanceof Trainable) {
                            bufferedWriter.write("targetClass");
                        }
                        bufferedWriter.write(FileHelper.NEWLINE_CHARACTER);
                        z = false;
                    }
                    Iterator<Feature<?>> it2 = classifiable.getFeatureVector().iterator();
                    while (it2.hasNext()) {
                        bufferedWriter.write(it2.next().getValue().toString());
                        bufferedWriter.write(DEFAULT_SEPARATOR);
                    }
                    if (classifiable instanceof Trainable) {
                        bufferedWriter.write(((Trainable) classifiable).getTargetClass());
                    }
                    bufferedWriter.write(FileHelper.NEWLINE_CHARACTER);
                    i++;
                }
                LOGGER.info("Wrote {} train instances with {} features.", Integer.valueOf(i), Integer.valueOf(i2));
                FileHelper.close(bufferedWriter);
            } catch (IOException e) {
                throw new IllegalStateException("Encountered " + e + " while writing to '" + file + "'", e);
            }
        } catch (Throwable th) {
            FileHelper.close(bufferedWriter);
            throw th;
        }
    }

    public static <T> List<T> drawRandomSubset(List<T> list, int i) {
        Random random = new Random(System.currentTimeMillis());
        ArrayList arrayList = new ArrayList(list);
        int size = (i * list.size()) / 100;
        for (int i2 = 0; i2 < size; i2++) {
            int nextInt = random.nextInt(arrayList.size() - i2) + i2;
            Object obj = arrayList.get(i2);
            arrayList.set(i2, arrayList.get(nextInt));
            arrayList.set(nextInt, obj);
        }
        return new ArrayList(arrayList.subList(0, size));
    }

    public static FeatureVector filterFeatures(Classifiable classifiable, Filter<String> filter) {
        Validate.notNull(classifiable, "classifiable must not be null", new Object[0]);
        Validate.notNull(filter, "nameFilter must not be null", new Object[0]);
        FeatureVector featureVector = new FeatureVector();
        Iterator<Feature<?>> it = classifiable.getFeatureVector().iterator();
        while (it.hasNext()) {
            Feature<?> next = it.next();
            if (filter.accept(next.getName())) {
                featureVector.add(next);
            }
        }
        LOGGER.trace("Reduced from {} to {}", Integer.valueOf(classifiable.getFeatureVector().size()), Integer.valueOf(featureVector.size()));
        return featureVector;
    }

    public static List<Trainable> filterFeatures(Iterable<? extends Trainable> iterable, Filter<String> filter) {
        ArrayList newArrayList = CollectionHelper.newArrayList();
        for (Trainable trainable : iterable) {
            newArrayList.add(new Instance(trainable.getTargetClass(), filterFeatures(trainable, filter)));
        }
        return newArrayList;
    }

    public static Set<String> getFeatureNames(Collection<? extends Trainable> collection) {
        Validate.notNull(collection, "dataset must not be null", new Object[0]);
        TreeSet newTreeSet = CollectionHelper.newTreeSet();
        Iterator<Feature<?>> it = ((Trainable) CollectionHelper.getFirst(collection)).getFeatureVector().iterator();
        while (it.hasNext()) {
            newTreeSet.add(it.next().getName());
        }
        return newTreeSet;
    }

    public static <M extends Model, T extends Classifiable> CategoryEntries classifyWithMultipleModels(Classifier<M> classifier, T t, M... mArr) {
        CategoryEntriesMap categoryEntriesMap = new CategoryEntriesMap();
        for (M m : mArr) {
            categoryEntriesMap = CategoryEntriesMap.merge(classifier.classify(t, m), categoryEntriesMap);
        }
        return categoryEntriesMap;
    }
}
