/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.inference.parser;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.ignite.ml.inference.Model;
import org.apache.ignite.ml.inference.parser.ModelParser;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

public abstract class TensorFlowBaseModelParser<I, O>
implements ModelParser<I, O, Model<I, O>> {
    private static final long serialVersionUID = 5574259553625871456L;
    private final Map<String, InputTransformer<I>> inputs = new HashMap<String, InputTransformer<I>>();
    private List<String> outputNames;
    private OutputTransformer<O> outputTransformer;

    public Model<I, O> parse(byte[] mdl) {
        return new TensorFlowInfModel(this.parseModel(mdl));
    }

    public abstract Session parseModel(byte[] var1);

    public TensorFlowBaseModelParser<I, O> withInput(String name, InputTransformer<I> transformer) {
        if (this.inputs.containsKey(name)) {
            throw new IllegalArgumentException("Inputs already contains specified name [name=" + name + "]");
        }
        this.inputs.put(name, transformer);
        return this;
    }

    public TensorFlowBaseModelParser<I, O> withOutput(List<String> names, OutputTransformer<O> transformer) {
        if (this.outputNames != null || this.outputTransformer != null) {
            throw new IllegalArgumentException("Outputs already specified");
        }
        this.outputNames = names;
        this.outputTransformer = transformer;
        return this;
    }

    private class TensorFlowInfModel
    implements Model<I, O> {
        private final Session ses;

        TensorFlowInfModel(Session ses) {
            this.ses = ses;
        }

        public O predict(I input) {
            Session.Runner runner = this.ses.runner();
            runner = this.feedAll(runner, input);
            runner = this.fetchAll(runner);
            List prediction = runner.run();
            Map<String, Tensor<?>> collectedPredictionTensors = this.indexTensors(prediction);
            return TensorFlowBaseModelParser.this.outputTransformer.transform(collectedPredictionTensors);
        }

        private Session.Runner feedAll(Session.Runner runner, I input) {
            for (Map.Entry e : TensorFlowBaseModelParser.this.inputs.entrySet()) {
                String opName = (String)e.getKey();
                InputTransformer transformer = (InputTransformer)e.getValue();
                runner = runner.feed(opName, transformer.transform(input));
            }
            return runner;
        }

        private Session.Runner fetchAll(Session.Runner runner) {
            for (String e : TensorFlowBaseModelParser.this.outputNames) {
                runner.fetch(e);
            }
            return runner;
        }

        private Map<String, Tensor<?>> indexTensors(List<Tensor<?>> tensors) {
            HashMap collectedTensors = new HashMap();
            Iterator outputNamesIter = TensorFlowBaseModelParser.this.outputNames.iterator();
            Iterator<Tensor<?>> tensorsIter = tensors.iterator();
            while (outputNamesIter.hasNext() && tensorsIter.hasNext()) {
                collectedTensors.put((String)outputNamesIter.next(), tensorsIter.next());
            }
            if (outputNamesIter.hasNext() || tensorsIter.hasNext()) {
                throw new IllegalStateException("Outputs are incorrect");
            }
            return collectedTensors;
        }

        public void close() {
            this.ses.close();
        }
    }

    @FunctionalInterface
    public static interface OutputTransformer<O>
    extends Serializable {
        public O transform(Map<String, Tensor<?>> var1);
    }

    @FunctionalInterface
    public static interface InputTransformer<I>
    extends Serializable {
        public Tensor<?> transform(I var1);
    }
}

