package org.apache.ignite.ml.clustering.gmm;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
import org.apache.ignite.ml.dataset.UpstreamEntry;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironment;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.structures.LabeledVector;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:org/apache/ignite/ml/clustering/gmm/GmmPartitionData.class */
public class GmmPartitionData implements AutoCloseable {
    private List<LabeledVector<Double>> xs;
    private double[][] pcxi;

    /* loaded from: input_file:org/apache/ignite/ml/clustering/gmm/GmmPartitionData$Builder.class */
    static class Builder<K, V> implements PartitionDataBuilder<K, V, EmptyContext, GmmPartitionData> {
        private static final long serialVersionUID = 1847063348042022561L;
        private final Preprocessor<K, V> preprocessor;
        private final int countOfComponents;

        public Builder(Preprocessor<K, V> preprocessor, int i) {
            this.preprocessor = preprocessor;
            this.countOfComponents = i;
        }

        @Override // org.apache.ignite.ml.dataset.PartitionDataBuilder
        public GmmPartitionData build(LearningEnvironment learningEnvironment, Iterator<UpstreamEntry<K, V>> it, long j, EmptyContext emptyContext) {
            int intExact = Math.toIntExact(j);
            ArrayList arrayList = new ArrayList(intExact);
            double[][] dArr = new double[intExact][this.countOfComponents];
            while (it.hasNext()) {
                UpstreamEntry<K, V> next = it.next();
                arrayList.add((LabeledVector) this.preprocessor.apply(next.getKey(), next.getValue()));
            }
            return new GmmPartitionData(arrayList, dArr);
        }
    }

    GmmPartitionData(List<LabeledVector<Double>> list, double[][] dArr) {
        A.ensure(list.size() == dArr.length, "xs.size() == pcxi.length");
        this.xs = list;
        this.pcxi = dArr;
    }

    public Vector getX(int i) {
        return this.xs.get(i).features();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double updatePcxiAndComputeLikelihood(Dataset<EmptyContext, GmmPartitionData> dataset, Vector vector, List<MultivariateGaussianDistribution> list) {
        return ((Double) dataset.compute(gmmPartitionData -> {
            return Double.valueOf(updatePcxi(gmmPartitionData, vector, list));
        }, (d, d2) -> {
            return Double.valueOf(asPrimitive(d) + asPrimitive(d2));
        })).doubleValue();
    }

    public double pcxi(int i, int i2) {
        return this.pcxi[i2][i];
    }

    public void setPcxi(int i, int i2, double d) {
        this.pcxi[i2][i] = d;
    }

    public List<LabeledVector<Double>> getAllXs() {
        return Collections.unmodifiableList(this.xs);
    }

    @Override // java.lang.AutoCloseable
    public void close() throws Exception {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void estimateLikelihoodClusters(GmmPartitionData gmmPartitionData, Vector[] vectorArr) {
        for (int i = 0; i < gmmPartitionData.size(); i++) {
            int i2 = -1;
            double d = Double.MAX_VALUE;
            Vector x = gmmPartitionData.getX(i);
            for (int i3 = 0; i3 < vectorArr.length; i3++) {
                gmmPartitionData.setPcxi(i3, i, 0.0d);
                double distanceSquared = vectorArr[i3].getDistanceSquared(x);
                if (distanceSquared < d) {
                    i2 = i3;
                    d = distanceSquared;
                }
            }
            gmmPartitionData.setPcxi(i2, i, 1.0d);
        }
    }

    public int size() {
        return this.pcxi.length;
    }

    static double updatePcxi(GmmPartitionData gmmPartitionData, Vector vector, List<MultivariateGaussianDistribution> list) {
        GmmModel gmmModel = new GmmModel(vector, list);
        double d = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < gmmPartitionData.size(); i++) {
            Vector x = gmmPartitionData.getX(i);
            double prob = gmmModel.prob(x);
            if (prob > d) {
                d = prob;
            }
            double d2 = 0.0d;
            for (int i2 = 0; i2 < vector.size(); i2++) {
                d2 += list.get(i2).prob(x) * vector.get(i2);
            }
            for (int i3 = 0; i3 < vector.size(); i3++) {
                gmmPartitionData.pcxi[i][i3] = (list.get(i3).prob(x) * vector.get(i3)) / d2;
            }
        }
        return d;
    }

    private static double asPrimitive(Double d) {
        if (d == null) {
            return 0.0d;
        }
        return d.doubleValue();
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -252259381:
                if (implMethodName.equals("lambda$updatePcxiAndComputeLikelihood$b091865e$1")) {
                    z = true;
                    break;
                }
                break;
            case 163767649:
                if (implMethodName.equals("lambda$updatePcxiAndComputeLikelihood$71a6c1f3$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/clustering/gmm/GmmPartitionData") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/math/primitives/vector/Vector;Ljava/util/List;Lorg/apache/ignite/ml/clustering/gmm/GmmPartitionData;)Ljava/lang/Double;")) {
                    Vector vector = (Vector) serializedLambda.getCapturedArg(0);
                    List list = (List) serializedLambda.getCapturedArg(1);
                    return gmmPartitionData -> {
                        return Double.valueOf(updatePcxi(gmmPartitionData, vector, list));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/clustering/gmm/GmmPartitionData") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Double;Ljava/lang/Double;)Ljava/lang/Double;")) {
                    return (d, d2) -> {
                        return Double.valueOf(asPrimitive(d) + asPrimitive(d2));
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
