/*
 * Decompiled with CFR 0.152.
 */
package ai.onnxruntime;

import ai.onnxruntime.OnnxJavaType;
import ai.onnxruntime.OnnxRuntime;
import ai.onnxruntime.OnnxTensorLike;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtAllocator;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtUtil;
import ai.onnxruntime.TensorInfo;
import ai.onnxruntime.platform.Fp16Conversions;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.Arrays;
import java.util.logging.Logger;

public final class OnnxSparseTensor
extends OnnxTensorLike {
    private static final Logger logger = Logger.getLogger(OnnxSparseTensor.class.getName());
    private final SparseTensorType sparseTensorType;
    private final Buffer indices;
    private final LongBuffer innerIndices;
    private final Buffer values;

    OnnxSparseTensor(long nativeHandle, long allocatorHandle, int sparseType, TensorInfo info) {
        this(nativeHandle, allocatorHandle, SparseTensorType.mapFromInt(sparseType), info, null, null, null);
    }

    OnnxSparseTensor(long nativeHandle, long allocatorHandle, SparseTensorType sparseType, TensorInfo info, Buffer indices, Buffer values) {
        this(nativeHandle, allocatorHandle, sparseType, info, indices, null, values);
    }

    OnnxSparseTensor(long nativeHandle, long allocatorHandle, SparseTensorType sparseType, TensorInfo info, Buffer indices, LongBuffer innerIndices, Buffer values) {
        super(nativeHandle, allocatorHandle, info);
        this.sparseTensorType = sparseType;
        this.indices = indices;
        this.innerIndices = innerIndices;
        this.values = values;
    }

    public static <T extends Buffer> OnnxSparseTensor createSparseTensor(OrtEnvironment env, SparseTensor<T> tensor) throws OrtException {
        return OnnxSparseTensor.createSparseTensor(env, env.defaultAllocator, tensor);
    }

    static <T extends Buffer> OnnxSparseTensor createSparseTensor(OrtEnvironment env, OrtAllocator allocator, SparseTensor<T> tensor) throws OrtException {
        if (!allocator.isClosed()) {
            TensorInfo info = TensorInfo.constructFromSparseTensor(tensor);
            OnnxJavaType indicesType = tensor.getIndicesType();
            OrtUtil.BufferTuple indicesTuple = OrtUtil.prepareBuffer(tensor.getIndices(), indicesType);
            OrtUtil.BufferTuple valuesTuple = OrtUtil.prepareBuffer(tensor.getValues(), info.type);
            if (!(indicesTuple.data instanceof LongBuffer) && !(indicesTuple.data instanceof IntBuffer)) {
                throw new IllegalStateException("Unexpected type of indices buffer, found " + indicesTuple.data.getClass() + ", expected IntBuffer or LongBuffer");
            }
            switch (tensor.getSparsityType().ordinal()) {
                case 1: 
                case 3: {
                    return new OnnxSparseTensor(OnnxSparseTensor.createSparseTensorFromBuffer(OnnxRuntime.ortApiHandle, allocator.handle, indicesTuple.data, indicesTuple.pos, indicesTuple.size, valuesTuple.data, valuesTuple.pos, info.shape, tensor.getIndicesShape(), tensor.getValuesShape(), info.onnxType.value, tensor.getSparsityType().value), allocator.handle, tensor.getSparsityType(), info, indicesTuple.data, valuesTuple.data);
                }
                case 2: {
                    OrtUtil.BufferTuple innerIndicesTuple = OrtUtil.prepareBuffer(((CSRCTensor)tensor).getInnerIndices(), indicesType);
                    return new OnnxSparseTensor(OnnxSparseTensor.createCSRCSparseTensorFromBuffer(OnnxRuntime.ortApiHandle, allocator.handle, indicesTuple.data, indicesTuple.pos, indicesTuple.size, innerIndicesTuple.data, innerIndicesTuple.pos, innerIndicesTuple.size, valuesTuple.data, valuesTuple.pos, info.shape, tensor.getValuesShape(), info.onnxType.value), allocator.handle, tensor.getSparsityType(), info, indicesTuple.data, (LongBuffer)innerIndicesTuple.data, valuesTuple.data);
                }
            }
            throw new IllegalArgumentException("Cannot create an UNDEFINED sparse tensor.");
        }
        throw new IllegalStateException("Trying to create an OnnxSparseTensor on a closed OrtAllocator.");
    }

    @Override
    public OnnxValue.OnnxValueType getType() {
        return OnnxValue.OnnxValueType.ONNX_TYPE_SPARSETENSOR;
    }

    @Override
    public SparseTensor<? extends Buffer> getValue() throws OrtException {
        this.checkClosed();
        Buffer buffer = this.getValuesBuffer();
        long[] indicesShape = this.getIndicesShape(OnnxRuntime.ortApiHandle, this.nativeHandle);
        switch (this.sparseTensorType.ordinal()) {
            case 1: {
                return new COOTensor((LongBuffer)this.getIndicesBuffer(), indicesShape, buffer, this.info.shape, this.info.type, buffer.remaining());
            }
            case 2: {
                return new CSRCTensor((LongBuffer)this.getIndicesBuffer(), this.getInnerIndicesBuffer(), buffer, this.info.shape, this.info.type, buffer.remaining());
            }
            case 3: {
                long[] valuesShape = this.getValuesShape(OnnxRuntime.ortApiHandle, this.nativeHandle);
                return new BlockSparseTensor((IntBuffer)this.getIndicesBuffer(), indicesShape, buffer, valuesShape, this.info.shape, this.info.type, (long)buffer.remaining());
            }
        }
        throw new IllegalStateException("Undefined sparsity type in this sparse tensor.");
    }

    @Override
    public synchronized void close() {
        if (!this.closed) {
            this.close(OnnxRuntime.ortApiHandle, this.nativeHandle);
            this.closed = true;
        } else {
            logger.warning("Closing an already closed OnnxSparseTensor.");
        }
    }

    public SparseTensorType getSparseTensorType() {
        return this.sparseTensorType;
    }

    public Buffer getIndicesBuffer() {
        this.checkClosed();
        switch (this.sparseTensorType.ordinal()) {
            case 1: 
            case 2: {
                LongBuffer longBuf = this.getIndicesBuffer(OnnxRuntime.ortApiHandle, this.nativeHandle).order(ByteOrder.nativeOrder()).asLongBuffer();
                LongBuffer output = LongBuffer.allocate(longBuf.capacity());
                output.put(longBuf);
                output.rewind();
                return output;
            }
            case 3: {
                IntBuffer intBuf = this.getIndicesBuffer(OnnxRuntime.ortApiHandle, this.nativeHandle).order(ByteOrder.nativeOrder()).asIntBuffer();
                IntBuffer output = IntBuffer.allocate(intBuf.capacity());
                output.put(intBuf);
                output.rewind();
                return output;
            }
        }
        throw new IllegalStateException("UNDEFINED sparse tensor type.");
    }

    public LongBuffer getInnerIndicesBuffer() {
        this.checkClosed();
        if (this.sparseTensorType == SparseTensorType.CSRC) {
            LongBuffer buf = this.getInnerIndicesBuffer(OnnxRuntime.ortApiHandle, this.nativeHandle).order(ByteOrder.nativeOrder()).asLongBuffer();
            LongBuffer output = LongBuffer.allocate(buf.capacity());
            output.put(buf);
            output.rewind();
            return output;
        }
        throw new IllegalStateException("Inner indices are only available for CSRC sparse tensors, this sparse tensor is " + (Object)((Object)this.sparseTensorType));
    }

    public Buffer getValuesBuffer() {
        this.checkClosed();
        ByteBuffer buffer = this.getValuesBuffer(OnnxRuntime.ortApiHandle, this.nativeHandle).order(ByteOrder.nativeOrder());
        switch (this.info.type) {
            case FLOAT: {
                FloatBuffer floatBuf = buffer.asFloatBuffer();
                FloatBuffer output = FloatBuffer.allocate(floatBuf.capacity());
                output.put(floatBuf);
                output.rewind();
                return output;
            }
            case FLOAT16: {
                ShortBuffer shortBuffer = buffer.asShortBuffer();
                return Fp16Conversions.convertFp16BufferToFloatBuffer(shortBuffer);
            }
            case BFLOAT16: {
                ShortBuffer shortBuffer = buffer.asShortBuffer();
                return Fp16Conversions.convertBf16BufferToFloatBuffer(shortBuffer);
            }
            case DOUBLE: {
                DoubleBuffer doubleBuf = buffer.asDoubleBuffer();
                DoubleBuffer output = DoubleBuffer.allocate(doubleBuf.capacity());
                output.put(doubleBuf);
                output.rewind();
                return output;
            }
            case INT16: {
                ShortBuffer shortBuf = buffer.asShortBuffer();
                ShortBuffer output = ShortBuffer.allocate(shortBuf.capacity());
                output.put(shortBuf);
                output.rewind();
                return output;
            }
            case INT32: {
                IntBuffer intBuf = buffer.asIntBuffer();
                IntBuffer output = IntBuffer.allocate(intBuf.capacity());
                output.put(intBuf);
                output.rewind();
                return output;
            }
            case INT64: {
                LongBuffer longBuf = buffer.asLongBuffer();
                LongBuffer output = LongBuffer.allocate(longBuf.capacity());
                output.put(longBuf);
                output.rewind();
                return output;
            }
            case BOOL: 
            case INT8: 
            case UINT8: {
                ByteBuffer output = ByteBuffer.allocate(buffer.capacity());
                output.put(buffer);
                output.rewind();
                return output;
            }
            case STRING: {
                throw new IllegalStateException("Unsupported data type String");
            }
        }
        throw new IllegalStateException("Unsupported data type");
    }

    public long[] getIndicesShape() {
        this.checkClosed();
        return this.getIndicesShape(OnnxRuntime.ortApiHandle, this.nativeHandle);
    }

    public long[] getInnerIndicesShape() {
        this.checkClosed();
        if (this.sparseTensorType == SparseTensorType.CSRC) {
            return this.getInnerIndicesShape(OnnxRuntime.ortApiHandle, this.nativeHandle);
        }
        throw new IllegalStateException("Inner indices are only available for CSRC sparse tensors, this sparse tensor is " + (Object)((Object)this.sparseTensorType));
    }

    public long[] getValuesShape() {
        this.checkClosed();
        return this.getValuesShape(OnnxRuntime.ortApiHandle, this.nativeHandle);
    }

    private native long[] getIndicesShape(long var1, long var3);

    private native long[] getInnerIndicesShape(long var1, long var3);

    private native long[] getValuesShape(long var1, long var3);

    private native ByteBuffer getIndicesBuffer(long var1, long var3);

    private native ByteBuffer getInnerIndicesBuffer(long var1, long var3);

    private native ByteBuffer getValuesBuffer(long var1, long var3);

    private native void close(long var1, long var3);

    private static native long createCSRCSparseTensorFromBuffer(long var0, long var2, Buffer var4, int var5, long var6, Buffer var8, int var9, long var10, Buffer var12, int var13, long[] var14, long[] var15, int var16) throws OrtException;

    private static native long createSparseTensorFromBuffer(long var0, long var2, Buffer var4, int var5, long var6, Buffer var8, int var9, long[] var10, long[] var11, long[] var12, int var13, int var14) throws OrtException;

    public static enum SparseTensorType {
        UNDEFINED(0),
        COO(1),
        CSRC(2),
        BLOCK_SPARSE(4);

        public final int value;
        private static final SparseTensorType[] values;

        private SparseTensorType(int value) {
            this.value = value;
        }

        public static SparseTensorType mapFromInt(int value) {
            if (value > 0 && value < values.length) {
                return values[value];
            }
            return UNDEFINED;
        }

        static {
            values = new SparseTensorType[5];
            SparseTensorType.values[0] = UNDEFINED;
            SparseTensorType.values[1] = COO;
            SparseTensorType.values[2] = CSRC;
            SparseTensorType.values[3] = UNDEFINED;
            SparseTensorType.values[4] = BLOCK_SPARSE;
        }
    }

    public static abstract class SparseTensor<T extends Buffer> {
        private final long[] indicesShape;
        private final long[] valuesShape;
        private final long[] denseShape;
        private final OnnxJavaType type;
        private final long numNonZero;
        final T indices;
        final Buffer values;

        SparseTensor(T indices, long[] indicesShape, Buffer values, long[] valuesShape, long[] denseShape, OnnxJavaType type, long numNonZero) {
            this.indices = indices;
            this.indicesShape = indicesShape;
            this.values = values;
            this.valuesShape = valuesShape;
            this.denseShape = denseShape;
            this.type = type;
            this.numNonZero = numNonZero;
            if ((long)values.remaining() != numNonZero) {
                throw new IllegalArgumentException("Expected numNonZero and data.remaining to be equal, found " + numNonZero + " and " + values.remaining() + " respectively");
            }
            if (type == OnnxJavaType.STRING) {
                throw new IllegalArgumentException("String SparseTensors are not supported.");
            }
        }

        public long[] getDenseShape() {
            return this.denseShape;
        }

        public OnnxJavaType getType() {
            return this.type;
        }

        public long getNumNonZeroElements() {
            return this.numNonZero;
        }

        public T getIndices() {
            return this.indices;
        }

        public Buffer getValues() {
            return this.values;
        }

        public long[] getValuesShape() {
            return this.valuesShape;
        }

        public long[] getIndicesShape() {
            return this.indicesShape;
        }

        public abstract SparseTensorType getSparsityType();

        public abstract OnnxJavaType getIndicesType();
    }

    public static final class CSRCTensor
    extends SparseTensor<LongBuffer> {
        private final LongBuffer innerIndices;

        public CSRCTensor(LongBuffer outerIndices, LongBuffer innerIndices, Buffer values, long[] denseShape, OnnxJavaType type, long numNonZero) {
            super(outerIndices, new long[]{outerIndices.remaining()}, values, new long[]{numNonZero}, denseShape, type, numNonZero);
            this.innerIndices = innerIndices;
            long expectedRows = denseShape[0] + 1L;
            if ((long)outerIndices.remaining() != expectedRows) {
                throw new IllegalArgumentException("Outer indices should be equal to the number of rows + 1 in the dense shape, found " + outerIndices.remaining() + ", expected " + expectedRows);
            }
            if ((long)innerIndices.remaining() != numNonZero) {
                throw new IllegalArgumentException("Inner indices should be equal to the number of non-zero elements, found " + innerIndices.remaining() + ", expected " + numNonZero);
            }
        }

        public long[] getInnerIndicesShape() {
            return new long[]{this.innerIndices.remaining()};
        }

        public LongBuffer getInnerIndices() {
            return this.innerIndices;
        }

        @Override
        public OnnxJavaType getIndicesType() {
            return OnnxJavaType.INT64;
        }

        @Override
        public SparseTensorType getSparsityType() {
            return SparseTensorType.CSRC;
        }
    }

    public static final class COOTensor
    extends SparseTensor<LongBuffer> {
        public COOTensor(LongBuffer indices, long[] indicesShape, Buffer values, long[] denseShape, OnnxJavaType type, long numNonZero) {
            super(indices, indicesShape, values, new long[]{numNonZero}, denseShape, type, numNonZero);
            if (indicesShape.length > 2 || indicesShape.length == 0 || indicesShape[0] != numNonZero) {
                throw new IllegalArgumentException("Invalid indices shape, expected [numNonZero, dimension] or [numNonZero] found " + Arrays.toString(indicesShape));
            }
            long elementCount = OrtUtil.elementCount(indicesShape);
            if (elementCount != (long)indices.remaining()) {
                throw new IllegalArgumentException("Unexpected number of indices found in buffer, expected " + elementCount + " found " + indices.remaining());
            }
            if ((long)values.remaining() != numNonZero) {
                throw new IllegalArgumentException("Expected data.remaining() - " + values.remaining() + " to equal numNonZero - " + numNonZero);
            }
        }

        @Override
        public OnnxJavaType getIndicesType() {
            return OnnxJavaType.INT64;
        }

        @Override
        public SparseTensorType getSparsityType() {
            return SparseTensorType.COO;
        }
    }

    public static final class BlockSparseTensor
    extends SparseTensor<IntBuffer> {
        public BlockSparseTensor(IntBuffer indices, long[] indicesShape, Buffer values, long[] valuesShape, long[] denseShape, OnnxJavaType type, long numNonZero) {
            super(indices, indicesShape, values, valuesShape, denseShape, type, numNonZero);
            if (OrtUtil.elementCount(valuesShape) != numNonZero) {
                throw new IllegalArgumentException("Expected " + numNonZero + " entries in the data shape, found " + Arrays.toString(valuesShape));
            }
            if (numNonZero != (long)values.remaining()) {
                throw new IllegalArgumentException("Expected " + numNonZero + " elements in the data buffer, found " + values.remaining());
            }
            if (OrtUtil.elementCount(indicesShape) != (long)indices.remaining()) {
                throw new IllegalArgumentException("Expected " + OrtUtil.elementCount(indicesShape) + " elements in the indices buffer, found " + indices.remaining());
            }
            if (valuesShape.length < 3) {
                throw new IllegalArgumentException("Expected [numBlocks, blockSize, blockSize] or larger, but data shape was " + Arrays.toString(valuesShape));
            }
            if (indicesShape.length < 2) {
                throw new IllegalArgumentException("Expected [numBlocks, co-ordinates] or larger, but indices shape was " + Arrays.toString(indicesShape));
            }
        }

        @Override
        public OnnxJavaType getIndicesType() {
            return OnnxJavaType.INT32;
        }

        @Override
        public SparseTensorType getSparsityType() {
            return SparseTensorType.BLOCK_SPARSE;
        }
    }
}

