package ws.palladian.extraction.location.disambiguation;

import java.io.File;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.commons.lang3.Validate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickdt.randomForest.RandomForestBuilder;
import ws.palladian.classification.Instance;
import ws.palladian.classification.dt.QuickDtLearner;
import ws.palladian.classification.utils.ClassificationUtils;
import ws.palladian.extraction.location.AnnotationFilter;
import ws.palladian.extraction.location.ContextClassifier;
import ws.palladian.extraction.location.EntityPreprocessingTagger;
import ws.palladian.extraction.location.GeoUtils;
import ws.palladian.extraction.location.Location;
import ws.palladian.extraction.location.LocationAnnotation;
import ws.palladian.extraction.location.LocationExtractorUtils;
import ws.palladian.extraction.location.LocationSource;
import ws.palladian.extraction.location.PalladianLocationExtractor;
import ws.palladian.extraction.location.disambiguation.LocationFeatureExtractor;
import ws.palladian.extraction.location.persistence.LocationDatabase;
import ws.palladian.helper.collection.CollectionHelper;
import ws.palladian.helper.collection.CompositeIterator;
import ws.palladian.helper.io.FileHelper;
import ws.palladian.helper.math.MathHelper;
import ws.palladian.persistence.DatabaseManagerFactory;
import ws.palladian.processing.Trainable;

/* loaded from: input_file:lib/palladian.jar:ws/palladian/extraction/location/disambiguation/FeatureBasedDisambiguationLearner.class */
public class FeatureBasedDisambiguationLearner {
    private static final Logger LOGGER = LoggerFactory.getLogger(FeatureBasedDisambiguationLearner.class);
    private final QuickDtLearner learner = new QuickDtLearner(new RandomForestBuilder().numTrees(10));
    private final LocationFeatureExtractor featureExtraction = new LocationFeatureExtractor();
    private final EntityPreprocessingTagger tagger = new EntityPreprocessingTagger(3);
    private final AnnotationFilter filter = new AnnotationFilter();
    private final ContextClassifier contextClassifier = new ContextClassifier(ContextClassifier.ClassificationMode.PROPAGATION);
    private final LocationSource locationSource;

    public FeatureBasedDisambiguationLearner(LocationSource locationSource) {
        Validate.notNull(locationSource, "locationSource must not be null", new Object[0]);
        this.locationSource = locationSource;
    }

    public void learn(File file) {
        learn(LocationExtractorUtils.iterateDataset(file));
    }

    public void learn(File... fileArr) {
        Validate.notNull(fileArr, "datasetDirectories must not be null", new Object[0]);
        ArrayList newArrayList = CollectionHelper.newArrayList();
        for (File file : fileArr) {
            newArrayList.add(LocationExtractorUtils.iterateDataset(file));
        }
        learn(new CompositeIterator(newArrayList));
    }

    public void learn(Iterator<LocationExtractorUtils.LocationDocument> it) {
        Set<Trainable> createTrainingData = createTrainingData(it);
        String format = String.format("data/temp/location_disambiguation_%s", Long.valueOf(System.currentTimeMillis()));
        ClassificationUtils.writeCsv(createTrainingData, new File(format + ".csv"));
        FileHelper.serialize(this.learner.train((Iterable<? extends Trainable>) createTrainingData), format + ".model");
    }

    private Set<Trainable> createTrainingData(Iterator<LocationExtractorUtils.LocationDocument> it) {
        HashSet newHashSet = CollectionHelper.newHashSet();
        while (it.hasNext()) {
            LocationExtractorUtils.LocationDocument next = it.next();
            String text = next.getText();
            newHashSet.addAll(createTrainData(this.featureExtraction.makeInstances(text, PalladianLocationExtractor.fetchLocations(this.locationSource, this.contextClassifier.classify(this.filter.filter(this.tagger.getAnnotations(text)), text))), next.getAnnotations()));
        }
        return newHashSet;
    }

    private Set<Trainable> createTrainData(Set<LocationFeatureExtractor.LocationInstance> set, List<LocationAnnotation> list) {
        HashSet newHashSet = CollectionHelper.newHashSet();
        int i = 0;
        for (LocationFeatureExtractor.LocationInstance locationInstance : set) {
            boolean z = false;
            Iterator<LocationAnnotation> it = list.iterator();
            while (true) {
                if (it.hasNext()) {
                    LocationAnnotation next = it.next();
                    if (locationInstance.getLatitude() != null && locationInstance.getLongitude() != null) {
                        Location location = next.getLocation();
                        boolean z2 = GeoUtils.getDistance(locationInstance, location) < 50.0d;
                        boolean commonName = locationInstance.commonName(location);
                        boolean equals = locationInstance.getType().equals(location.getType());
                        if (z2 && commonName && equals) {
                            i++;
                            z = true;
                            break;
                        }
                    }
                }
            }
            newHashSet.add(new Instance(z, locationInstance));
        }
        LOGGER.info("{} positive instances in {} ({}%)", Integer.valueOf(i), Integer.valueOf(set.size()), Double.valueOf(MathHelper.round((i / set.size()) * 100.0f, 2)));
        return newHashSet;
    }

    public static void main(String[] strArr) {
        FeatureBasedDisambiguationLearner featureBasedDisambiguationLearner = new FeatureBasedDisambiguationLearner((LocationSource) DatabaseManagerFactory.create(LocationDatabase.class, "locations"));
        File file = new File("/Users/pk/Dropbox/Uni/Datasets/TUD-Loc-2013/TUD-Loc-2013_V2/1-training");
        File file2 = new File("/Users/pk/Dropbox/Uni/Dissertation_LocationLab/LGL-converted/1-train");
        File file3 = new File("/Users/pk/Dropbox/Uni/Dissertation_LocationLab/CLUST-converted/1-train");
        featureBasedDisambiguationLearner.learn(file);
        featureBasedDisambiguationLearner.learn(file2);
        featureBasedDisambiguationLearner.learn(file3);
        featureBasedDisambiguationLearner.learn(file, file2, file3);
    }
}
