/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.recommendation;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.apache.ignite.binary.BinaryObject;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironment;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.recommendation.ObjectSubjectRatingTriplet;
import org.apache.ignite.ml.recommendation.RecommendationModel;
import org.apache.ignite.ml.recommendation.util.MatrixFactorizationGradient;
import org.apache.ignite.ml.recommendation.util.RecommendationBinaryDatasetDataBuilder;
import org.apache.ignite.ml.recommendation.util.RecommendationDatasetData;
import org.apache.ignite.ml.recommendation.util.RecommendationDatasetDataBuilder;

public class RecommendationTrainer {
    private LearningEnvironmentBuilder environmentBuilder = LearningEnvironmentBuilder.defaultBuilder();
    private LearningEnvironment trainerEnvironment = this.environmentBuilder.buildForTrainer();
    private int batchSize = 1000;
    private double regParam = 0.0;
    private double learningRate = 10.0;
    private int maxIterations = 1000;
    private double minMdlImprovement = 0.0;
    private int k = 10;

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public RecommendationModel<Serializable, Serializable> fit(DatasetBuilder<Object, BinaryObject> datasetBuilder, String objFieldName, String subjFieldName, String ratingFieldName) {
        try (Dataset dataset = datasetBuilder.build(this.environmentBuilder, new EmptyContextBuilder(), new RecommendationBinaryDatasetDataBuilder(objFieldName, subjFieldName, ratingFieldName), this.trainerEnvironment);){
            RecommendationModel<Serializable, Serializable> recommendationModel = this.train(dataset);
            return recommendationModel;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public <K, O extends Serializable, S extends Serializable> RecommendationModel<O, S> fit(DatasetBuilder<K, ? extends ObjectSubjectRatingTriplet<O, S>> datasetBuilder) {
        try (Dataset<EmptyContext, RecommendationDatasetData<O, S>> dataset = datasetBuilder.build(this.environmentBuilder, new EmptyContextBuilder(), new RecommendationDatasetDataBuilder(), this.trainerEnvironment);){
            RecommendationModel<O, S> recommendationModel = this.train(dataset);
            return recommendationModel;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private <O extends Serializable, S extends Serializable> RecommendationModel<O, S> train(Dataset<EmptyContext, RecommendationDatasetData<O, S>> dataset) {
        Set objects = (Set)dataset.compute(RecommendationDatasetData::getObjects, RecommendationTrainer::join);
        Set subjects = (Set)dataset.compute(RecommendationDatasetData::getSubjects, RecommendationTrainer::join);
        Map objMatrix = this.generateRandomVectorForEach(objects, this.trainerEnvironment.randomNumbersGenerator());
        Map subjMatrix = this.generateRandomVectorForEach(subjects, this.trainerEnvironment.randomNumbersGenerator());
        for (int i = 0; this.maxIterations == -1 || i < this.maxIterations; ++i) {
            int seed = i;
            MatrixFactorizationGradient grad = (MatrixFactorizationGradient)dataset.compute((data, env) -> data.calculateGradient(objMatrix, subjMatrix, this.batchSize, seed ^ env.partition(), this.regParam, this.learningRate), RecommendationTrainer::sum);
            if (this.minMdlImprovement != 0.0 && this.calculateImprovement(grad) < this.minMdlImprovement) break;
            grad.applyGradient(objMatrix, subjMatrix);
        }
        return new RecommendationModel(objMatrix, subjMatrix);
    }

    private <O extends Serializable, S extends Serializable> double calculateImprovement(MatrixFactorizationGradient<O, S> grad) {
        int i;
        double mean = 0.0;
        for (Vector vector : grad.getObjGrad().values()) {
            for (i = 0; i < vector.size(); ++i) {
                mean += Math.abs(vector.get(i));
            }
        }
        for (Vector vector : grad.getSubjGrad().values()) {
            for (i = 0; i < vector.size(); ++i) {
                mean += Math.abs(vector.get(i));
            }
        }
        return mean /= (double)(grad.getSubjGrad().size() + grad.getObjGrad().size());
    }

    private <T> Map<T, Vector> generateRandomVectorForEach(Collection<T> objects, Random rnd) {
        HashMap<T, Vector> res = new HashMap<T, Vector>();
        for (T obj : objects) {
            res.put(obj, RecommendationTrainer.randomVector(this.k, rnd));
        }
        return res;
    }

    private static <T> Set<T> join(Set<T> a, Set<T> b) {
        if (a == null) {
            return b;
        }
        if (b != null) {
            a.addAll(b);
        }
        return a;
    }

    public RecommendationTrainer withLearningEnvironmentBuilder(LearningEnvironmentBuilder environmentBuilder) {
        this.environmentBuilder = environmentBuilder;
        return this;
    }

    public RecommendationTrainer withTrainerEnvironment(LearningEnvironment trainerEnvironment) {
        this.trainerEnvironment = trainerEnvironment;
        return this;
    }

    public RecommendationTrainer withBatchSize(int batchSize) {
        this.batchSize = batchSize;
        return this;
    }

    public RecommendationTrainer withRegularizer(double regParam) {
        this.regParam = regParam;
        return this;
    }

    public RecommendationTrainer withLearningRate(double learningRate) {
        this.learningRate = learningRate;
        return this;
    }

    public RecommendationTrainer withMaxIterations(int maxIterations) {
        this.maxIterations = maxIterations;
        return this;
    }

    public RecommendationTrainer withMinMdlImprovement(double minMdlImprovement) {
        this.minMdlImprovement = minMdlImprovement;
        return this;
    }

    public RecommendationTrainer withK(int k) {
        this.k = k;
        return this;
    }

    private static <O extends Serializable, S extends Serializable> MatrixFactorizationGradient<O, S> sum(MatrixFactorizationGradient<O, S> a, MatrixFactorizationGradient<O, S> b) {
        return new MatrixFactorizationGradient<O, S>(RecommendationTrainer.sum(a == null ? null : a.getObjGrad(), b == null ? null : b.getObjGrad()), RecommendationTrainer.sum(a == null ? null : a.getSubjGrad(), b == null ? null : b.getSubjGrad()), (a == null ? 0 : a.getRows()) + (b == null ? 0 : b.getRows()));
    }

    private static <T> Map<T, Vector> sum(Map<T, Vector> a, Map<T, Vector> b) {
        if (a == null) {
            return b;
        }
        if (b == null) {
            return a;
        }
        HashMap res = new HashMap();
        for (Map map : Arrays.asList(a, b)) {
            for (Map.Entry e : map.entrySet()) {
                Vector vector = (Vector)res.get(e.getKey());
                res.put(e.getKey(), vector == null ? (Vector)e.getValue() : ((Vector)e.getValue()).plus(vector));
            }
        }
        return Collections.unmodifiableMap(res);
    }

    private static Vector randomVector(int k, Random rnd) {
        double[] vector = new double[k];
        for (int i = 0; i < vector.length; ++i) {
            vector[i] = rnd.nextDouble();
        }
        return VectorUtils.of(vector);
    }
}

