package org.knime.knip.core.algorithm.extendedem;

import java.util.Random;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;

/* loaded from: input_file:knip-core.jar:org/knime/knip/core/algorithm/extendedem/ExtendedEM.class */
public class ExtendedEM {
    private static final long serialVersionUID = 8348181483812829475L;
    private double[][][] m_modelNormal;
    private double[] m_minStdDevPerAtt;
    private double[][] m_weights;
    private double[] m_priors;
    private int m_numClusters;
    private int m_numAttribs;
    private double[] m_minValues;
    private double[] m_maxValues;
    private Random m_rr;
    private boolean m_verbose;
    private int m_seedDefault;
    private final int m_seed;
    private InstancesTmp m_centers;
    private int[] m_clusterSizes;
    private static double m_normConst = Math.log(Math.sqrt(6.283185307179586d));
    private final double m_minStdDev = 1.0E-6d;
    private InstancesTmp m_theInstances = null;
    private int m_maxIterations = 100;

    /* loaded from: input_file:knip-core.jar:org/knime/knip/core/algorithm/extendedem/ExtendedEM$DiscreteEstimator.class */
    public class DiscreteEstimator {
        private final double[] m_Counts;
        private double m_SumOfCounts;

        public DiscreteEstimator(int i, boolean z) {
            this.m_Counts = new double[i];
            this.m_SumOfCounts = CMAESOptimizer.DEFAULT_STOPFITNESS;
            if (z) {
                for (int i2 = 0; i2 < i; i2++) {
                    this.m_Counts[i2] = 1.0d;
                }
                this.m_SumOfCounts = i;
            }
        }

        public void addValue(double d, double d2) {
            double[] dArr = this.m_Counts;
            int i = (int) d;
            dArr[i] = dArr[i] + d2;
            this.m_SumOfCounts += d2;
        }

        public double getProbability(double d) {
            return this.m_SumOfCounts == CMAESOptimizer.DEFAULT_STOPFITNESS ? CMAESOptimizer.DEFAULT_STOPFITNESS : this.m_Counts[(int) d] / this.m_SumOfCounts;
        }
    }

    private void normalize(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        normalize(dArr, d);
    }

    private void normalize(double[] dArr, double d) {
        if (Double.isNaN(d)) {
            throw new IllegalArgumentException("Can't normalize array. Sum is NaN.");
        }
        if (d == CMAESOptimizer.DEFAULT_STOPFITNESS) {
            throw new IllegalArgumentException("Can't normalize array. Sum is zero.");
        }
        for (int i = 0; i < dArr.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] / d;
        }
    }

    public void setNumClusters(int i) {
        if (i == 0) {
            throw new RuntimeException("Number of clusters must be > 0. (or -1 to select by cross validation).");
        }
        if (i < 0) {
            this.m_numClusters = -1;
        } else {
            this.m_numClusters = i;
        }
    }

    public void setCenters(InstancesTmp instancesTmp) {
        this.m_centers = instancesTmp;
    }

    public void setClusterSizes(int[] iArr) {
        this.m_clusterSizes = (int[]) iArr.clone();
    }

    public void setMaxInterations(int i) {
        this.m_maxIterations = i;
    }

    private void EM_Init(InstancesTmp instancesTmp) {
        this.m_weights = new double[instancesTmp.numInstances()][this.m_numClusters];
        this.m_modelNormal = new double[this.m_numClusters][this.m_numAttribs][3];
        this.m_priors = new double[this.m_numClusters];
        int[] iArr = this.m_clusterSizes;
        InstancesTmp instancesTmp2 = this.m_centers;
        for (int i = 0; i < this.m_numClusters; i++) {
            InstanceTmp instance = instancesTmp2.instance(i);
            for (int i2 = 0; i2 < this.m_numAttribs; i2++) {
                double d = this.m_minStdDevPerAtt != null ? this.m_minStdDevPerAtt[i2] : 1.0E-6d;
                this.m_modelNormal[i][i2][0] = instance.isMissing(i2) ? instancesTmp.meanOrMode(i2) : instance.value(i2);
                double d2 = (this.m_maxValues[i2] - this.m_minValues[i2]) / (2 * this.m_numClusters);
                if (d2 < d) {
                    d2 = instancesTmp.attributeStats(i2).m_numericStats.getStdDev();
                    if (Double.isInfinite(d2)) {
                        d2 = d;
                    }
                    if (d2 < d) {
                        d2 = d;
                    }
                }
                if (d2 <= CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    d2 = 1.0E-6d;
                }
                this.m_modelNormal[i][i2][1] = d2;
                this.m_modelNormal[i][i2][2] = 1.0d;
            }
        }
        for (int i3 = 0; i3 < this.m_numClusters; i3++) {
            this.m_priors[i3] = iArr[i3];
        }
        normalize(this.m_priors);
    }

    private void estimate_priors(InstancesTmp instancesTmp) {
        for (int i = 0; i < this.m_numClusters; i++) {
            this.m_priors[i] = 0.0d;
        }
        for (int i2 = 0; i2 < instancesTmp.numInstances(); i2++) {
            for (int i3 = 0; i3 < this.m_numClusters; i3++) {
                double[] dArr = this.m_priors;
                int i4 = i3;
                dArr[i4] = dArr[i4] + (instancesTmp.instance(i2).weight() * this.m_weights[i2][i3]);
            }
        }
        normalize(this.m_priors);
    }

    private void new_estimators() {
        for (int i = 0; i < this.m_numClusters; i++) {
            for (int i2 = 0; i2 < this.m_numAttribs; i2++) {
                double[] dArr = this.m_modelNormal[i][i2];
                double[] dArr2 = this.m_modelNormal[i][i2];
                this.m_modelNormal[i][i2][2] = 0.0d;
                dArr2[1] = 0.0d;
                dArr[0] = 0.0d;
            }
        }
    }

    private void M(InstancesTmp instancesTmp) {
        new_estimators();
        estimate_priors(instancesTmp);
        for (int i = 0; i < this.m_numClusters; i++) {
            for (int i2 = 0; i2 < this.m_numAttribs; i2++) {
                for (int i3 = 0; i3 < instancesTmp.numInstances(); i3++) {
                    InstanceTmp instance = instancesTmp.instance(i3);
                    if (!instance.isMissing(i2)) {
                        double[] dArr = this.m_modelNormal[i][i2];
                        dArr[0] = dArr[0] + (instance.value(i2) * instance.weight() * this.m_weights[i3][i]);
                        double[] dArr2 = this.m_modelNormal[i][i2];
                        dArr2[2] = dArr2[2] + (instance.weight() * this.m_weights[i3][i]);
                        double[] dArr3 = this.m_modelNormal[i][i2];
                        dArr3[1] = dArr3[1] + (instance.value(i2) * instance.value(i2) * instance.weight() * this.m_weights[i3][i]);
                    }
                }
            }
        }
        for (int i4 = 0; i4 < this.m_numAttribs; i4++) {
            if (!instancesTmp.attribute(i4).isNominal()) {
                for (int i5 = 0; i5 < this.m_numClusters; i5++) {
                    if (this.m_modelNormal[i5][i4][2] <= CMAESOptimizer.DEFAULT_STOPFITNESS) {
                        this.m_modelNormal[i5][i4][1] = Double.MAX_VALUE;
                        this.m_modelNormal[i5][i4][0] = 1.0E-6d;
                    } else {
                        this.m_modelNormal[i5][i4][1] = (this.m_modelNormal[i5][i4][1] - ((this.m_modelNormal[i5][i4][0] * this.m_modelNormal[i5][i4][0]) / this.m_modelNormal[i5][i4][2])) / this.m_modelNormal[i5][i4][2];
                        if (this.m_modelNormal[i5][i4][1] < CMAESOptimizer.DEFAULT_STOPFITNESS) {
                            this.m_modelNormal[i5][i4][1] = 0.0d;
                        }
                        double d = this.m_minStdDevPerAtt != null ? this.m_minStdDevPerAtt[i4] : 1.0E-6d;
                        this.m_modelNormal[i5][i4][1] = Math.sqrt(this.m_modelNormal[i5][i4][1]);
                        if (this.m_modelNormal[i5][i4][1] <= d) {
                            this.m_modelNormal[i5][i4][1] = instancesTmp.attributeStats(i4).m_numericStats.getStdDev();
                            if (this.m_modelNormal[i5][i4][1] <= d) {
                                this.m_modelNormal[i5][i4][1] = d;
                            }
                        }
                        if (this.m_modelNormal[i5][i4][1] <= CMAESOptimizer.DEFAULT_STOPFITNESS) {
                            this.m_modelNormal[i5][i4][1] = 1.0E-6d;
                        }
                        if (Double.isInfinite(this.m_modelNormal[i5][i4][1])) {
                            this.m_modelNormal[i5][i4][1] = 1.0E-6d;
                        }
                        double[] dArr4 = this.m_modelNormal[i5][i4];
                        dArr4[0] = dArr4[0] / this.m_modelNormal[i5][i4][2];
                    }
                }
            }
        }
    }

    private double E(InstancesTmp instancesTmp, boolean z) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < instancesTmp.numInstances(); i++) {
            InstanceTmp instance = instancesTmp.instance(i);
            d += instance.weight() * logDensityForInstance(instance);
            d2 += instance.weight();
            if (z) {
                this.m_weights[i] = distributionForInstance(instance);
            }
        }
        return d / d2;
    }

    public ExtendedEM() {
        this.m_seedDefault = 1;
        this.m_seed = this.m_seedDefault;
        this.m_seedDefault = 100;
    }

    public double[][][] getClusterModelsNumericAtts() {
        return this.m_modelNormal;
    }

    private void updateMinMax(InstanceTmp instanceTmp) {
        for (int i = 0; i < this.m_theInstances.numAttributes(); i++) {
            if (!instanceTmp.isMissing(i)) {
                if (Double.isNaN(this.m_minValues[i])) {
                    this.m_minValues[i] = instanceTmp.value(i);
                    this.m_maxValues[i] = instanceTmp.value(i);
                } else if (instanceTmp.value(i) < this.m_minValues[i]) {
                    this.m_minValues[i] = instanceTmp.value(i);
                } else if (instanceTmp.value(i) > this.m_maxValues[i]) {
                    this.m_maxValues[i] = instanceTmp.value(i);
                }
            }
        }
    }

    public void buildClusterer(InstancesTmp instancesTmp) {
        this.m_theInstances = instancesTmp;
        this.m_minValues = new double[this.m_theInstances.numAttributes()];
        this.m_maxValues = new double[this.m_theInstances.numAttributes()];
        for (int i = 0; i < this.m_theInstances.numAttributes(); i++) {
            this.m_maxValues[i] = Double.NaN;
            this.m_minValues[i] = Double.NaN;
        }
        for (int i2 = 0; i2 < this.m_theInstances.numInstances(); i2++) {
            updateMinMax(this.m_theInstances.instance(i2));
        }
        doEM();
        this.m_theInstances = new InstancesTmp(this.m_theInstances, 0);
    }

    private void doEM() {
        this.m_rr = new Random(getSeed());
        for (int i = 0; i < 10; i++) {
            this.m_rr.nextDouble();
        }
        this.m_numAttribs = this.m_theInstances.numAttributes();
        EM_Init(this.m_theInstances);
        iterate(this.m_theInstances, this.m_verbose);
    }

    private double iterate(InstancesTmp instancesTmp, boolean z) {
        double d = 0.0d;
        boolean z2 = false;
        int seed = getSeed();
        int i = 0;
        while (!z2) {
            for (int i2 = 0; i2 < this.m_maxIterations; i2++) {
                try {
                    double d2 = d;
                    d = E(instancesTmp, true);
                    if (z) {
                        System.out.println("Loglikely: " + d);
                    }
                    if (i2 > 0 && d - d2 < 1.0E-6d) {
                        break;
                    }
                    M(instancesTmp);
                } catch (Exception e) {
                    e.printStackTrace();
                    seed++;
                    i++;
                    this.m_rr = new Random(seed);
                    for (int i3 = 0; i3 < 10; i3++) {
                        this.m_rr.nextDouble();
                        this.m_rr.nextInt();
                    }
                    if (i > 5) {
                        this.m_numClusters--;
                        i = 0;
                    }
                    EM_Init(this.m_theInstances);
                }
            }
            z2 = true;
        }
        return d;
    }

    public int getSeed() {
        return this.m_seed;
    }

    private int maxIndex(double[] dArr) {
        double d = 0.0d;
        int i = 0;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (i2 == 0 || dArr[i2] > d) {
                i = i2;
                d = dArr[i2];
            }
        }
        return i;
    }

    public double logDensityForInstance(InstanceTmp instanceTmp) {
        double[] logJointDensitiesForInstance = logJointDensitiesForInstance(instanceTmp);
        double d = logJointDensitiesForInstance[maxIndex(logJointDensitiesForInstance)];
        double d2 = 0.0d;
        for (double d3 : logJointDensitiesForInstance) {
            d2 += Math.exp(d3 - d);
        }
        return d + Math.log(d2);
    }

    private double[] logs2probs(double[] dArr) {
        double d = dArr[maxIndex(dArr)];
        double d2 = 0.0d;
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = Math.exp(dArr[i] - d);
            d2 += dArr2[i];
        }
        normalize(dArr2, d2);
        return dArr2;
    }

    public double[] logDensityPerClusterForInstance(InstanceTmp instanceTmp) {
        double[] dArr = new double[this.m_numClusters];
        for (int i = 0; i < this.m_numClusters; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < this.m_numAttribs; i2++) {
                if (!instanceTmp.isMissing(i2)) {
                    d += logNormalDens(instanceTmp.value(i2), this.m_modelNormal[i][i2][0], this.m_modelNormal[i][i2][1]);
                }
            }
            dArr[i] = d;
        }
        return dArr;
    }

    public double[] clusterPriors() {
        double[] dArr = new double[this.m_priors.length];
        System.arraycopy(this.m_priors, 0, dArr, 0, dArr.length);
        return dArr;
    }

    private double logNormalDens(double d, double d2, double d3) {
        double d4 = d - d2;
        return ((-((d4 * d4) / ((2.0d * d3) * d3))) - m_normConst) - Math.log(d3);
    }

    public double[] distributionForInstance(InstanceTmp instanceTmp) {
        return logs2probs(logJointDensitiesForInstance(instanceTmp));
    }

    public double[] logJointDensitiesForInstance(InstanceTmp instanceTmp) {
        double[] logDensityPerClusterForInstance = logDensityPerClusterForInstance(instanceTmp);
        double[] clusterPriors = clusterPriors();
        for (int i = 0; i < logDensityPerClusterForInstance.length; i++) {
            if (clusterPriors[i] > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                int i2 = i;
                logDensityPerClusterForInstance[i2] = logDensityPerClusterForInstance[i2] + Math.log(clusterPriors[i]);
            }
        }
        return logDensityPerClusterForInstance;
    }
}
