/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.knn.classification;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.knn.KNNModel;
import org.apache.ignite.ml.knn.utils.indices.SpatialIndex;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.structures.LabeledVector;

public class KNNClassificationModel
extends KNNModel<Double> {
    KNNClassificationModel(Dataset<EmptyContext, SpatialIndex<Double>> dataset, DistanceMeasure distanceMeasure, int k, boolean weighted) {
        super(dataset, distanceMeasure, k, weighted);
    }

    @Override
    public Double predict(Vector input) {
        List<LabeledVector<Double>> neighbors = this.findKClosest(this.k, input);
        return this.election(neighbors, input);
    }

    private Double election(List<LabeledVector<Double>> neighbours, Vector pnt) {
        Collection<GroupedNeighbours> groups = this.groupByLabel(neighbours);
        return this.election(groups, pnt);
    }

    private Double election(Collection<GroupedNeighbours> groups, Vector pnt) {
        Double res = null;
        double votes = 0.0;
        for (GroupedNeighbours groupedNeighbours : groups) {
            double grpVotes = this.calculateGroupVotes(groupedNeighbours, pnt);
            if (!(grpVotes > votes)) continue;
            votes = grpVotes;
            res = groupedNeighbours.getLb();
        }
        return res;
    }

    private Double calculateGroupVotes(GroupedNeighbours grp, Vector pnt) {
        double res = 0.0;
        for (Vector neighbour : grp) {
            double distance = this.distanceMeasure.compute(pnt, neighbour);
            double vote = this.weighted ? 1.0 / distance : 1.0;
            res += vote;
        }
        return res;
    }

    private Collection<GroupedNeighbours> groupByLabel(List<LabeledVector<Double>> neighbours) {
        HashMap<Double, GroupedNeighbours> groups = new HashMap<Double, GroupedNeighbours>();
        for (LabeledVector<Double> neighbour : neighbours) {
            double lb = neighbour.label();
            GroupedNeighbours groupedNeighbours = (GroupedNeighbours)groups.get(lb);
            if (groupedNeighbours == null) {
                groupedNeighbours = new GroupedNeighbours(lb);
                groups.put(lb, groupedNeighbours);
            }
            groupedNeighbours.addNeighbour((Vector)neighbour.features());
        }
        return Collections.unmodifiableCollection(groups.values());
    }

    private static class GroupedNeighbours
    implements Iterable<Vector> {
        private final Double lb;
        private final List<Vector> neighbours = new ArrayList<Vector>();

        public GroupedNeighbours(Double lb) {
            this.lb = lb;
        }

        public void addNeighbour(Vector neighbour) {
            this.neighbours.add(neighbour);
        }

        public Double getLb() {
            return this.lb;
        }

        @Override
        public Iterator<Vector> iterator() {
            return this.neighbours.iterator();
        }
    }
}

