package org.apache.ignite.ml.clustering.gmm;

import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;

/* loaded from: input_file:org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregator.class */
class MeanWithClusterProbAggregator implements Serializable {
    private static final long serialVersionUID = 2700985110021774629L;
    private Vector weightedXsSum;
    private double pcxiSum;
    private int rowCount;

    /* loaded from: input_file:org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregator$AggregatedStats.class */
    public static class AggregatedStats {
        private final Vector clusterProbs;
        private final List<Vector> means;

        private AggregatedStats(List<MeanWithClusterProbAggregator> list) {
            this.clusterProbs = VectorUtils.of(list.stream().mapToDouble((v0) -> {
                return v0.clusterProb();
            }).toArray());
            this.means = (List) list.stream().map((v0) -> {
                return v0.mean();
            }).collect(Collectors.toList());
        }

        public Vector clusterProbabilities() {
            return this.clusterProbs;
        }

        public List<Vector> means() {
            return this.means;
        }
    }

    MeanWithClusterProbAggregator() {
    }

    MeanWithClusterProbAggregator(Vector vector, double d, int i) {
        this.weightedXsSum = vector;
        this.pcxiSum = d;
        this.rowCount = i;
    }

    public Vector mean() {
        return this.weightedXsSum.divide(this.pcxiSum);
    }

    public double clusterProb() {
        return this.pcxiSum / this.rowCount;
    }

    public static AggregatedStats aggreateStats(Dataset<EmptyContext, GmmPartitionData> dataset, int i) {
        return new AggregatedStats((List) dataset.compute(gmmPartitionData -> {
            return map(gmmPartitionData, i);
        }, MeanWithClusterProbAggregator::reduce));
    }

    void add(Vector vector, double d) {
        A.ensure(d >= 0.0d && d <= 1.0d, "pcxi >= 0 && pcxi <= 1.");
        Vector times = vector.times(d);
        if (this.weightedXsSum == null) {
            this.weightedXsSum = times;
        } else {
            this.weightedXsSum = this.weightedXsSum.plus(times);
        }
        this.pcxiSum += d;
        this.rowCount++;
    }

    MeanWithClusterProbAggregator plus(MeanWithClusterProbAggregator meanWithClusterProbAggregator) {
        return new MeanWithClusterProbAggregator(this.weightedXsSum.plus(meanWithClusterProbAggregator.weightedXsSum), this.pcxiSum + meanWithClusterProbAggregator.pcxiSum, this.rowCount + meanWithClusterProbAggregator.rowCount);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static List<MeanWithClusterProbAggregator> map(GmmPartitionData gmmPartitionData, int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(new MeanWithClusterProbAggregator());
        }
        for (int i3 = 0; i3 < gmmPartitionData.size(); i3++) {
            for (int i4 = 0; i4 < i; i4++) {
                ((MeanWithClusterProbAggregator) arrayList.get(i4)).add(gmmPartitionData.getX(i3), gmmPartitionData.pcxi(i4, i3));
            }
        }
        return arrayList;
    }

    static List<MeanWithClusterProbAggregator> reduce(List<MeanWithClusterProbAggregator> list, List<MeanWithClusterProbAggregator> list2) {
        A.ensure((list == null && list2 == null) ? false : true, "Both partitions cannot equal to null");
        if (list == null || list.isEmpty()) {
            return list2;
        }
        if (list2 == null || list2.isEmpty()) {
            return list;
        }
        A.ensure(list.size() == list2.size(), "l.size() == r.size()");
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            arrayList.add(list.get(i).plus(list2.get(i)));
        }
        return arrayList;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -934873754:
                if (implMethodName.equals("reduce")) {
                    z = false;
                    break;
                }
                break;
            case 419724984:
                if (implMethodName.equals("lambda$aggreateStats$fde49d3f$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregator") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/List;Ljava/util/List;)Ljava/util/List;")) {
                    return MeanWithClusterProbAggregator::reduce;
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregator") && serializedLambda.getImplMethodSignature().equals("(ILorg/apache/ignite/ml/clustering/gmm/GmmPartitionData;)Ljava/util/List;")) {
                    int intValue = ((Integer) serializedLambda.getCapturedArg(0)).intValue();
                    return gmmPartitionData -> {
                        return map(gmmPartitionData, intValue);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
