/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.clustering.gmm;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.ml.clustering.gmm.GmmModel;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
import org.apache.ignite.ml.dataset.UpstreamEntry;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironment;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.structures.LabeledVector;

class GmmPartitionData
implements AutoCloseable {
    private List<LabeledVector<Double>> xs;
    private double[][] pcxi;

    GmmPartitionData(List<LabeledVector<Double>> xs, double[][] pcxi) {
        A.ensure((xs.size() == pcxi.length ? 1 : 0) != 0, (String)"xs.size() == pcxi.length");
        this.xs = xs;
        this.pcxi = pcxi;
    }

    public Vector getX(int i) {
        return this.xs.get(i).features();
    }

    static double updatePcxiAndComputeLikelihood(Dataset<EmptyContext, GmmPartitionData> dataset, Vector clusterProbs, List<MultivariateGaussianDistribution> components) {
        return (Double)dataset.compute(data -> GmmPartitionData.updatePcxi(data, clusterProbs, components), (left, right) -> GmmPartitionData.asPrimitive(left) + GmmPartitionData.asPrimitive(right));
    }

    public double pcxi(int cluster, int i) {
        return this.pcxi[i][cluster];
    }

    public void setPcxi(int cluster, int i, double value) {
        this.pcxi[i][cluster] = value;
    }

    public List<LabeledVector<Double>> getAllXs() {
        return Collections.unmodifiableList(this.xs);
    }

    @Override
    public void close() throws Exception {
    }

    static void estimateLikelihoodClusters(GmmPartitionData data, Vector[] initMeans) {
        for (int i = 0; i < data.size(); ++i) {
            int closestClusterId = -1;
            double minSquaredDist = Double.MAX_VALUE;
            Vector x = data.getX(i);
            for (int c = 0; c < initMeans.length; ++c) {
                data.setPcxi(c, i, 0.0);
                double distance = initMeans[c].getDistanceSquared(x);
                if (!(distance < minSquaredDist)) continue;
                closestClusterId = c;
                minSquaredDist = distance;
            }
            data.setPcxi(closestClusterId, i, 1.0);
        }
    }

    public int size() {
        return this.pcxi.length;
    }

    static double updatePcxi(GmmPartitionData data, Vector clusterProbs, List<MultivariateGaussianDistribution> components) {
        GmmModel model = new GmmModel(clusterProbs, components);
        double maxProb = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < data.size(); ++i) {
            int c;
            Vector x = data.getX(i);
            double xProb = model.prob(x);
            if (xProb > maxProb) {
                maxProb = xProb;
            }
            double normalizer = 0.0;
            for (c = 0; c < clusterProbs.size(); ++c) {
                normalizer += components.get(c).prob(x) * clusterProbs.get(c);
            }
            for (c = 0; c < clusterProbs.size(); ++c) {
                data.pcxi[i][c] = components.get(c).prob(x) * clusterProbs.get(c) / normalizer;
            }
        }
        return maxProb;
    }

    private static double asPrimitive(Double val) {
        return val == null ? 0.0 : val;
    }

    static class Builder<K, V>
    implements PartitionDataBuilder<K, V, EmptyContext, GmmPartitionData> {
        private static final long serialVersionUID = 1847063348042022561L;
        private final Preprocessor<K, V> preprocessor;
        private final int countOfComponents;

        public Builder(Preprocessor<K, V> preprocessor, int countOfComponents) {
            this.preprocessor = preprocessor;
            this.countOfComponents = countOfComponents;
        }

        @Override
        public GmmPartitionData build(LearningEnvironment env, Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, EmptyContext ctx) {
            int rowsCount = Math.toIntExact(upstreamDataSize);
            ArrayList<LabeledVector<Double>> xs = new ArrayList<LabeledVector<Double>>(rowsCount);
            double[][] pcxi = new double[rowsCount][this.countOfComponents];
            while (upstreamData.hasNext()) {
                UpstreamEntry<K, V> entry = upstreamData.next();
                LabeledVector x = (LabeledVector)this.preprocessor.apply(entry.getKey(), entry.getValue());
                xs.add(x);
            }
            return new GmmPartitionData(xs, pcxi);
        }
    }
}

