/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite3.internal.compute.executor.wasm;

import com.dylibso.chicory.compiler.MachineFactoryCompiler;
import com.dylibso.chicory.runtime.ByteBufferMemory;
import com.dylibso.chicory.runtime.ExportFunction;
import com.dylibso.chicory.runtime.ImportFunction;
import com.dylibso.chicory.runtime.Instance;
import com.dylibso.chicory.runtime.Machine;
import com.dylibso.chicory.runtime.Store;
import com.dylibso.chicory.wasm.ChicoryException;
import com.dylibso.chicory.wasm.InvalidException;
import com.dylibso.chicory.wasm.MalformedException;
import com.dylibso.chicory.wasm.Parser;
import com.dylibso.chicory.wasm.UnlinkableException;
import com.dylibso.chicory.wasm.WasmModule;
import com.dylibso.chicory.wasm.types.MemoryLimits;
import java.io.File;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import org.apache.ignite3.compute.JobExecutionContext;
import org.apache.ignite3.internal.binarytuple.BinaryTupleReader;
import org.apache.ignite3.internal.client.proto.ColumnTypeConverter;
import org.apache.ignite3.internal.compute.ComputeJobDataHolder;
import org.apache.ignite3.internal.compute.ComputeJobDataType;
import org.apache.ignite3.internal.compute.SharedComputeUtils;
import org.apache.ignite3.internal.compute.configuration.WasmConfiguration;
import org.apache.ignite3.internal.compute.configuration.WasmConfigurationUtils;
import org.apache.ignite3.internal.compute.executor.wasm.GoWasiUtils;
import org.apache.ignite3.internal.compute.executor.wasm.RustWasmBindgenUtils;
import org.apache.ignite3.internal.compute.executor.wasm.WasmCallConv;
import org.apache.ignite3.internal.compute.executor.wasm.WasmCommon;
import org.apache.ignite3.internal.compute.executor.wasm.WasmFunctionDef;
import org.apache.ignite3.internal.compute.executor.wasm.WasmType;
import org.apache.ignite3.internal.logger.IgniteLogger;
import org.apache.ignite3.internal.logger.Loggers;
import org.apache.ignite3.internal.util.StringUtils;
import org.apache.ignite3.sql.ColumnType;
import org.jetbrains.annotations.Nullable;

public class ChicoryWasmComputeExecutor {
    private static final IgniteLogger LOG = Loggers.forClass(ChicoryWasmComputeExecutor.class);
    private volatile int moduleMaxMemoryPages = 65536;
    private volatile boolean enableCompiler;

    public ChicoryWasmComputeExecutor(WasmConfiguration configuration) {
        configuration.moduleMaxMemory().listen(ctx -> {
            this.moduleMaxMemoryPages = ChicoryWasmComputeExecutor.getMemoryPages((String)ctx.newValue());
            return CompletableFuture.completedFuture(null);
        });
        configuration.enableCompiler().listen(ctx -> {
            this.enableCompiler = ChicoryWasmComputeExecutor.getEnableCompiler((Boolean)ctx.newValue());
            return CompletableFuture.completedFuture(null);
        });
        this.moduleMaxMemoryPages = ChicoryWasmComputeExecutor.getMemoryPages((String)configuration.moduleMaxMemory().value());
        this.enableCompiler = ChicoryWasmComputeExecutor.getEnableCompiler((Boolean)configuration.enableCompiler().value());
    }

    public Callable<CompletableFuture<ComputeJobDataHolder>> getJobCallable(String jobClassName, @Nullable ComputeJobDataHolder input, JobExecutionContext context) {
        return () -> this.executeJob(context, jobClassName, input);
    }

    private CompletableFuture<ComputeJobDataHolder> executeJob(JobExecutionContext context, String jobClassName, @Nullable ComputeJobDataHolder input) {
        WasmFunctionDef funcDef = WasmCommon.getFileAndFunction(context, jobClassName);
        try {
            WasmInstance inst = this.getWasmInstance(funcDef);
            ComputeJobDataHolder resultHolder = ChicoryWasmComputeExecutor.apply(inst, input);
            return CompletableFuture.completedFuture(resultHolder);
        }
        catch (MalformedException e) {
            throw WasmCommon.wasmError("WASM module '" + funcDef.file().getName() + "' is malformed: " + e.getMessage(), e);
        }
        catch (ChicoryException e) {
            throw WasmCommon.wasmError("Failed to execute WASM function '" + jobClassName + ": " + e.getMessage(), e);
        }
    }

    private static ComputeJobDataHolder apply(WasmInstance func, @Nullable ComputeJobDataHolder input) {
        if (input == null || input.data() == null) {
            throw WasmCommon.wasmError("WASM jobs do not support null arguments");
        }
        if (input.type() != ComputeJobDataType.NATIVE) {
            throw WasmCommon.wasmError("Unsupported arg type: " + input.type());
        }
        BinaryTupleReader reader = new BinaryTupleReader(3, input.data());
        long[] res = ChicoryWasmComputeExecutor.applyFunc(func, reader);
        return ChicoryWasmComputeExecutor.convertResult(func, res);
    }

    private static ComputeJobDataHolder convertResult(WasmInstance func, long[] res) {
        switch (func.def.retType()) {
            case INT8: {
                return SharedComputeUtils.marshalArgOrResult((byte)res[0], null);
            }
            case INT16: {
                return SharedComputeUtils.marshalArgOrResult((short)res[0], null);
            }
            case INT32: {
                return SharedComputeUtils.marshalArgOrResult((int)res[0], null);
            }
            case INT64: {
                return SharedComputeUtils.marshalArgOrResult(res[0], null);
            }
            case FLOAT32: {
                return SharedComputeUtils.marshalArgOrResult(Float.valueOf(Float.intBitsToFloat((int)res[0])), null);
            }
            case FLOAT64: {
                return SharedComputeUtils.marshalArgOrResult(Double.longBitsToDouble(res[0]), null);
            }
            case TIMESTAMP: {
                return SharedComputeUtils.marshalArgOrResult(Instant.ofEpochSecond(res[0]), null);
            }
            case STRING: 
            case BYTES: {
                byte[] bytes = ChicoryWasmComputeExecutor.readBytesAndDealloc(func, res);
                if (func.def.retType() == WasmType.BYTES) {
                    return SharedComputeUtils.marshalArgOrResult(bytes, null);
                }
                String str = new String(bytes, StandardCharsets.UTF_8);
                return SharedComputeUtils.marshalArgOrResult(str, null);
            }
        }
        throw WasmCommon.wasmError("Unsupported return type: " + func.def.retType());
    }

    private static long[] applyFunc(WasmInstance func, BinaryTupleReader reader) {
        long[] res;
        if (reader.hasNullValue(0)) {
            throw WasmCommon.wasmError("WASM jobs do not support null arguments");
        }
        int typeId = reader.intValue(0);
        ColumnType type = ColumnTypeConverter.fromIdOrThrow(typeId);
        int valIdx = 2;
        switch (type) {
            case INT8: 
            case INT16: 
            case INT32: 
            case INT64: {
                long val = reader.longValue(valIdx);
                res = func.func.apply(new long[]{val});
                break;
            }
            case FLOAT: {
                float val = reader.floatValue(valIdx);
                int intBits = Float.floatToIntBits(val);
                res = func.func.apply(new long[]{intBits});
                break;
            }
            case DOUBLE: {
                double val = reader.doubleValue(valIdx);
                long longBits = Double.doubleToLongBits(val);
                res = func.func.apply(new long[]{longBits});
                break;
            }
            case TIMESTAMP: {
                Instant val = reader.timestampValue(valIdx);
                if (val == null) {
                    throw WasmCommon.wasmError("WASM jobs do not support null timestamps");
                }
                res = func.func.apply(new long[]{val.getEpochSecond()});
                break;
            }
            case STRING: {
                String str = reader.stringValue(valIdx);
                if (str == null) {
                    throw WasmCommon.wasmError("WASM jobs do not support null strings");
                }
                res = ChicoryWasmComputeExecutor.applyWithBytesArg(func, str.getBytes(StandardCharsets.UTF_8));
                break;
            }
            case BYTE_ARRAY: {
                byte[] bytes = reader.bytesValue(valIdx);
                if (bytes == null) {
                    throw WasmCommon.wasmError("WASM jobs do not support null byte arrays");
                }
                res = ChicoryWasmComputeExecutor.applyWithBytesArg(func, bytes);
                break;
            }
            default: {
                throw WasmCommon.wasmError("Unsupported column type: " + type);
            }
        }
        return res;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static long[] applyWithBytesArg(WasmInstance func, byte[] bytes) {
        if (bytes.length == 0) {
            return func.def.callConv().pointerAsPair() ? func.func.apply(new long[]{0L, 0L}) : func.func.apply(new long[]{0L});
        }
        int ptr = func.alloc(bytes.length);
        try {
            func.instance.memory().write(ptr, bytes);
            long[] lArray = func.def.callConv().pointerAsPair() ? func.func.apply(new long[]{ptr, bytes.length}) : func.func.apply(new long[]{(long)bytes.length << 32 | (long)ptr});
            return lArray;
        }
        finally {
            func.dealloc(ptr, bytes.length);
        }
    }

    private static byte[] readBytesAndDealloc(WasmInstance func, long[] res) {
        int len;
        int ptr;
        if (func.def.callConv().pointerAsPair()) {
            if (res.length != 2) {
                throw WasmCommon.wasmError("Expected two i32 return values [call conv: " + func.def.callConv() + "]");
            }
            ptr = (int)res[0];
            len = (int)res[1];
        } else {
            if (res.length != 1) {
                throw WasmCommon.wasmError("Expected single i64 return value [call conv: " + func.def.callConv() + "]");
            }
            ptr = (int)(res[0] & 0xFFFFFFFFL);
            len = (int)(res[0] >> 32);
        }
        byte[] bytes = func.instance.memory().readBytes(ptr, len);
        func.dealloc(ptr, len);
        return bytes;
    }

    private WasmInstance getWasmInstance(WasmFunctionDef funcDef) {
        WasmModule module = Parser.parse((File)funcDef.file());
        Store store = new Store();
        if (funcDef.callConv() == WasmCallConv.RUST_WASM_BINDGEN) {
            RustWasmBindgenUtils.stubImports(module, store);
        } else if (funcDef.callConv() == WasmCallConv.GO_WASI_REACTOR) {
            store.addFunction((ImportFunction[])GoWasiUtils.MINIMAL_GO_WASI_FUNCS);
        }
        Function<Instance, Machine> machineFactory = this.enableCompiler ? MachineFactoryCompiler::compile : null;
        Instance instance = this.instantiate(store, module, machineFactory, funcDef);
        ExportFunction func = ChicoryWasmComputeExecutor.getExport(instance, funcDef);
        ChicoryWasmComputeExecutor.initModule(instance, funcDef);
        return new WasmInstance(instance, funcDef, func);
    }

    private static ExportFunction getExport(Instance instance, WasmFunctionDef funcDef) {
        try {
            return instance.export(funcDef.functionName());
        }
        catch (ChicoryException e) {
            throw WasmCommon.wasmError("WASM module '" + funcDef.file().getName() + "' does not have exported function '" + funcDef.functionName() + "'", e);
        }
    }

    private Instance instantiate(Store store, WasmModule module, @Nullable Function<Instance, Machine> machineFactory, WasmFunctionDef funcDef) {
        try {
            return store.instantiate("ignite-compute", imports -> Instance.builder((WasmModule)module).withMachineFactory(machineFactory).withImportValues(imports).withMemoryFactory(moduleRequest -> this.initMemory((MemoryLimits)moduleRequest, funcDef.file().getName())).build());
        }
        catch (UnlinkableException ue) {
            throw WasmCommon.wasmError("Failed to link WASM module, make sure to specify correct calling convention: " + ue.getMessage(), ue);
        }
    }

    private ByteBufferMemory initMemory(MemoryLimits moduleRequest, String moduleName) {
        int reqInitial = moduleRequest.initialPages();
        if (reqInitial > this.moduleMaxMemoryPages) {
            throw WasmCommon.wasmError("WebAssembly module '" + moduleName + "' requests too much initial memory: " + reqInitial + " pages (" + reqInitial * 65536 + " bytes). Maximum allowed by compute.wasm.moduleMaxMemory: " + this.moduleMaxMemoryPages + " pages (" + this.moduleMaxMemoryPages * 65536 + " bytes)");
        }
        return new ByteBufferMemory(new MemoryLimits(reqInitial, this.moduleMaxMemoryPages));
    }

    private static void initModule(Instance instance, WasmFunctionDef funcDef) {
        ExportFunction initFunc = ChicoryWasmComputeExecutor.getInitFunction(instance, funcDef);
        if (initFunc != null) {
            initFunc.apply(new long[0]);
        }
    }

    @Nullable
    private static ExportFunction getInitFunction(Instance instance, WasmFunctionDef funcDef) {
        try {
            return instance.export(funcDef.callConv().init());
        }
        catch (InvalidException e) {
            if (LOG.isDebugEnabled()) {
                LOG.debug("WASM module '" + funcDef.file() + "' does not have init function '" + funcDef.callConv().init() + "'", new Object[0]);
            }
            return null;
        }
    }

    private static int getMemoryPages(@Nullable String maxMemory) {
        if (StringUtils.nullOrBlank(maxMemory)) {
            maxMemory = "64m";
        }
        long maxMemoryBytes = WasmConfigurationUtils.parseMemorySize(maxMemory);
        return WasmConfigurationUtils.memorySizeToPagesValidated(maxMemoryBytes);
    }

    private static boolean getEnableCompiler(@Nullable Boolean newValue) {
        return newValue == null ? true : newValue;
    }

    private static class WasmInstance {
        private final Instance instance;
        private final ExportFunction func;
        private final WasmFunctionDef def;
        private ExportFunction alloc;
        private ExportFunction dealloc;

        private WasmInstance(Instance instance, WasmFunctionDef funcDef, ExportFunction func) {
            this.instance = instance;
            this.func = func;
            this.def = funcDef;
        }

        private int alloc(int size) {
            if (size <= 0) {
                throw WasmCommon.wasmError("Cannot allocate non-positive size: " + size);
            }
            if (this.alloc == null) {
                this.alloc = this.instance.export(this.def.callConv().alloc());
                if (this.alloc == null) {
                    throw WasmCommon.wasmError("WASM module must export 'alloc' function for allocating memory");
                }
            }
            return this.isBindgen() ? (int)this.alloc.apply(new long[]{size, 1L})[0] : (int)this.alloc.apply(new long[]{size})[0];
        }

        private void dealloc(int ptr, int len) {
            if (this.dealloc == null) {
                this.dealloc = this.instance.export(this.def.callConv().dealloc());
                if (this.dealloc == null) {
                    throw WasmCommon.wasmError("WASM module must export 'dealloc' function for deallocating memory");
                }
            }
            if (this.isBindgen()) {
                this.dealloc.apply(new long[]{ptr, len, 1L});
            } else {
                this.dealloc.apply(new long[]{ptr, len});
            }
        }

        private boolean isBindgen() {
            return this.def.callConv() == WasmCallConv.RUST_WASM_BINDGEN;
        }
    }
}

