/*
 * Decompiled with CFR 0.152.
 */
package sc.fiji.analyzeSkeleton.ita;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import net.imglib2.util.ValuePair;
import org.joml.Vector3d;
import org.joml.Vector3dc;
import sc.fiji.analyzeSkeleton.Edge;
import sc.fiji.analyzeSkeleton.Graph;
import sc.fiji.analyzeSkeleton.Point;
import sc.fiji.analyzeSkeleton.Vertex;
import sc.fiji.analyzeSkeleton.ita.PointUtils;

public final class GraphPruning {
    private GraphPruning() {
    }

    public static ValuePair<Graph, double[]> pruneShortEdges(Graph graph, double minDistance, boolean iterate, boolean clustered) {
        return GraphPruning.pruneShortEdges(graph, minDistance, iterate, clustered, new double[]{1.0, 1.0, 1.0});
    }

    public static ValuePair<Graph, double[]> pruneShortEdges(Graph graph, double minDistance, boolean iterate, boolean clustered, double[] voxelSize) {
        Graph pruned = graph.clone();
        boolean prune = true;
        GraphPruning.removeLoops(pruned);
        int sansLoops = pruned.getEdges().size();
        int loops = graph.getEdges().size() - sansLoops;
        GraphPruning.removeParallelEdges(pruned);
        int parallel = sansLoops - pruned.getEdges().size();
        int otherPruned = 0;
        int deadEnds = 0;
        pruned.getEdges().forEach(e -> GraphPruning.euclideanDistance(e, voxelSize));
        while (prune) {
            int startSize = pruned.getVertices().size();
            int startEdges = pruned.getEdges().size();
            GraphPruning.pruneDeadEnds(pruned, minDistance);
            pruned = clustered ? GraphPruning.clusteredPruning(pruned, minDistance, voxelSize) : GraphPruning.edgewisePruning(pruned, minDistance, voxelSize);
            GraphPruning.removeParallelEdges(pruned);
            otherPruned += startEdges - (deadEnds += startEdges - pruned.getEdges().size()) - pruned.getEdges().size();
            int cleanedSize = pruned.getVertices().size();
            prune = iterate && startSize != cleanedSize;
        }
        int total = graph.getEdges().size();
        double[] stats = GraphPruning.calculateStats(loops, deadEnds, parallel, otherPruned, total);
        return new ValuePair((Object)pruned, (Object)stats);
    }

    public static void removeLoops(Graph graph) {
        List<Edge> loops = graph.getEdges().stream().filter(GraphPruning::isLoop).collect(Collectors.toList());
        loops.forEach(GraphPruning::removeBranchFromEndpoints);
        graph.getEdges().removeAll(loops);
    }

    public static void removeParallelEdges(Graph graph) {
        Map<Vertex, Integer> idMap = GraphPruning.mapVertexIds(graph.getVertices());
        HashSet connections = new HashSet();
        ArrayList parallelEdges = new ArrayList();
        graph.getEdges().forEach(edge -> {
            long hash = GraphPruning.connectionHash(edge, idMap);
            if (!connections.add(hash)) {
                parallelEdges.add(edge);
            }
        });
        parallelEdges.forEach(GraphPruning::removeBranchFromEndpoints);
        graph.getEdges().removeAll(parallelEdges);
    }

    private static double[] calculateStats(int loops, int deadEnds, int parallel, int otherPruned, int total) {
        double deadEndPct = 100.0 * (double)deadEnds / (double)total;
        double parallelPct = 100.0 * (double)parallel / (double)total;
        double loopPct = 100.0 * (double)loops / (double)total;
        double otherPrunedPct = 100.0 * (double)otherPruned / (double)total;
        return new double[]{loopPct, deadEndPct, parallelPct, otherPrunedPct, total};
    }

    private static Graph clusteredPruning(Graph graph, double minDistance, double[] voxelSize) {
        List<Set<Vertex>> clusters = GraphPruning.findClusters(graph, minDistance);
        List<Vertex> clusterCentres = clusters.stream().map(GraphPruning::getClusterCentre).collect(Collectors.toList());
        Map<Edge, Edge> replacements = GraphPruning.mapReplacementEdges(clusters, clusterCentres);
        return GraphPruning.createCleanGraph(graph, clusters, clusterCentres, replacements, voxelSize);
    }

    private static long connectionHash(Edge e, Map<Vertex, Integer> idMap) {
        long b;
        long nVertices = idMap.size();
        long a = idMap.get(e.getV1()).intValue();
        return a < (b = (long)idMap.get(e.getV2()).intValue()) ? a * nVertices + b : b * nVertices + a;
    }

    private static Graph createCleanGraph(Graph graph, Collection<Set<Vertex>> clusters, Collection<Vertex> clusterCentres, Map<Edge, Edge> replacements, double[] voxelSize) {
        Collection clusterEdges = replacements.values().stream().peek(e -> GraphPruning.euclideanDistance(e, voxelSize)).collect(Collectors.toList());
        List<Edge> nonClusterEdges = graph.getEdges().stream().filter(e -> !replacements.containsKey(e) && GraphPruning.isNotInClusters(e, clusters)).collect(Collectors.toList());
        Graph cleanGraph = new Graph();
        HashSet<Edge> cleanEdges = new HashSet<Edge>();
        cleanEdges.addAll(nonClusterEdges);
        cleanEdges.addAll(clusterEdges);
        cleanGraph.getEdges().addAll(cleanEdges);
        clusterCentres.forEach(cleanGraph::addVertex);
        GraphPruning.endpoints(nonClusterEdges).forEach(cleanGraph::addVertex);
        GraphPruning.endpoints(clusterEdges).forEach(cleanGraph::addVertex);
        GraphPruning.getUnconnectedVertices(graph).forEach(cleanGraph::addVertex);
        GraphPruning.removeDanglingBranches(cleanGraph);
        return cleanGraph;
    }

    private static Graph edgewisePruning(Graph graph, double minDistance, double[] voxelSize) {
        Graph cleanGraph = graph.clone();
        List<Edge> innerEdges = cleanGraph.getEdges().stream().filter(e -> GraphPruning.isShort(e, minDistance) && !GraphPruning.isDeadEnd(e)).collect(Collectors.toList());
        for (Edge innerEdge : innerEdges) {
            ArrayList<Set<Vertex>> pairs = new ArrayList<Set<Vertex>>();
            pairs.add(GraphPruning.getEndpoints(innerEdge));
            List<Vertex> centroids = pairs.stream().map(GraphPruning::getClusterCentre).collect(Collectors.toList());
            Map<Edge, Edge> replacements = GraphPruning.mapReplacementEdges(pairs, centroids);
            Graph tmp = GraphPruning.createCleanGraph(cleanGraph, pairs, centroids, replacements, voxelSize);
            GraphPruning.updateInnerEdges(innerEdges, replacements);
            cleanGraph = tmp;
        }
        return cleanGraph;
    }

    private static Stream<Vertex> endpoints(Collection<Edge> edges) {
        return edges.stream().flatMap(e -> Stream.of(e.getV1(), e.getV2())).distinct();
    }

    private static void euclideanDistance(Edge e, double[] voxelSize) {
        Vector3d centre = PointUtils.centroid(e.getV1().getPoints());
        Vector3d centre2 = PointUtils.centroid(e.getV2().getPoints());
        centre.sub((Vector3dc)centre2);
        double l = GraphPruning.length(centre, voxelSize);
        e.setLength(l);
    }

    private static Set<Vertex> fillCluster(Vertex start, double minDistance) {
        HashSet<Vertex> cluster = new HashSet<Vertex>();
        Stack<Vertex> stack = new Stack<Vertex>();
        stack.push(start);
        while (!stack.isEmpty()) {
            Vertex vertex = (Vertex)stack.pop();
            cluster.add(vertex);
            Set freeNeighbours = vertex.getBranches().stream().filter(e -> GraphPruning.isShort(e, minDistance)).map(e -> e.getOppositeVertex(vertex)).filter(v -> !cluster.contains(v)).collect(Collectors.toSet());
            stack.addAll(freeNeighbours);
        }
        return cluster;
    }

    private static List<Vertex> findClusterVertices(Graph graph, double minDistance) {
        return graph.getEdges().stream().filter(e -> GraphPruning.isShort(e, minDistance)).flatMap(e -> Stream.of(e.getV1(), e.getV2())).distinct().collect(Collectors.toList());
    }

    private static Set<Vertex> getEndpoints(Edge edge) {
        HashSet<Vertex> pair = new HashSet<Vertex>();
        pair.add(edge.getV1());
        pair.add(edge.getV2());
        return pair;
    }

    private static Stream<Vertex> getUnconnectedVertices(Graph graph) {
        return graph.getVertices().stream().filter(v -> v.getBranches().isEmpty()).distinct();
    }

    private static boolean isDeadEnd(Edge e) {
        return Stream.of(e.getV1(), e.getV2()).filter(v -> v.getBranches().size() == 1).count() == 1L;
    }

    private static boolean isLoop(Edge edge) {
        return edge.getV1() != null && edge.getV1() == edge.getV2();
    }

    private static boolean isNotInClusters(Edge e, Collection<Set<Vertex>> clusters) {
        return clusters.stream().noneMatch(c -> c.contains(e.getV1()) && c.contains(e.getV2()));
    }

    private static double length(Vector3d v, double[] voxelSize) {
        double x = v.x * voxelSize[0];
        double y = v.y * voxelSize[1];
        double z = v.z * voxelSize[2];
        double sqSum = DoubleStream.of(x, y, z).map(d -> d * d).sum();
        return Math.sqrt(sqSum);
    }

    private static Map<Edge, Edge> mapReplacementEdges(List<Set<Vertex>> clusters, List<Vertex> clusterCentres) {
        HashMap<Edge, Edge> replacements = new HashMap<Edge, Edge>();
        for (int i = 0; i < clusters.size(); ++i) {
            Collection cluster = clusters.get(i);
            Vertex centre = clusterCentres.get(i);
            Set<Edge> outerEdges = GraphPruning.findEdgesWithOneEndInCluster(cluster);
            for (Edge outerEdge : outerEdges) {
                Edge oldEdge = replacements.getOrDefault(outerEdge, outerEdge);
                Edge replacement = GraphPruning.replaceEdge(oldEdge, cluster, centre);
                replacements.put(outerEdge, replacement);
            }
        }
        return replacements;
    }

    private static Map<Vertex, Integer> mapVertexIds(List<Vertex> vertices) {
        return IntStream.range(0, vertices.size()).boxed().collect(Collectors.toMap(vertices::get, Function.identity()));
    }

    private static void pruneDeadEnds(Graph graph, double minDistance) {
        List<Edge> deadEnds = graph.getEdges().stream().filter(e -> GraphPruning.isDeadEnd(e) && GraphPruning.isShort(e, minDistance)).collect(Collectors.toList());
        List terminals = deadEnds.stream().flatMap(e -> Stream.of(e.getV1(), e.getV2())).filter(v -> v.getBranches().size() == 1).collect(Collectors.toList());
        graph.getVertices().removeAll(terminals);
        deadEnds.forEach(GraphPruning::removeBranchFromEndpoints);
        graph.getEdges().removeAll(deadEnds);
    }

    private static int[] realToIntegerCoordinate(Vector3d v) {
        return Stream.of(v.x, v.y, v.z).mapToInt(d -> Double.isNaN(d) ? Integer.MAX_VALUE : (int)Math.round(d)).toArray();
    }

    private static void removeBranchFromEndpoints(Edge branch) {
        branch.getV1().getBranches().remove(branch);
        branch.getV2().getBranches().remove(branch);
    }

    private static void removeDanglingBranches(Graph graph) {
        graph.getVertices().stream().map(Vertex::getBranches).forEach(b -> b.removeIf(e -> !graph.getEdges().contains(e)));
    }

    private static Edge replaceEdge(Edge edge, Collection<Vertex> cluster, Vertex centre) {
        Edge replacement;
        Vertex v1 = edge.getV1();
        Vertex v2 = edge.getV2();
        if (cluster.contains(v1)) {
            replacement = new Edge(centre, v2, null, 0.0);
            replacement.getV1().setBranch(replacement);
            replacement.getV2().setBranch(replacement);
        } else if (cluster.contains(v2)) {
            replacement = new Edge(v1, centre, null, 0.0);
            replacement.getV1().setBranch(replacement);
            replacement.getV2().setBranch(replacement);
        } else {
            return null;
        }
        return replacement;
    }

    private static void updateInnerEdges(List<Edge> innerEdges, Map<Edge, Edge> replacements) {
        innerEdges.stream().filter(replacements::containsKey).forEach(e -> {
            int i = innerEdges.indexOf(e);
            Edge replacement = (Edge)replacements.get(e);
            innerEdges.set(i, replacement);
        });
    }

    static List<Set<Vertex>> findClusters(Graph graph, double minDistance) {
        ArrayList<Set<Vertex>> clusters = new ArrayList<Set<Vertex>>();
        List<Vertex> clusterVertices = GraphPruning.findClusterVertices(graph, minDistance);
        while (!clusterVertices.isEmpty()) {
            Vertex start = clusterVertices.get(0);
            Set<Vertex> cluster = GraphPruning.fillCluster(start, minDistance);
            clusters.add(cluster);
            clusterVertices.removeAll(cluster);
        }
        return clusters;
    }

    static Set<Edge> findEdgesWithOneEndInCluster(Collection<Vertex> cluster) {
        Map edgeCounts = cluster.stream().flatMap(v -> v.getBranches().stream()).collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
        return edgeCounts.keySet().stream().filter(e -> (Long)edgeCounts.get(e) == 1L).collect(Collectors.toSet());
    }

    static Vertex getClusterCentre(Set<Vertex> cluster) {
        Collection points = cluster.stream().flatMap(c -> c.getPoints().stream()).collect(Collectors.toList());
        Vector3d centroid = PointUtils.centroid(points);
        int[] coordinates = GraphPruning.realToIntegerCoordinate(centroid);
        Vertex vertex = new Vertex();
        vertex.addPoint(new Point(coordinates[0], coordinates[1], coordinates[2]));
        return vertex;
    }

    static boolean isShort(Edge e, double minLength) {
        return e.getLength() < minLength;
    }
}

