/*
 * Decompiled with CFR 0.152.
 */
package ai.onnxruntime;

import ai.onnxruntime.OnnxJavaType;
import ai.onnxruntime.OnnxRuntime;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxTensorLike;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtAllocator;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtUtil;
import ai.onnxruntime.TensorInfo;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.Arrays;

public final class OnnxSparseTensor
extends OnnxTensorLike {
    private final SparseTensorType sparseTensorType;
    private final Buffer indices;
    private final LongBuffer innerIndices;
    private final Buffer values;

    OnnxSparseTensor(long l2, long l3, int n2, TensorInfo tensorInfo) {
        this(l2, l3, SparseTensorType.mapFromInt(n2), tensorInfo, null, null, null);
    }

    OnnxSparseTensor(long l2, long l3, SparseTensorType sparseTensorType, TensorInfo tensorInfo, Buffer buffer, Buffer buffer2) {
        this(l2, l3, sparseTensorType, tensorInfo, buffer, null, buffer2);
    }

    OnnxSparseTensor(long l2, long l3, SparseTensorType sparseTensorType, TensorInfo tensorInfo, Buffer buffer, LongBuffer longBuffer, Buffer buffer2) {
        super(l2, l3, tensorInfo);
        this.sparseTensorType = sparseTensorType;
        this.indices = buffer;
        this.innerIndices = longBuffer;
        this.values = buffer2;
    }

    public static <T extends Buffer> OnnxSparseTensor createSparseTensor(OrtEnvironment ortEnvironment, SparseTensor<T> sparseTensor) throws OrtException {
        return OnnxSparseTensor.createSparseTensor(ortEnvironment, ortEnvironment.defaultAllocator, sparseTensor);
    }

    static <T extends Buffer> OnnxSparseTensor createSparseTensor(OrtEnvironment ortEnvironment, OrtAllocator ortAllocator, SparseTensor<T> sparseTensor) throws OrtException {
        if (!ortAllocator.isClosed()) {
            TensorInfo tensorInfo = TensorInfo.constructFromSparseTensor(sparseTensor);
            OnnxJavaType onnxJavaType = sparseTensor.getIndicesType();
            OrtUtil.BufferTuple bufferTuple = OrtUtil.prepareBuffer(sparseTensor.getIndices(), onnxJavaType);
            OrtUtil.BufferTuple bufferTuple2 = OrtUtil.prepareBuffer(sparseTensor.getValues(), tensorInfo.type);
            if (!(bufferTuple.data instanceof LongBuffer) && !(bufferTuple.data instanceof IntBuffer)) {
                throw new IllegalStateException("Unexpected type of indices buffer, found " + bufferTuple.data.getClass() + ", expected IntBuffer or LongBuffer");
            }
            switch (sparseTensor.getSparsityType()) {
                case COO: 
                case BLOCK_SPARSE: {
                    return new OnnxSparseTensor(OnnxSparseTensor.createSparseTensorFromBuffer(OnnxRuntime.ortApiHandle, ortAllocator.handle, bufferTuple.data, bufferTuple.pos, bufferTuple.size, bufferTuple2.data, bufferTuple2.pos, tensorInfo.shape, sparseTensor.getIndicesShape(), sparseTensor.getValuesShape(), tensorInfo.onnxType.value, sparseTensor.getSparsityType().value), ortAllocator.handle, sparseTensor.getSparsityType(), tensorInfo, bufferTuple.data, bufferTuple2.data);
                }
                case CSRC: {
                    OrtUtil.BufferTuple bufferTuple3 = OrtUtil.prepareBuffer(((CSRCTensor)sparseTensor).getInnerIndices(), onnxJavaType);
                    return new OnnxSparseTensor(OnnxSparseTensor.createCSRCSparseTensorFromBuffer(OnnxRuntime.ortApiHandle, ortAllocator.handle, bufferTuple.data, bufferTuple.pos, bufferTuple.size, bufferTuple3.data, bufferTuple3.pos, bufferTuple3.size, bufferTuple2.data, bufferTuple2.pos, tensorInfo.shape, sparseTensor.getValuesShape(), tensorInfo.onnxType.value), ortAllocator.handle, sparseTensor.getSparsityType(), tensorInfo, bufferTuple.data, (LongBuffer)bufferTuple3.data, bufferTuple2.data);
                }
            }
            throw new IllegalArgumentException("Cannot create an UNDEFINED sparse tensor.");
        }
        throw new IllegalStateException("Trying to create an OnnxSparseTensor on a closed OrtAllocator.");
    }

    @Override
    public OnnxValue.OnnxValueType getType() {
        return OnnxValue.OnnxValueType.ONNX_TYPE_SPARSETENSOR;
    }

    @Override
    public SparseTensor<? extends Buffer> getValue() throws OrtException {
        Buffer buffer = this.getValuesBuffer();
        long[] lArray = this.getIndicesShape(OnnxRuntime.ortApiHandle, this.nativeHandle);
        switch (this.sparseTensorType) {
            case COO: {
                return new COOTensor((LongBuffer)this.getIndicesBuffer(), lArray, buffer, this.info.shape, this.info.type, buffer.remaining());
            }
            case CSRC: {
                return new CSRCTensor((LongBuffer)this.getIndicesBuffer(), this.getInnerIndicesBuffer(), buffer, this.info.shape, this.info.type, buffer.remaining());
            }
            case BLOCK_SPARSE: {
                long[] lArray2 = this.getValuesShape(OnnxRuntime.ortApiHandle, this.nativeHandle);
                return new BlockSparseTensor((IntBuffer)this.getIndicesBuffer(), lArray, buffer, lArray2, this.info.shape, this.info.type, (long)buffer.remaining());
            }
        }
        throw new IllegalStateException("Undefined sparsity type in this sparse tensor.");
    }

    @Override
    public void close() {
        this.close(OnnxRuntime.ortApiHandle, this.nativeHandle);
    }

    public SparseTensorType getSparseTensorType() {
        return this.sparseTensorType;
    }

    public Buffer getIndicesBuffer() {
        switch (this.sparseTensorType) {
            case COO: 
            case CSRC: {
                LongBuffer longBuffer = this.getIndicesBuffer(OnnxRuntime.ortApiHandle, this.nativeHandle).order(ByteOrder.nativeOrder()).asLongBuffer();
                LongBuffer longBuffer2 = LongBuffer.allocate(longBuffer.capacity());
                longBuffer2.put(longBuffer);
                longBuffer2.rewind();
                return longBuffer2;
            }
            case BLOCK_SPARSE: {
                IntBuffer intBuffer = this.getIndicesBuffer(OnnxRuntime.ortApiHandle, this.nativeHandle).order(ByteOrder.nativeOrder()).asIntBuffer();
                IntBuffer intBuffer2 = IntBuffer.allocate(intBuffer.capacity());
                intBuffer2.put(intBuffer);
                intBuffer2.rewind();
                return intBuffer2;
            }
        }
        throw new IllegalStateException("UNDEFINED sparse tensor type.");
    }

    public LongBuffer getInnerIndicesBuffer() {
        if (this.sparseTensorType == SparseTensorType.CSRC) {
            LongBuffer longBuffer = this.getInnerIndicesBuffer(OnnxRuntime.ortApiHandle, this.nativeHandle).order(ByteOrder.nativeOrder()).asLongBuffer();
            LongBuffer longBuffer2 = LongBuffer.allocate(longBuffer.capacity());
            longBuffer2.put(longBuffer);
            longBuffer2.rewind();
            return longBuffer2;
        }
        throw new IllegalStateException("Inner indices are only available for CSRC sparse tensors, this sparse tensor is " + (Object)((Object)this.sparseTensorType));
    }

    public Buffer getValuesBuffer() {
        ByteBuffer byteBuffer = this.getValuesBuffer(OnnxRuntime.ortApiHandle, this.nativeHandle).order(ByteOrder.nativeOrder());
        switch (this.info.type) {
            case FLOAT: {
                if (this.info.onnxType == TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
                    ShortBuffer shortBuffer = byteBuffer.asShortBuffer();
                    int n2 = shortBuffer.capacity();
                    FloatBuffer floatBuffer = FloatBuffer.allocate(n2);
                    for (int i2 = 0; i2 < n2; ++i2) {
                        floatBuffer.put(OnnxTensor.fp16ToFloat(shortBuffer.get(i2)));
                    }
                    floatBuffer.rewind();
                    return floatBuffer;
                }
                if (this.info.onnxType == TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16) {
                    throw new IllegalArgumentException("BFloat16 is not supported.");
                }
                FloatBuffer floatBuffer = byteBuffer.asFloatBuffer();
                FloatBuffer floatBuffer2 = FloatBuffer.allocate(floatBuffer.capacity());
                floatBuffer2.put(floatBuffer);
                floatBuffer2.rewind();
                return floatBuffer2;
            }
            case DOUBLE: {
                DoubleBuffer doubleBuffer = byteBuffer.asDoubleBuffer();
                DoubleBuffer doubleBuffer2 = DoubleBuffer.allocate(doubleBuffer.capacity());
                doubleBuffer2.put(doubleBuffer);
                doubleBuffer2.rewind();
                return doubleBuffer2;
            }
            case INT16: {
                ShortBuffer shortBuffer = byteBuffer.asShortBuffer();
                ShortBuffer shortBuffer2 = ShortBuffer.allocate(shortBuffer.capacity());
                shortBuffer2.put(shortBuffer);
                shortBuffer2.rewind();
                return shortBuffer2;
            }
            case INT32: {
                IntBuffer intBuffer = byteBuffer.asIntBuffer();
                IntBuffer intBuffer2 = IntBuffer.allocate(intBuffer.capacity());
                intBuffer2.put(intBuffer);
                intBuffer2.rewind();
                return intBuffer2;
            }
            case INT64: {
                LongBuffer longBuffer = byteBuffer.asLongBuffer();
                LongBuffer longBuffer2 = LongBuffer.allocate(longBuffer.capacity());
                longBuffer2.put(longBuffer);
                longBuffer2.rewind();
                return longBuffer2;
            }
            case BOOL: 
            case INT8: 
            case UINT8: {
                ByteBuffer byteBuffer2 = ByteBuffer.allocate(byteBuffer.capacity());
                byteBuffer2.put(byteBuffer);
                byteBuffer2.rewind();
                return byteBuffer2;
            }
            case STRING: {
                throw new IllegalStateException("Unsupported data type String");
            }
        }
        throw new IllegalStateException("Unsupported data type");
    }

    public long[] getIndicesShape() {
        return this.getIndicesShape(OnnxRuntime.ortApiHandle, this.nativeHandle);
    }

    public long[] getInnerIndicesShape() {
        if (this.sparseTensorType == SparseTensorType.CSRC) {
            return this.getInnerIndicesShape(OnnxRuntime.ortApiHandle, this.nativeHandle);
        }
        throw new IllegalStateException("Inner indices are only available for CSRC sparse tensors, this sparse tensor is " + (Object)((Object)this.sparseTensorType));
    }

    public long[] getValuesShape() {
        return this.getValuesShape(OnnxRuntime.ortApiHandle, this.nativeHandle);
    }

    private native long[] getIndicesShape(long var1, long var3);

    private native long[] getInnerIndicesShape(long var1, long var3);

    private native long[] getValuesShape(long var1, long var3);

    private native ByteBuffer getIndicesBuffer(long var1, long var3);

    private native ByteBuffer getInnerIndicesBuffer(long var1, long var3);

    private native ByteBuffer getValuesBuffer(long var1, long var3);

    private native void close(long var1, long var3);

    private static native long createCSRCSparseTensorFromBuffer(long var0, long var2, Buffer var4, int var5, long var6, Buffer var8, int var9, long var10, Buffer var12, int var13, long[] var14, long[] var15, int var16) throws OrtException;

    private static native long createSparseTensorFromBuffer(long var0, long var2, Buffer var4, int var5, long var6, Buffer var8, int var9, long[] var10, long[] var11, long[] var12, int var13, int var14) throws OrtException;

    public static final class BlockSparseTensor
    extends SparseTensor<IntBuffer> {
        public BlockSparseTensor(IntBuffer intBuffer, long[] lArray, Buffer buffer, long[] lArray2, long[] lArray3, OnnxJavaType onnxJavaType, long l2) {
            super(intBuffer, lArray, buffer, lArray2, lArray3, onnxJavaType, l2);
            if (OrtUtil.elementCount(lArray2) != l2) {
                throw new IllegalArgumentException("Expected " + l2 + " entries in the data shape, found " + Arrays.toString(lArray2));
            }
            if (l2 != (long)buffer.remaining()) {
                throw new IllegalArgumentException("Expected " + l2 + " elements in the data buffer, found " + buffer.remaining());
            }
            if (OrtUtil.elementCount(lArray) != (long)intBuffer.remaining()) {
                throw new IllegalArgumentException("Expected " + OrtUtil.elementCount(lArray) + " elements in the indices buffer, found " + intBuffer.remaining());
            }
            if (lArray2.length < 3) {
                throw new IllegalArgumentException("Expected [numBlocks, blockSize, blockSize] or larger, but data shape was " + Arrays.toString(lArray2));
            }
            if (lArray.length < 2) {
                throw new IllegalArgumentException("Expected [numBlocks, co-ordinates] or larger, but indices shape was " + Arrays.toString(lArray));
            }
        }

        @Override
        public OnnxJavaType getIndicesType() {
            return OnnxJavaType.INT32;
        }

        @Override
        public SparseTensorType getSparsityType() {
            return SparseTensorType.BLOCK_SPARSE;
        }
    }

    public static final class CSRCTensor
    extends SparseTensor<LongBuffer> {
        private final LongBuffer innerIndices;

        public CSRCTensor(LongBuffer longBuffer, LongBuffer longBuffer2, Buffer buffer, long[] lArray, OnnxJavaType onnxJavaType, long l2) {
            super(longBuffer, new long[]{longBuffer.remaining()}, buffer, new long[]{l2}, lArray, onnxJavaType, l2);
            this.innerIndices = longBuffer2;
            long l3 = lArray[0] + 1L;
            if ((long)longBuffer.remaining() != l3) {
                throw new IllegalArgumentException("Outer indices should be equal to the number of rows + 1 in the dense shape, found " + longBuffer.remaining() + ", expected " + l3);
            }
            if ((long)longBuffer2.remaining() != l2) {
                throw new IllegalArgumentException("Inner indices should be equal to the number of non-zero elements, found " + longBuffer2.remaining() + ", expected " + l2);
            }
        }

        public long[] getInnerIndicesShape() {
            return new long[]{this.innerIndices.remaining()};
        }

        public LongBuffer getInnerIndices() {
            return this.innerIndices;
        }

        @Override
        public OnnxJavaType getIndicesType() {
            return OnnxJavaType.INT64;
        }

        @Override
        public SparseTensorType getSparsityType() {
            return SparseTensorType.CSRC;
        }
    }

    public static final class COOTensor
    extends SparseTensor<LongBuffer> {
        public COOTensor(LongBuffer longBuffer, long[] lArray, Buffer buffer, long[] lArray2, OnnxJavaType onnxJavaType, long l2) {
            super(longBuffer, lArray, buffer, new long[]{l2}, lArray2, onnxJavaType, l2);
            if (lArray.length > 2 || lArray.length == 0 || lArray[0] != l2) {
                throw new IllegalArgumentException("Invalid indices shape, expected [numNonZero, dimension] or [numNonZero] found " + Arrays.toString(lArray));
            }
            long l3 = OrtUtil.elementCount(lArray);
            if (l3 != (long)longBuffer.remaining()) {
                throw new IllegalArgumentException("Unexpected number of indices found in buffer, expected " + l3 + " found " + longBuffer.remaining());
            }
            if ((long)buffer.remaining() != l2) {
                throw new IllegalArgumentException("Expected data.remaining() - " + buffer.remaining() + " to equal numNonZero - " + l2);
            }
        }

        @Override
        public OnnxJavaType getIndicesType() {
            return OnnxJavaType.INT64;
        }

        @Override
        public SparseTensorType getSparsityType() {
            return SparseTensorType.COO;
        }
    }

    public static abstract class SparseTensor<T extends Buffer> {
        private final long[] indicesShape;
        private final long[] valuesShape;
        private final long[] denseShape;
        private final OnnxJavaType type;
        private final long numNonZero;
        final T indices;
        final Buffer values;

        SparseTensor(T t2, long[] lArray, Buffer buffer, long[] lArray2, long[] lArray3, OnnxJavaType onnxJavaType, long l2) {
            this.indices = t2;
            this.indicesShape = lArray;
            this.values = buffer;
            this.valuesShape = lArray2;
            this.denseShape = lArray3;
            this.type = onnxJavaType;
            this.numNonZero = l2;
            if ((long)buffer.remaining() != l2) {
                throw new IllegalArgumentException("Expected numNonZero and data.remaining to be equal, found " + l2 + " and " + buffer.remaining() + " respectively");
            }
            if (onnxJavaType == OnnxJavaType.STRING) {
                throw new IllegalArgumentException("String SparseTensors are not supported.");
            }
        }

        public long[] getDenseShape() {
            return this.denseShape;
        }

        public OnnxJavaType getType() {
            return this.type;
        }

        public long getNumNonZeroElements() {
            return this.numNonZero;
        }

        public T getIndices() {
            return this.indices;
        }

        public Buffer getValues() {
            return this.values;
        }

        public long[] getValuesShape() {
            return this.valuesShape;
        }

        public long[] getIndicesShape() {
            return this.indicesShape;
        }

        public abstract SparseTensorType getSparsityType();

        public abstract OnnxJavaType getIndicesType();
    }

    public static enum SparseTensorType {
        UNDEFINED(0),
        COO(1),
        CSRC(2),
        BLOCK_SPARSE(4);

        public final int value;
        private static final SparseTensorType[] values;

        private SparseTensorType(int n3) {
            this.value = n3;
        }

        public static SparseTensorType mapFromInt(int n2) {
            if (n2 > 0 && n2 < values.length) {
                return values[n2];
            }
            return UNDEFINED;
        }

        static {
            values = new SparseTensorType[5];
            SparseTensorType.values[0] = UNDEFINED;
            SparseTensorType.values[1] = COO;
            SparseTensorType.values[2] = CSRC;
            SparseTensorType.values[3] = UNDEFINED;
            SparseTensorType.values[4] = BLOCK_SPARSE;
        }
    }
}

