package org.apache.ignite.ml.recommendation;

import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.apache.ignite.binary.BinaryObject;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironment;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.recommendation.util.MatrixFactorizationGradient;
import org.apache.ignite.ml.recommendation.util.RecommendationBinaryDatasetDataBuilder;
import org.apache.ignite.ml.recommendation.util.RecommendationDatasetData;
import org.apache.ignite.ml.recommendation.util.RecommendationDatasetDataBuilder;
import org.apache.ignite.ml.util.generators.DataStreamGenerator;

/* loaded from: input_file:org/apache/ignite/ml/recommendation/RecommendationTrainer.class */
public class RecommendationTrainer {
    private LearningEnvironmentBuilder environmentBuilder = LearningEnvironmentBuilder.defaultBuilder();
    private LearningEnvironment trainerEnvironment = this.environmentBuilder.buildForTrainer();
    private int batchSize = DataStreamGenerator.FILL_CACHE_BATCH_SIZE;
    private double regParam = 0.0d;
    private double learningRate = 10.0d;
    private int maxIterations = DataStreamGenerator.FILL_CACHE_BATCH_SIZE;
    private double minMdlImprovement = 0.0d;
    private int k = 10;

    public RecommendationModel<Serializable, Serializable> fit(DatasetBuilder<Object, BinaryObject> datasetBuilder, String str, String str2, String str3) {
        try {
            Dataset<C, D> build = datasetBuilder.build(this.environmentBuilder, new EmptyContextBuilder(), new RecommendationBinaryDatasetDataBuilder(str, str2, str3), this.trainerEnvironment);
            Throwable th = null;
            try {
                RecommendationModel<Serializable, Serializable> train = train(build);
                if (build != 0) {
                    if (0 != 0) {
                        try {
                            build.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        build.close();
                    }
                }
                return train;
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public <K, O extends Serializable, S extends Serializable> RecommendationModel<O, S> fit(DatasetBuilder<K, ? extends ObjectSubjectRatingTriplet<O, S>> datasetBuilder) {
        try {
            Dataset<EmptyContext, RecommendationDatasetData<O, S>> build = datasetBuilder.build(this.environmentBuilder, new EmptyContextBuilder(), new RecommendationDatasetDataBuilder(), this.trainerEnvironment);
            Throwable th = null;
            try {
                try {
                    RecommendationModel<O, S> train = train(build);
                    if (build != null) {
                        if (0 != 0) {
                            try {
                                build.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            build.close();
                        }
                    }
                    return train;
                } finally {
                }
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private <O extends Serializable, S extends Serializable> RecommendationModel<O, S> train(Dataset<EmptyContext, RecommendationDatasetData<O, S>> dataset) {
        Set set = (Set) dataset.compute((v0) -> {
            return v0.getObjects();
        }, RecommendationTrainer::join);
        Set set2 = (Set) dataset.compute((v0) -> {
            return v0.getSubjects();
        }, RecommendationTrainer::join);
        Map<O, Vector> generateRandomVectorForEach = generateRandomVectorForEach(set, this.trainerEnvironment.randomNumbersGenerator());
        Map<S, Vector> generateRandomVectorForEach2 = generateRandomVectorForEach(set2, this.trainerEnvironment.randomNumbersGenerator());
        int i = 0;
        while (true) {
            if (this.maxIterations != -1 && i >= this.maxIterations) {
                break;
            }
            int i2 = i;
            MatrixFactorizationGradient<O, S> matrixFactorizationGradient = (MatrixFactorizationGradient) dataset.compute((recommendationDatasetData, learningEnvironment) -> {
                return recommendationDatasetData.calculateGradient(generateRandomVectorForEach, generateRandomVectorForEach2, this.batchSize, i2 ^ learningEnvironment.partition(), this.regParam, this.learningRate);
            }, RecommendationTrainer::sum);
            if (this.minMdlImprovement != 0.0d && calculateImprovement(matrixFactorizationGradient) < this.minMdlImprovement) {
                break;
            }
            matrixFactorizationGradient.applyGradient(generateRandomVectorForEach, generateRandomVectorForEach2);
            i++;
        }
        return new RecommendationModel<>(generateRandomVectorForEach, generateRandomVectorForEach2);
    }

    private <O extends Serializable, S extends Serializable> double calculateImprovement(MatrixFactorizationGradient<O, S> matrixFactorizationGradient) {
        double d = 0.0d;
        for (Vector vector : matrixFactorizationGradient.getObjGrad().values()) {
            for (int i = 0; i < vector.size(); i++) {
                d += Math.abs(vector.get(i));
            }
        }
        for (Vector vector2 : matrixFactorizationGradient.getSubjGrad().values()) {
            for (int i2 = 0; i2 < vector2.size(); i2++) {
                d += Math.abs(vector2.get(i2));
            }
        }
        return d / (matrixFactorizationGradient.getSubjGrad().size() + matrixFactorizationGradient.getObjGrad().size());
    }

    private <T> Map<T, Vector> generateRandomVectorForEach(Collection<T> collection, Random random) {
        HashMap hashMap = new HashMap();
        Iterator<T> it = collection.iterator();
        while (it.hasNext()) {
            hashMap.put(it.next(), randomVector(this.k, random));
        }
        return hashMap;
    }

    private static <T> Set<T> join(Set<T> set, Set<T> set2) {
        if (set == null) {
            return set2;
        }
        if (set2 != null) {
            set.addAll(set2);
        }
        return set;
    }

    public RecommendationTrainer withLearningEnvironmentBuilder(LearningEnvironmentBuilder learningEnvironmentBuilder) {
        this.environmentBuilder = learningEnvironmentBuilder;
        return this;
    }

    public RecommendationTrainer withTrainerEnvironment(LearningEnvironment learningEnvironment) {
        this.trainerEnvironment = learningEnvironment;
        return this;
    }

    public RecommendationTrainer withBatchSize(int i) {
        this.batchSize = i;
        return this;
    }

    public RecommendationTrainer withRegularizer(double d) {
        this.regParam = d;
        return this;
    }

    public RecommendationTrainer withLearningRate(double d) {
        this.learningRate = d;
        return this;
    }

    public RecommendationTrainer withMaxIterations(int i) {
        this.maxIterations = i;
        return this;
    }

    public RecommendationTrainer withMinMdlImprovement(double d) {
        this.minMdlImprovement = d;
        return this;
    }

    public RecommendationTrainer withK(int i) {
        this.k = i;
        return this;
    }

    private static <O extends Serializable, S extends Serializable> MatrixFactorizationGradient<O, S> sum(MatrixFactorizationGradient<O, S> matrixFactorizationGradient, MatrixFactorizationGradient<O, S> matrixFactorizationGradient2) {
        return new MatrixFactorizationGradient<>(sum(matrixFactorizationGradient == null ? null : matrixFactorizationGradient.getObjGrad(), matrixFactorizationGradient2 == null ? null : matrixFactorizationGradient2.getObjGrad()), sum(matrixFactorizationGradient == null ? null : matrixFactorizationGradient.getSubjGrad(), matrixFactorizationGradient2 == null ? null : matrixFactorizationGradient2.getSubjGrad()), (matrixFactorizationGradient == null ? 0 : matrixFactorizationGradient.getRows()) + (matrixFactorizationGradient2 == null ? 0 : matrixFactorizationGradient2.getRows()));
    }

    private static <T> Map<T, Vector> sum(Map<T, Vector> map, Map<T, Vector> map2) {
        if (map == null) {
            return map2;
        }
        if (map2 == null) {
            return map;
        }
        HashMap hashMap = new HashMap();
        Iterator it = Arrays.asList(map, map2).iterator();
        while (it.hasNext()) {
            for (Map.Entry entry : ((Map) it.next()).entrySet()) {
                Vector vector = (Vector) hashMap.get(entry.getKey());
                hashMap.put(entry.getKey(), vector == null ? (Vector) entry.getValue() : ((Vector) entry.getValue()).plus(vector));
            }
        }
        return Collections.unmodifiableMap(hashMap);
    }

    private static Vector randomVector(int i, Random random) {
        double[] dArr = new double[i];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = random.nextDouble();
        }
        return VectorUtils.of(dArr);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -622788386:
                if (implMethodName.equals("getObjects")) {
                    z = 4;
                    break;
                }
                break;
            case 114251:
                if (implMethodName.equals("sum")) {
                    z = 2;
                    break;
                }
                break;
            case 3267882:
                if (implMethodName.equals("join")) {
                    z = 3;
                    break;
                }
                break;
            case 8031805:
                if (implMethodName.equals("getSubjects")) {
                    z = false;
                    break;
                }
                break;
            case 1421178236:
                if (implMethodName.equals("lambda$train$7f61863a$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/recommendation/util/RecommendationDatasetData") && serializedLambda.getImplMethodSignature().equals("()Ljava/util/Set;")) {
                    return (v0) -> {
                        return v0.getSubjects();
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBiFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/recommendation/RecommendationTrainer") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Map;Ljava/util/Map;ILorg/apache/ignite/ml/recommendation/util/RecommendationDatasetData;Lorg/apache/ignite/ml/environment/LearningEnvironment;)Lorg/apache/ignite/ml/recommendation/util/MatrixFactorizationGradient;")) {
                    RecommendationTrainer recommendationTrainer = (RecommendationTrainer) serializedLambda.getCapturedArg(0);
                    Map map = (Map) serializedLambda.getCapturedArg(1);
                    Map map2 = (Map) serializedLambda.getCapturedArg(2);
                    int intValue = ((Integer) serializedLambda.getCapturedArg(3)).intValue();
                    return (recommendationDatasetData, learningEnvironment) -> {
                        return recommendationDatasetData.calculateGradient(map, map2, this.batchSize, intValue ^ learningEnvironment.partition(), this.regParam, this.learningRate);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/recommendation/RecommendationTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/recommendation/util/MatrixFactorizationGradient;Lorg/apache/ignite/ml/recommendation/util/MatrixFactorizationGradient;)Lorg/apache/ignite/ml/recommendation/util/MatrixFactorizationGradient;")) {
                    return RecommendationTrainer::sum;
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/recommendation/RecommendationTrainer") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Set;Ljava/util/Set;)Ljava/util/Set;")) {
                    return RecommendationTrainer::join;
                }
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/recommendation/RecommendationTrainer") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Set;Ljava/util/Set;)Ljava/util/Set;")) {
                    return RecommendationTrainer::join;
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/recommendation/util/RecommendationDatasetData") && serializedLambda.getImplMethodSignature().equals("()Ljava/util/Set;")) {
                    return (v0) -> {
                        return v0.getObjects();
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
