package org.apache.ignite.ml.knn.regression;

import java.util.Iterator;
import java.util.List;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.knn.KNNModel;
import org.apache.ignite.ml.knn.utils.indices.SpatialIndex;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.structures.LabeledVector;

/* loaded from: input_file:org/apache/ignite/ml/knn/regression/KNNRegressionModel.class */
public class KNNRegressionModel extends KNNModel<Double> {
    private final KNNRegressionPredictor predictor;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/ignite/ml/knn/regression/KNNRegressionModel$KNNRegressionPredictor.class */
    public interface KNNRegressionPredictor {
        Double predict(List<LabeledVector<Double>> list, Vector vector);
    }

    /* loaded from: input_file:org/apache/ignite/ml/knn/regression/KNNRegressionModel$KNNRegressionSimplePredictor.class */
    private class KNNRegressionSimplePredictor implements KNNRegressionPredictor {
        private KNNRegressionSimplePredictor() {
        }

        @Override // org.apache.ignite.ml.knn.regression.KNNRegressionModel.KNNRegressionPredictor
        public Double predict(List<LabeledVector<Double>> list, Vector vector) {
            if (list.isEmpty()) {
                return null;
            }
            double d = 0.0d;
            Iterator<LabeledVector<Double>> it = list.iterator();
            while (it.hasNext()) {
                d += it.next().label().doubleValue();
            }
            return Double.valueOf(d / KNNRegressionModel.this.k);
        }
    }

    /* loaded from: input_file:org/apache/ignite/ml/knn/regression/KNNRegressionModel$KNNRegressionWeightedPredictor.class */
    private class KNNRegressionWeightedPredictor extends KNNRegressionSimplePredictor {
        private KNNRegressionWeightedPredictor() {
            super();
        }

        @Override // org.apache.ignite.ml.knn.regression.KNNRegressionModel.KNNRegressionSimplePredictor, org.apache.ignite.ml.knn.regression.KNNRegressionModel.KNNRegressionPredictor
        public Double predict(List<LabeledVector<Double>> list, Vector vector) {
            if (list.isEmpty()) {
                return null;
            }
            double d = 0.0d;
            double d2 = 0.0d;
            for (LabeledVector<Double> labeledVector : list) {
                double compute = KNNRegressionModel.this.distanceMeasure.compute(vector, labeledVector.features());
                d += labeledVector.label().doubleValue() * compute;
                d2 += compute;
            }
            return d2 == 0.0d ? super.predict(list, vector) : Double.valueOf(d / d2);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public KNNRegressionModel(Dataset<EmptyContext, SpatialIndex<Double>> dataset, DistanceMeasure distanceMeasure, int i, boolean z) {
        super(dataset, distanceMeasure, i, z);
        this.predictor = z ? new KNNRegressionWeightedPredictor() : new KNNRegressionSimplePredictor();
    }

    @Override // org.apache.ignite.ml.inference.Model
    public Double predict(Vector vector) {
        return this.predictor.predict(findKClosest(this.k, vector), vector);
    }
}
