/*
 * Decompiled with CFR 0.152.
 */
package com.oracle.svm.enterprise.hosted.onnx;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtLoggingLevel;
import ai.onnxruntime.OrtSession;
import com.oracle.svm.enterprise.hosted.ml.inferencer.Inferencer;
import java.nio.file.Path;
import java.util.Collections;
import java.util.HashMap;

public class ONNXInferencer
implements Inferencer {
    private final Env env = new Env();

    public Inferencer.Env getEnvironment() {
        return this.env;
    }

    public float[][] runInference(Inferencer.Env infEnv, Inferencer.Session infSession, String inputName, float[][] features) throws Inferencer.InferenceException {
        OrtEnvironment ortEnv = ((Env)infEnv).env;
        OrtSession ortSession = ((Session)infSession).session;
        try {
            OnnxTensor tensor = OnnxTensor.createTensor(ortEnv, features);
            OrtSession.Result output = ortSession.run(Collections.singletonMap(inputName, tensor));
            float[][] outputProbabilities = (float[][])output.get(0).getValue();
            assert (outputProbabilities.length == features.length) : "Fatal Error: can not infer the control split branches probabilities - invalid ONNX runtime lib.";
            tensor.close();
            output.close();
            return outputProbabilities;
        }
        catch (OrtException e) {
            throw new Inferencer.InferenceException("Error running ONNX inference", (Throwable)e);
        }
    }

    public float[] runInference(Inferencer.Env infEnv, Inferencer.Session infSession, float[][] x, long[][] edgeIndex, long[] isControlSplit, long[] branchesOrder) throws Inferencer.InferenceException {
        OrtEnvironment ortEnv = ((Env)infEnv).env;
        OrtSession ortSession = ((Session)infSession).session;
        try {
            OnnxTensor xTensor = OnnxTensor.createTensor(ortEnv, x);
            OnnxTensor edgeIndexTensor = OnnxTensor.createTensor(ortEnv, edgeIndex);
            OnnxTensor isControlSplitTensor = OnnxTensor.createTensor(ortEnv, isControlSplit);
            OnnxTensor branchesOrderTensor = OnnxTensor.createTensor(ortEnv, branchesOrder);
            HashMap<String, OnnxTensor> input = new HashMap<String, OnnxTensor>();
            input.put("x", xTensor);
            input.put("edge_index", edgeIndexTensor);
            input.put("is_control_split", isControlSplitTensor);
            input.put("branches_ord", branchesOrderTensor);
            OrtSession.Result output = ortSession.run(input);
            float[] outputProbabilities = (float[])output.get(0).getValue();
            xTensor.close();
            edgeIndexTensor.close();
            isControlSplitTensor.close();
            branchesOrderTensor.close();
            output.close();
            return outputProbabilities;
        }
        catch (OrtException e) {
            throw new Inferencer.InferenceException("Error running ONNX inference", (Throwable)e);
        }
    }

    static class Env
    implements Inferencer.Env {
        final OrtEnvironment env = OrtEnvironment.getEnvironment();

        Env() {
        }

        public Inferencer.Session getSession(Path path) throws Inferencer.InferenceException {
            OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
            try {
                opts.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);
                opts.setCPUArenaAllocator(false);
                opts.setInterOpNumThreads(Runtime.getRuntime().availableProcessors());
                opts.setIntraOpNumThreads(1);
                opts.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL);
                return new Session(this.env.createSession(String.valueOf(path), opts));
            }
            catch (OrtException e) {
                throw new Inferencer.InferenceException("Error initializing ONNX inference session", (Throwable)e);
            }
        }

        public void close() {
            this.env.close();
        }
    }

    static class Session
    implements Inferencer.Session {
        final OrtSession session;

        Session(OrtSession session) {
            this.session = session;
        }

        public String getInputName() {
            return this.session.getInputNames().iterator().next();
        }

        public void close() throws Inferencer.InferenceException {
            try {
                this.session.close();
            }
            catch (OrtException e) {
                throw new Inferencer.InferenceException("Error closing ONNX session", (Throwable)e);
            }
        }
    }
}

