package org.apache.ignite.ml.nn;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.CachePeekMode;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.functions.IgniteSupplier;
import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
import org.apache.ignite.ml.nn.initializers.MLPInitializer;
import org.apache.ignite.ml.nn.trainers.distributed.AbstractMLPGroupUpdateTrainerInput;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.util.Utils;

/* loaded from: input_file:org/apache/ignite/ml/nn/MLPGroupUpdateTrainerCacheInput.class */
public class MLPGroupUpdateTrainerCacheInput extends AbstractMLPGroupUpdateTrainerInput {
    private final IgniteCache<Integer, LabeledVector<Vector, Vector>> cache;
    private final int batchSize;
    private final MultilayerPerceptron mlp;
    private final Random rand;

    public MLPGroupUpdateTrainerCacheInput(MLPArchitecture mLPArchitecture, MLPInitializer mLPInitializer, int i, IgniteCache<Integer, LabeledVector<Vector, Vector>> igniteCache, int i2, Random random) {
        super(i);
        this.batchSize = i2;
        this.cache = igniteCache;
        this.mlp = new MultilayerPerceptron(mLPArchitecture, mLPInitializer);
        this.rand = random;
    }

    public MLPGroupUpdateTrainerCacheInput(MLPArchitecture mLPArchitecture, MLPInitializer mLPInitializer, int i, IgniteCache<Integer, LabeledVector<Vector, Vector>> igniteCache, int i2) {
        this(mLPArchitecture, mLPInitializer, i, igniteCache, i2, null);
    }

    public MLPGroupUpdateTrainerCacheInput(MLPArchitecture mLPArchitecture, int i, IgniteCache<Integer, LabeledVector<Vector, Vector>> igniteCache, int i2) {
        this(mLPArchitecture, null, i, igniteCache, i2);
    }

    @Override // org.apache.ignite.ml.trainers.local.LocalBatchTrainerInput
    public IgniteSupplier<IgniteBiTuple<Matrix, Matrix>> batchSupplier() {
        String name = this.cache.getName();
        int i = this.batchSize;
        Random random = this.rand;
        return () -> {
            Ignite localIgnite = Ignition.localIgnite();
            IgniteCache orCreateCache = localIgnite.getOrCreateCache(name);
            ArrayList arrayList = new ArrayList((Collection) localIgnite.affinity(name).mapKeysToNodes((List) IntStream.range(0, orCreateCache.size(new CachePeekMode[0])).boxed().collect(Collectors.toList())).get(localIgnite.cluster().localNode()));
            int size = arrayList.size();
            int[] selectKDistinct = Utils.selectKDistinct(size, Math.min(i, size), random);
            LabeledVector labeledVector = (LabeledVector) orCreateCache.get(arrayList.get(selectKDistinct[0]));
            DenseLocalOnHeapMatrix denseLocalOnHeapMatrix = new DenseLocalOnHeapMatrix(labeledVector.features().size(), i);
            DenseLocalOnHeapMatrix denseLocalOnHeapMatrix2 = new DenseLocalOnHeapMatrix(((Vector) labeledVector.label()).size(), i);
            for (int i2 = 0; i2 < selectKDistinct.length; i2++) {
                LabeledVector labeledVector2 = (LabeledVector) orCreateCache.get(arrayList.get(selectKDistinct[i2]));
                denseLocalOnHeapMatrix.assignColumn(i2, labeledVector2.features());
                denseLocalOnHeapMatrix2.assignColumn(i2, (Vector) labeledVector2.label());
            }
            return new IgniteBiTuple(denseLocalOnHeapMatrix, denseLocalOnHeapMatrix2);
        };
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.ignite.ml.trainers.local.LocalBatchTrainerInput
    public MultilayerPerceptron mdl() {
        return this.mlp;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -524118696:
                if (implMethodName.equals("lambda$batchSupplier$6bd1095c$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteSupplier") && serializedLambda.getFunctionalInterfaceMethodName().equals("get") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/nn/MLPGroupUpdateTrainerCacheInput") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/String;ILjava/util/Random;)Lorg/apache/ignite/lang/IgniteBiTuple;")) {
                    String str = (String) serializedLambda.getCapturedArg(0);
                    int intValue = ((Integer) serializedLambda.getCapturedArg(1)).intValue();
                    Random random = (Random) serializedLambda.getCapturedArg(2);
                    return () -> {
                        Ignite localIgnite = Ignition.localIgnite();
                        IgniteCache orCreateCache = localIgnite.getOrCreateCache(str);
                        ArrayList arrayList = new ArrayList((Collection) localIgnite.affinity(str).mapKeysToNodes((List) IntStream.range(0, orCreateCache.size(new CachePeekMode[0])).boxed().collect(Collectors.toList())).get(localIgnite.cluster().localNode()));
                        int size = arrayList.size();
                        int[] selectKDistinct = Utils.selectKDistinct(size, Math.min(intValue, size), random);
                        LabeledVector labeledVector = (LabeledVector) orCreateCache.get(arrayList.get(selectKDistinct[0]));
                        DenseLocalOnHeapMatrix denseLocalOnHeapMatrix = new DenseLocalOnHeapMatrix(labeledVector.features().size(), intValue);
                        DenseLocalOnHeapMatrix denseLocalOnHeapMatrix2 = new DenseLocalOnHeapMatrix(((Vector) labeledVector.label()).size(), intValue);
                        for (int i2 = 0; i2 < selectKDistinct.length; i2++) {
                            LabeledVector labeledVector2 = (LabeledVector) orCreateCache.get(arrayList.get(selectKDistinct[i2]));
                            denseLocalOnHeapMatrix.assignColumn(i2, labeledVector2.features());
                            denseLocalOnHeapMatrix2.assignColumn(i2, (Vector) labeledVector2.label());
                        }
                        return new IgniteBiTuple(denseLocalOnHeapMatrix, denseLocalOnHeapMatrix2);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
