/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.mleap;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import ml.combust.mleap.core.types.DataType;
import ml.combust.mleap.core.types.ScalarType;
import ml.combust.mleap.core.types.StructField;
import ml.combust.mleap.core.types.StructType;
import ml.combust.mleap.runtime.frame.DefaultLeapFrame;
import ml.combust.mleap.runtime.frame.FrameBuilder;
import ml.combust.mleap.runtime.frame.Row;
import ml.combust.mleap.runtime.frame.Transformer;
import ml.combust.mleap.runtime.javadsl.LeapFrameBuilder;
import org.apache.ignite.ml.inference.Model;
import org.apache.ignite.ml.math.primitives.vector.NamedVector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import scala.collection.immutable.Set;
import scala.collection.immutable.Stream;
import scala.util.Try;

public class MLeapModel
implements Model<NamedVector, Double> {
    private final Transformer transformer;
    private final List<String> schema;
    private final String outputFieldName;

    public MLeapModel(Transformer transformer, List<String> schema, String outputFieldName) {
        this.transformer = transformer;
        this.schema = new ArrayList<String>(schema);
        this.outputFieldName = outputFieldName;
    }

    public Double predict(NamedVector input) {
        LeapFrameBuilder builder = new LeapFrameBuilder();
        ArrayList<StructField> structFields = new ArrayList<StructField>();
        ArrayList<Double> values = new ArrayList<Double>();
        for (String fieldName : input.getKeys()) {
            structFields.add(new StructField(fieldName, (DataType)ScalarType.Double()));
            values.add(input.get(fieldName));
        }
        StructType schema = builder.createSchema(structFields);
        ArrayList<Row> rows = new ArrayList<Row>();
        rows.add(builder.createRowFromIterable(values));
        DefaultLeapFrame inputFrame = builder.createFrame(schema, rows);
        return this.predict(inputFrame);
    }

    public double predict(Double[] input) {
        if (input.length != this.schema.size()) {
            throw new IllegalArgumentException("Input size is not equal to schema size");
        }
        Map<String, Double> vec = IntStream.range(0, input.length).boxed().collect(Collectors.toMap(this.schema::get, i -> input[i]));
        return this.predict(VectorUtils.of(vec));
    }

    public double predict(DefaultLeapFrame inputFrame) {
        DefaultLeapFrame outputFrame = (DefaultLeapFrame)this.transformer.transform((FrameBuilder)inputFrame).get();
        Try resFrame = outputFrame.select(new Set.Set1((Object)this.outputFieldName).toSeq());
        DefaultLeapFrame frame = (DefaultLeapFrame)resFrame.get();
        Stream stream = (Stream)frame.productElement(1);
        Row row = (Row)stream.head();
        return (Double)row.get(0);
    }

    public void close() {
        this.transformer.close();
    }
}

