/*
 * Copyright 2019 GridGain Systems, Inc. and Contributors.
 *
 * Licensed under the GridGain Community Edition License (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     https://www.gridgain.com/products/software/community-edition/gridgain-community-edition-license
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.ignite.ml.dataset.primitive.builder.data;

import java.io.Serializable;
import java.util.Iterator;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
import org.apache.ignite.ml.dataset.UpstreamEntry;
import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData;
import org.apache.ignite.ml.environment.LearningEnvironment;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.structures.LabeledVector;

/**
 * A partition {@code data} builder that makes {@link SimpleLabeledDatasetData}.
 *
 * @param <K> Type of a key in <tt>upstream</tt> data.
 * @param <V> Type of a value in <tt>upstream</tt> data.
 * @param <C> type of a partition <tt>context</tt>.
 */
public class SimpleLabeledDatasetDataBuilder<K, V, C extends Serializable>
    implements PartitionDataBuilder<K, V, C, SimpleLabeledDatasetData> {
    /** */
    private static final long serialVersionUID = 3678784980215216039L;

    /** Function that extracts labeled vectors from an {@code upstream} data. */
    private final Preprocessor<K, V> vectorizer;

    /**
     * Constructs a new instance of partition {@code data} builder that makes {@link SimpleLabeledDatasetData}.
     *
     * @param vectorizer Function that extracts labeled vectors from an {@code upstream} data.
     */
    public SimpleLabeledDatasetDataBuilder(Preprocessor<K, V> vectorizer) {
        this.vectorizer = vectorizer;
    }

    /** {@inheritDoc} */
    @Override public SimpleLabeledDatasetData build(
        LearningEnvironment env,
        Iterator<UpstreamEntry<K, V>> upstreamData,
        long upstreamDataSize, C ctx) {
        // Prepares the matrix of features in flat column-major format.
        int featureCols = -1;
        int lbCols = -1;
        double[] features = null;
        double[] labels = null;

        int ptr = 0;
        while (upstreamData.hasNext()) {
            UpstreamEntry<K, V> entry = upstreamData.next();

            LabeledVector<double[]> labeledVector = vectorizer.apply(entry.getKey(), entry.getValue());
            Vector featureRow = labeledVector.features();

            if (featureCols < 0) {
                featureCols = featureRow.size();
                features = new double[Math.toIntExact(upstreamDataSize * featureCols)];
            }
            else
                assert featureRow.size() == featureCols : "Feature extractor must return exactly " + featureCols
                    + " features";

            for (int i = 0; i < featureCols; i++)
                features[Math.toIntExact(i * upstreamDataSize) + ptr] = featureRow.get(i);

            double[] lbRow = labeledVector.label();

            if (lbCols < 0) {
                lbCols = lbRow.length;
                labels = new double[Math.toIntExact(upstreamDataSize * lbCols)];
            }

            assert lbRow.length == lbCols : "Label extractor must return exactly " + lbCols + " labels";

            for (int i = 0; i < lbCols; i++)
                labels[Math.toIntExact(i * upstreamDataSize) + ptr] = lbRow[i];

            ptr++;
        }

        return new SimpleLabeledDatasetData(features, labels, Math.toIntExact(upstreamDataSize));
    }
}
