package org.apache.ignite.ml.knn.regression;

import java.util.Iterator;
import java.util.List;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.LabeledVectorSet;
import org.apache.ignite.ml.util.ModelTrace;

/* loaded from: input_file:org/apache/ignite/ml/knn/regression/KNNRegressionModel.class */
public class KNNRegressionModel extends KNNClassificationModel {
    private static final long serialVersionUID = -721836321291120543L;

    public KNNRegressionModel(Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset) {
        super(dataset);
    }

    @Override // org.apache.ignite.ml.knn.classification.KNNClassificationModel, org.apache.ignite.ml.inference.Model
    public Double predict(Vector vector) {
        return Double.valueOf(predictYBasedOn(findKNearestNeighbors(vector), vector));
    }

    private double predictYBasedOn(List<LabeledVector> list, Vector vector) {
        switch (this.stgy) {
            case SIMPLE:
                return simpleRegression(list);
            case WEIGHTED:
                return weightedRegression(list, vector);
            default:
                throw new UnsupportedOperationException("Strategy " + this.stgy.name() + " is not supported");
        }
    }

    private double weightedRegression(List<LabeledVector> list, Vector vector) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (LabeledVector labeledVector : list) {
            double compute = this.distanceMeasure.compute(vector, labeledVector.features());
            d += ((Double) labeledVector.label()).doubleValue() * compute;
            d2 += compute;
        }
        return d2 == 0.0d ? simpleRegression(list) : d / d2;
    }

    private double simpleRegression(List<LabeledVector> list) {
        double d = 0.0d;
        Iterator<LabeledVector> it = list.iterator();
        while (it.hasNext()) {
            d += ((Double) it.next().label()).doubleValue();
        }
        return d / this.k;
    }

    @Override // org.apache.ignite.ml.knn.NNClassificationModel
    public String toString() {
        return toString(false);
    }

    @Override // org.apache.ignite.ml.knn.NNClassificationModel, org.apache.ignite.ml.IgniteModel
    public String toString(boolean z) {
        return ModelTrace.builder("KNNRegressionModel", z).addField("k", String.valueOf(this.k)).addField("measure", this.distanceMeasure.getClass().getSimpleName()).addField("strategy", this.stgy.name()).toString();
    }
}
