package org.apache.ignite.ml.inference.parser;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.ignite.ml.inference.Model;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

/* loaded from: input_file:org/apache/ignite/ml/inference/parser/TensorFlowBaseModelParser.class */
public abstract class TensorFlowBaseModelParser<I, O> implements ModelParser<I, O, Model<I, O>> {
    private static final long serialVersionUID = 5574259553625871456L;
    private final Map<String, InputTransformer<I>> inputs = new HashMap();
    private List<String> outputNames;
    private OutputTransformer<O> outputTransformer;

    @FunctionalInterface
    /* loaded from: input_file:org/apache/ignite/ml/inference/parser/TensorFlowBaseModelParser$InputTransformer.class */
    public interface InputTransformer<I> extends Serializable {
        Tensor<?> transform(I i);
    }

    @FunctionalInterface
    /* loaded from: input_file:org/apache/ignite/ml/inference/parser/TensorFlowBaseModelParser$OutputTransformer.class */
    public interface OutputTransformer<O> extends Serializable {
        O transform(Map<String, Tensor<?>> map);
    }

    /* loaded from: input_file:org/apache/ignite/ml/inference/parser/TensorFlowBaseModelParser$TensorFlowInfModel.class */
    private class TensorFlowInfModel implements Model<I, O> {
        private final Session ses;

        TensorFlowInfModel(Session session) {
            this.ses = session;
        }

        public O predict(I i) {
            return (O) TensorFlowBaseModelParser.this.outputTransformer.transform(indexTensors(fetchAll(feedAll(this.ses.runner(), i)).run()));
        }

        private Session.Runner feedAll(Session.Runner runner, I i) {
            for (Map.Entry entry : TensorFlowBaseModelParser.this.inputs.entrySet()) {
                runner = runner.feed((String) entry.getKey(), ((InputTransformer) entry.getValue()).transform(i));
            }
            return runner;
        }

        private Session.Runner fetchAll(Session.Runner runner) {
            Iterator it = TensorFlowBaseModelParser.this.outputNames.iterator();
            while (it.hasNext()) {
                runner.fetch((String) it.next());
            }
            return runner;
        }

        private Map<String, Tensor<?>> indexTensors(List<Tensor<?>> list) {
            HashMap hashMap = new HashMap();
            Iterator it = TensorFlowBaseModelParser.this.outputNames.iterator();
            Iterator<Tensor<?>> it2 = list.iterator();
            while (it.hasNext() && it2.hasNext()) {
                hashMap.put(it.next(), it2.next());
            }
            if (it.hasNext() || it2.hasNext()) {
                throw new IllegalStateException("Outputs are incorrect");
            }
            return hashMap;
        }

        public void close() {
            this.ses.close();
        }
    }

    public Model<I, O> parse(byte[] bArr) {
        return new TensorFlowInfModel(parseModel(bArr));
    }

    public abstract Session parseModel(byte[] bArr);

    public TensorFlowBaseModelParser<I, O> withInput(String str, InputTransformer<I> inputTransformer) {
        if (this.inputs.containsKey(str)) {
            throw new IllegalArgumentException("Inputs already contains specified name [name=" + str + "]");
        }
        this.inputs.put(str, inputTransformer);
        return this;
    }

    public TensorFlowBaseModelParser<I, O> withOutput(List<String> list, OutputTransformer<O> outputTransformer) {
        if (this.outputNames != null || this.outputTransformer != null) {
            throw new IllegalArgumentException("Outputs already specified");
        }
        this.outputNames = list;
        this.outputTransformer = outputTransformer;
        return this;
    }
}
