package org.apache.ignite.ml.composition.boosting.convergence;

import java.io.Serializable;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData;
import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapDataBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;

/* loaded from: input_file:org/apache/ignite/ml/composition/boosting/convergence/ConvergenceChecker.class */
public abstract class ConvergenceChecker<K, V> implements Serializable {
    private static final long serialVersionUID = 710762134746674105L;
    private long sampleSize;
    private IgniteFunction<Double, Double> externalLbToInternalMapping;
    private Loss loss;
    private IgniteBiFunction<K, V, Vector> featureExtractor;
    private IgniteBiFunction<K, V, Double> lbExtractor;
    private double precision;
    static final /* synthetic */ boolean $assertionsDisabled;

    public ConvergenceChecker(long j, IgniteFunction<Double, Double> igniteFunction, Loss loss, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> igniteBiFunction, IgniteBiFunction<K, V, Double> igniteBiFunction2, double d) {
        if (!$assertionsDisabled && (d >= 1.0d || d < 0.0d)) {
            throw new AssertionError();
        }
        this.sampleSize = j;
        this.externalLbToInternalMapping = igniteFunction;
        this.loss = loss;
        this.featureExtractor = igniteBiFunction;
        this.lbExtractor = igniteBiFunction2;
        this.precision = d;
    }

    public boolean isConverged(DatasetBuilder<K, V> datasetBuilder, ModelsComposition modelsComposition) {
        try {
            Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> build = datasetBuilder.build(new EmptyContextBuilder(), new FeatureMatrixWithLabelsOnHeapDataBuilder(this.featureExtractor, this.lbExtractor));
            Throwable th = null;
            try {
                try {
                    boolean isConverged = isConverged(build, modelsComposition);
                    if (build != null) {
                        if (0 != 0) {
                            try {
                                build.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            build.close();
                        }
                    }
                    return isConverged;
                } finally {
                }
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public boolean isConverged(Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset, ModelsComposition modelsComposition) {
        Double computeMeanErrorOnDataset = computeMeanErrorOnDataset(dataset, modelsComposition);
        return computeMeanErrorOnDataset.doubleValue() < this.precision || computeMeanErrorOnDataset.isNaN();
    }

    public abstract Double computeMeanErrorOnDataset(Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset, ModelsComposition modelsComposition);

    public double computeError(Vector vector, Double d, ModelsComposition modelsComposition) {
        return -this.loss.gradient(this.sampleSize, this.externalLbToInternalMapping.apply(d).doubleValue(), modelsComposition.apply(vector).doubleValue());
    }

    static {
        $assertionsDisabled = !ConvergenceChecker.class.desiredAssertionStatus();
    }
}
