package org.apache.ignite.ml.regressions.linear;

import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleLabeledDatasetDataBuilder;
import org.apache.ignite.ml.math.isolve.lsqr.LSQROnHeap;
import org.apache.ignite.ml.math.isolve.lsqr.LSQRResult;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.preprocessing.developer.PatchedPreprocessor;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;

/* loaded from: input_file:org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.class */
public class LinearRegressionLSQRTrainer extends SingleLabelDatasetTrainer<LinearRegressionModel> {
    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public <K, V> LinearRegressionModel fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        return updateModel((LinearRegressionModel) null, (DatasetBuilder) datasetBuilder, (Preprocessor) preprocessor);
    }

    private static LabeledVector<double[]> extendLabeledVector(LabeledVector<Double> labeledVector) {
        double[] dArr = new double[labeledVector.features().size() + 1];
        System.arraycopy(labeledVector.features().asArray(), 0, dArr, 0, labeledVector.features().size());
        dArr[dArr.length - 1] = 1.0d;
        return VectorUtils.of(dArr).labeled(new double[]{labeledVector.label().doubleValue()});
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public <K, V> LinearRegressionModel updateModel(LinearRegressionModel linearRegressionModel, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        try {
            LSQROnHeap lSQROnHeap = new LSQROnHeap(datasetBuilder, this.envBuilder, new SimpleLabeledDatasetDataBuilder(new PatchedPreprocessor(LinearRegressionLSQRTrainer::extendLabeledVector, preprocessor)), learningEnvironment());
            Throwable th = null;
            double[] dArr = null;
            if (linearRegressionModel != null) {
                try {
                    try {
                        Vector like = linearRegressionModel.getWeights().like(linearRegressionModel.getWeights().size() + 1);
                        linearRegressionModel.getWeights().nonZeroes().forEach(element -> {
                            like.set(element.index(), element.get());
                        });
                        like.set(like.size() - 1, linearRegressionModel.getIntercept());
                        dArr = like.asArray();
                    } finally {
                    }
                } finally {
                }
            }
            LSQRResult solve = lSQROnHeap.solve(0.0d, 1.0E-12d, 1.0E-12d, 1.0E8d, -1.0d, false, dArr);
            if (solve == null) {
                LinearRegressionModel linearRegressionModel2 = (LinearRegressionModel) getLastTrainedModelOrThrowEmptyDatasetException(linearRegressionModel);
                if (lSQROnHeap != null) {
                    if (0 != 0) {
                        try {
                            lSQROnHeap.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        lSQROnHeap.close();
                    }
                }
                return linearRegressionModel2;
            }
            if (lSQROnHeap != null) {
                if (0 != 0) {
                    try {
                        lSQROnHeap.close();
                    } catch (Throwable th3) {
                        th.addSuppressed(th3);
                    }
                } else {
                    lSQROnHeap.close();
                }
            }
            double[] x = solve.getX();
            return new LinearRegressionModel(new DenseVector(Arrays.copyOfRange(x, 0, x.length - 1)), x[x.length - 1]);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
        throw new RuntimeException(e);
    }

    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public boolean isUpdateable(LinearRegressionModel linearRegressionModel) {
        return true;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1880539972:
                if (implMethodName.equals("extendLabeledVector")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/structures/LabeledVector;)Lorg/apache/ignite/ml/structures/LabeledVector;")) {
                    return LinearRegressionLSQRTrainer::extendLabeledVector;
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
