package org.apache.ignite.ml.knn.regression;

import org.apache.ignite.ml.knn.models.KNNModel;
import org.apache.ignite.ml.knn.models.KNNStrategy;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException;
import org.apache.ignite.ml.structures.LabeledDataset;
import org.apache.ignite.ml.structures.LabeledVector;

/* loaded from: input_file:org/apache/ignite/ml/knn/regression/KNNMultipleLinearRegression.class */
public class KNNMultipleLinearRegression extends KNNModel {
    public KNNMultipleLinearRegression(int i, DistanceMeasure distanceMeasure, KNNStrategy kNNStrategy, LabeledDataset labeledDataset) {
        super(i, distanceMeasure, kNNStrategy, labeledDataset);
    }

    @Override // org.apache.ignite.ml.knn.models.KNNModel, java.util.function.Function
    public Double apply(Vector vector) {
        return Double.valueOf(predictYBasedOn(findKNearestNeighbors(vector, true), vector));
    }

    private double predictYBasedOn(LabeledVector[] labeledVectorArr, Vector vector) {
        switch (this.stgy) {
            case SIMPLE:
                return simpleRegression(labeledVectorArr);
            case WEIGHTED:
                return weightedRegression(labeledVectorArr, vector);
            default:
                throw new UnsupportedOperationException("Strategy " + this.stgy.name() + " is not supported");
        }
    }

    private double weightedRegression(LabeledVector<Vector, Double>[] labeledVectorArr, Vector vector) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < labeledVectorArr.length; i++) {
            double compute = this.cachedDistances != null ? this.cachedDistances[i] : this.distanceMeasure.compute(vector, labeledVectorArr[i].features());
            d += labeledVectorArr[i].label().doubleValue() * compute;
            d2 += compute;
        }
        return d / d2;
    }

    private double simpleRegression(LabeledVector<Vector, Double>[] labeledVectorArr) {
        double d = 0.0d;
        for (LabeledVector<Vector, Double> labeledVector : labeledVectorArr) {
            d += labeledVector.label().doubleValue();
        }
        return d / this.k;
    }
}
