/*
 * Decompiled with CFR 0.152.
 */
package ai.onnxruntime;

import ai.onnxruntime.OnnxJavaType;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtUtil;
import ai.onnxruntime.ValueInfo;
import java.lang.reflect.Array;
import java.nio.Buffer;
import java.util.Arrays;

public class TensorInfo
implements ValueInfo {
    public static final int MAX_DIMENSIONS = 8;
    final long[] shape;
    public final OnnxJavaType type;
    public final OnnxTensorType onnxType;

    TensorInfo(long[] lArray, OnnxJavaType onnxJavaType, OnnxTensorType onnxTensorType) {
        this.shape = lArray;
        this.type = onnxJavaType;
        this.onnxType = onnxTensorType;
    }

    public long[] getShape() {
        return Arrays.copyOf(this.shape, this.shape.length);
    }

    public String toString() {
        return "TensorInfo(javaType=" + this.type.toString() + ",onnxType=" + this.onnxType.toString() + ",shape=" + Arrays.toString(this.shape) + ")";
    }

    public boolean isScalar() {
        return this.shape.length == 0;
    }

    private boolean validateShape() {
        return OrtUtil.validateShape(this.shape);
    }

    public Object makeCarrier() throws OrtException {
        if (!this.validateShape()) {
            throw new OrtException("This tensor is not representable in Java, it's too big - shape = " + Arrays.toString(this.shape));
        }
        switch (this.type) {
            case FLOAT: {
                return OrtUtil.newFloatArray(this.shape);
            }
            case DOUBLE: {
                return OrtUtil.newDoubleArray(this.shape);
            }
            case INT8: 
            case UINT8: {
                return OrtUtil.newByteArray(this.shape);
            }
            case INT16: {
                return OrtUtil.newShortArray(this.shape);
            }
            case INT32: {
                return OrtUtil.newIntArray(this.shape);
            }
            case INT64: {
                return OrtUtil.newLongArray(this.shape);
            }
            case BOOL: {
                return OrtUtil.newBooleanArray(this.shape);
            }
            case STRING: {
                return new String[(int)OrtUtil.elementCount(this.shape)];
            }
            case UNKNOWN: {
                throw new OrtException("Can't construct a carrier for an invalid type.");
            }
        }
        throw new OrtException("Unsupported type - " + (Object)((Object)this.type));
    }

    public static TensorInfo constructFromJavaArray(Object object) throws OrtException {
        Class<?> clazz = object.getClass();
        if (!clazz.isArray()) {
            OnnxJavaType onnxJavaType = OnnxJavaType.mapFromClass(clazz);
            if (onnxJavaType == OnnxJavaType.UNKNOWN) {
                throw new OrtException("Cannot convert " + clazz + " to a OnnxTensor.");
            }
            return new TensorInfo(new long[0], onnxJavaType, OnnxTensorType.mapFromJavaType(onnxJavaType));
        }
        int n2 = 0;
        while (clazz.isArray()) {
            clazz = clazz.getComponentType();
            ++n2;
        }
        if (!clazz.isPrimitive() && !clazz.equals(String.class)) {
            throw new OrtException("Cannot create an OnnxTensor from a base type of " + clazz);
        }
        if (n2 > 8) {
            throw new OrtException("Cannot create an OnnxTensor with more than 8 dimensions. Found " + n2 + " dimensions.");
        }
        OnnxJavaType onnxJavaType = OnnxJavaType.mapFromClass(clazz);
        long[] lArray = new long[n2];
        TensorInfo.extractShape(lArray, 0, object);
        return new TensorInfo(lArray, onnxJavaType, OnnxTensorType.mapFromJavaType(onnxJavaType));
    }

    public static TensorInfo constructFromBuffer(Buffer buffer, long[] lArray, OnnxJavaType onnxJavaType) throws OrtException {
        long l2;
        if (onnxJavaType == OnnxJavaType.STRING || onnxJavaType == OnnxJavaType.UNKNOWN) {
            throw new OrtException("Cannot create a tensor from a string or unknown buffer.");
        }
        long l3 = OrtUtil.elementCount(lArray);
        if (l3 != (l2 = (long)buffer.remaining())) {
            throw new OrtException("Shape " + Arrays.toString(lArray) + ", requires " + l3 + " elements but the buffer has " + l2 + " elements.");
        }
        return new TensorInfo(Arrays.copyOf(lArray, lArray.length), onnxJavaType, OnnxTensorType.mapFromJavaType(onnxJavaType));
    }

    private static void extractShape(long[] lArray, int n2, Object object) throws OrtException {
        if (lArray.length != n2) {
            int n3 = Array.getLength(object);
            if (n3 == 0) {
                throw new OrtException("Supplied array has a zero dimension at " + n2 + ", all dimensions must be positive");
            }
            if (lArray[n2] == 0L) {
                lArray[n2] = n3;
            } else if (lArray[n2] != (long)n3) {
                throw new OrtException("Supplied array is ragged, expected " + lArray[n2] + ", found " + n3);
            }
            for (int i2 = 0; i2 < n3; ++i2) {
                TensorInfo.extractShape(lArray, n2 + 1, Array.get(object, i2));
            }
        }
    }

    public static enum OnnxTensorType {
        ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED(0),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8(1),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8(2),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16(3),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16(4),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32(5),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32(6),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64(7),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64(8),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16(9),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT(10),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE(11),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING(12),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL(13),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64(14),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128(15),
        ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16(16);

        public final int value;
        private static final OnnxTensorType[] values;

        private OnnxTensorType(int n3) {
            this.value = n3;
        }

        public static OnnxTensorType mapFromInt(int n2) {
            if (n2 > 0 && n2 < values.length) {
                return values[n2];
            }
            return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
        }

        public static OnnxTensorType mapFromJavaType(OnnxJavaType onnxJavaType) {
            switch (onnxJavaType) {
                case FLOAT: {
                    return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
                }
                case DOUBLE: {
                    return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
                }
                case INT8: {
                    return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
                }
                case UINT8: {
                    return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
                }
                case INT16: {
                    return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
                }
                case INT32: {
                    return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
                }
                case INT64: {
                    return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
                }
                case BOOL: {
                    return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
                }
                case STRING: {
                    return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
                }
            }
            return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
        }

        static {
            values = new OnnxTensorType[17];
            OnnxTensorType[] onnxTensorTypeArray = OnnxTensorType.values();
            int n2 = onnxTensorTypeArray.length;
            for (int i2 = 0; i2 < n2; ++i2) {
                OnnxTensorType onnxTensorType;
                OnnxTensorType.values[onnxTensorType.value] = onnxTensorType = onnxTensorTypeArray[i2];
            }
        }
    }
}

