/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.mleap;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import ml.combust.mleap.core.types.ScalarType;
import ml.combust.mleap.core.types.StructField;
import ml.combust.mleap.core.types.StructType;
import ml.combust.mleap.runtime.MleapContext;
import ml.combust.mleap.runtime.frame.Transformer;
import ml.combust.mleap.runtime.javadsl.BundleBuilder;
import ml.combust.mleap.runtime.javadsl.ContextBuilder;
import ml.combust.mleap.runtime.transformer.PipelineModel;
import org.apache.ignite.ml.inference.parser.ModelParser;
import org.apache.ignite.ml.math.primitives.vector.NamedVector;
import org.apache.ignite.ml.mleap.MLeapModel;
import scala.collection.JavaConverters;
import scala.collection.Seq;

public class MLeapModelParser
implements ModelParser<NamedVector, Double, MLeapModel> {
    private static final long serialVersionUID = -370352744966205715L;
    private static final String TMP_FILE_PREFIX = "mleap_model";
    private static final String TMP_FILE_POSTFIX = ".zip";

    public MLeapModel parse(byte[] mdl) {
        MleapContext mleapCtx = new ContextBuilder().createMleapContext();
        BundleBuilder bundleBuilder = new BundleBuilder();
        File file = null;
        try {
            file = File.createTempFile(TMP_FILE_PREFIX, TMP_FILE_POSTFIX);
            try (FileOutputStream fos = new FileOutputStream(file);){
                fos.write(mdl);
                fos.flush();
            }
            Transformer transformer = (Transformer)bundleBuilder.load(file, mleapCtx).root();
            PipelineModel pipelineMdl = (PipelineModel)transformer.model();
            List<String> inputSchema = this.checkAndGetInputSchema(pipelineMdl);
            String outputSchema = this.checkAndGetOutputSchema(pipelineMdl);
            MLeapModel mLeapModel = new MLeapModel(transformer, inputSchema, outputSchema);
            return mLeapModel;
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        finally {
            if (file != null) {
                file.delete();
            }
        }
    }

    private String checkAndGetOutputSchema(PipelineModel mdl) {
        Transformer lastTransformer = (Transformer)mdl.transformers().last();
        StructType outputSchema = lastTransformer.outputSchema();
        ArrayList output = new ArrayList((Collection)JavaConverters.seqAsJavaListConverter((Seq)outputSchema.fields()).asJava());
        if (output.size() != 1) {
            throw new IllegalArgumentException("Parser supports only scalar outputs");
        }
        return ((StructField)output.get(0)).name();
    }

    private List<String> checkAndGetInputSchema(PipelineModel mdl) {
        Transformer firstTransformer = (Transformer)mdl.transformers().head();
        StructType inputSchema = firstTransformer.inputSchema();
        ArrayList input = new ArrayList((Collection)JavaConverters.seqAsJavaListConverter((Seq)inputSchema.fields()).asJava());
        ArrayList<String> schema = new ArrayList<String>();
        for (StructField field : input) {
            String fieldName = field.name();
            schema.add(field.name());
            if (ScalarType.Double().base().equals(field.dataType().base())) continue;
            throw new IllegalArgumentException("Parser supports only double types [name=" + fieldName + ",type=" + field.dataType() + "]");
        }
        return schema;
    }
}

