package org.apache.ignite.ml.regressions.linear;

import org.apache.ignite.ml.Trainer;
import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.optimization.BarzilaiBorweinUpdater;
import org.apache.ignite.ml.optimization.GradientDescent;
import org.apache.ignite.ml.optimization.LeastSquaresGradientFunction;
import org.apache.ignite.ml.optimization.SimpleUpdater;

/* loaded from: input_file:org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.class */
public class LinearRegressionSGDTrainer implements Trainer<LinearRegressionModel, Matrix> {
    private final GradientDescent gradientDescent;

    public LinearRegressionSGDTrainer(GradientDescent gradientDescent) {
        this.gradientDescent = gradientDescent;
    }

    public LinearRegressionSGDTrainer(int i, double d) {
        this.gradientDescent = new GradientDescent(new LeastSquaresGradientFunction(), new BarzilaiBorweinUpdater()).withMaxIterations(i).withConvergenceTol(d);
    }

    public LinearRegressionSGDTrainer(int i, double d, double d2) {
        this.gradientDescent = new GradientDescent(new LeastSquaresGradientFunction(), new SimpleUpdater(d2)).withMaxIterations(i).withConvergenceTol(d);
    }

    @Override // org.apache.ignite.ml.Trainer
    public LinearRegressionModel train(Matrix matrix) {
        Vector optimize = this.gradientDescent.optimize(matrix, matrix.likeVector(matrix.columnSize()));
        return new LinearRegressionModel(optimize.viewPart(1, optimize.size() - 1), optimize.get(0));
    }
}
