/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.multiclass;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.multiclass.MultiClassModel;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.preprocessing.developer.PatchedPreprocessor;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;

public class OneVsRestTrainer<M extends IgniteModel<Vector, Double>>
extends SingleLabelDatasetTrainer<MultiClassModel<M>> {
    private SingleLabelDatasetTrainer<M> classifier;

    public OneVsRestTrainer(SingleLabelDatasetTrainer<M> classifier) {
        this.classifier = classifier;
    }

    @Override
    public <K, V> MultiClassModel<M> fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
        return this.updateModel((MultiClassModel<M>)null, datasetBuilder, extractor);
    }

    @Override
    protected <K, V> MultiClassModel<M> updateModel(MultiClassModel<M> newMdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
        List<Double> classes = this.extractClassLabels(datasetBuilder, extractor);
        if (classes.isEmpty()) {
            return this.getLastTrainedModelOrThrowEmptyDatasetException(newMdl);
        }
        MultiClassModel multiClsMdl = new MultiClassModel();
        classes.forEach(clsLb -> {
            IgniteFunction lbTransformer = lb -> lb.equals(clsLb) ? 1.0 : 0.0;
            IgniteFunction func = lv -> new LabeledVector((Vector)lv.features(), lbTransformer.apply(lv.label()));
            PatchedPreprocessor patchedPreprocessor = new PatchedPreprocessor(func, extractor);
            IgniteModel mdl = Optional.ofNullable(newMdl).flatMap(multiClassModel -> multiClassModel.getModel((Double)clsLb)).map(learnedModel -> this.classifier.update((IgniteModel)learnedModel, datasetBuilder, patchedPreprocessor)).orElseGet(() -> this.classifier.fit(datasetBuilder, patchedPreprocessor));
            multiClsMdl.add((double)clsLb, mdl);
        });
        return multiClsMdl;
    }

    @Override
    public boolean isUpdateable(MultiClassModel<M> mdl) {
        return true;
    }

    private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        assert (datasetBuilder != null);
        LabelPartitionDataBuilderOnHeap partDataBuilder = new LabelPartitionDataBuilderOnHeap(preprocessor);
        ArrayList<Double> res = new ArrayList<Double>();
        try (Dataset dataset = datasetBuilder.build(this.envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder, this.learningEnvironment());){
            Set clsLabels = (Set)dataset.compute(data -> {
                double[] lbs;
                HashSet<Double> locClsLabels = new HashSet<Double>();
                for (double lb : lbs = data.getY()) {
                    locClsLabels.add(lb);
                }
                return locClsLabels;
            }, (a, b) -> {
                if (a == null) {
                    return b == null ? new HashSet() : b;
                }
                if (b == null) {
                    return a;
                }
                return Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet());
            });
            if (clsLabels != null) {
                res.addAll(clsLabels);
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        return res;
    }
}

