/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.recommendation.util;

import java.io.Serializable;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.recommendation.ObjectSubjectRatingTriplet;
import org.apache.ignite.ml.recommendation.util.MatrixFactorizationGradient;
import org.apache.ignite.ml.util.Utils;

public class RecommendationDatasetData<O extends Serializable, S extends Serializable>
implements AutoCloseable {
    private final List<? extends ObjectSubjectRatingTriplet<O, S>> ratings;

    public RecommendationDatasetData(List<? extends ObjectSubjectRatingTriplet<O, S>> ratings) {
        this.ratings = Collections.unmodifiableList(ratings);
    }

    public MatrixFactorizationGradient<O, S> calculateGradient(Map<O, Vector> objMatrix, Map<S, Vector> subjMatrix, int batchSize, int seed, double regParam, double learningRate) {
        int[] rows;
        HashMap objGrads = new HashMap();
        HashMap subjGrads = new HashMap();
        for (int row : rows = this.getRows(batchSize, seed)) {
            ObjectSubjectRatingTriplet<O, S> triplet = this.ratings.get(row);
            Vector objVector = objMatrix.get(triplet.getObj());
            Vector subjVector = subjMatrix.get(triplet.getSubj());
            double error = this.calculateError(objVector, subjVector, triplet.getRating());
            Vector objGrad = subjVector.times(error).plus(objVector.times(regParam)).times(learningRate);
            Vector subjGrad = objVector.times(error).plus(subjVector.times(regParam)).times(learningRate);
            objGrads.put(triplet.getObj(), objGrad);
            subjGrads.put(triplet.getSubj(), subjGrad);
        }
        return new MatrixFactorizationGradient(objGrads, subjGrads, rows.length);
    }

    public Set<O> getObjects() {
        HashSet res = new HashSet();
        for (ObjectSubjectRatingTriplet<O, S> triplet : this.ratings) {
            res.add(triplet.getObj());
        }
        return res;
    }

    public Set<S> getSubjects() {
        HashSet res = new HashSet();
        for (ObjectSubjectRatingTriplet<O, S> triplet : this.ratings) {
            res.add(triplet.getSubj());
        }
        return res;
    }

    private double calculateError(Vector wi, Vector hi, double rating) {
        if (wi == null || hi == null) {
            return rating;
        }
        return wi.dot(hi) - rating;
    }

    private int[] getRows(int batchSize, int seed) {
        return Utils.selectKDistinct(this.ratings.size(), Math.min(batchSize, this.ratings.size()), new Random(seed));
    }

    @Override
    public void close() {
    }
}

