/*
 * Decompiled with CFR 0.152.
 */
package oracle.net.nt;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.StringReader;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;
import java.util.Properties;
import java.util.concurrent.ThreadLocalRandom;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import oracle.jdbc.SecurityInformation;
import oracle.jdbc.diagnostics.Diagnosable;
import oracle.jdbc.internal.Monitor;
import oracle.jdbc.internal.OpaqueString;
import oracle.jdbc.logging.annotations.Blind;
import oracle.jdbc.logging.annotations.PropertiesBlinder;
import oracle.net.nt.SSLSocketChannel;
import oracle.net.nt.SocketChannelWrapper;

public class WSSSocketChannel
extends SocketChannelWrapper
implements Monitor {
    private static final byte WS_OPCODE_CONTINUE = 0;
    private static final byte WS_OPCODE_TEXTDATA = 1;
    private static final byte WS_OPCODE_BINARYDATA = 2;
    private static final byte WS_OPCODE_CLOSE = 8;
    private static final byte WS_OPCODE_PING = 9;
    private static final byte WS_OPCODE_PONG = 10;
    private static final int HANDSHAKE_RESPONSE_BUFFER_SIZE = 1024;
    private static final byte MASK_BYTE_OPCODE = 15;
    private static final byte MASK_BYTE_FIN = -127;
    public static final byte[] WS_DUMMY_MASK_KEY = new byte[]{0, 0, 0, 0};
    private ByteBuffer payloadBuffer;
    private boolean isClosed = false;
    private final Monitor.CloseableLock monitorLock = Monitor.newDefaultLock();

    public WSSSocketChannel(SocketChannel socketChannel, String uri, String server, int port, String authUser, OpaqueString authPwd, Diagnosable diagnosable) throws IOException {
        super(socketChannel, diagnosable);
        this.payloadBuffer = ByteBuffer.allocate(this.bufferSize);
        this.payloadBuffer.limit(0);
        this.doWSHandShake(uri, server, port, authUser, authPwd);
    }

    @Override
    public int read(ByteBuffer dstBuffer) throws IOException {
        int initialPosition = dstBuffer.position();
        if (!this.payloadBuffer.hasRemaining()) {
            this.readFromSocket();
        }
        while (this.payloadBuffer.hasRemaining() && dstBuffer.hasRemaining()) {
            dstBuffer.put(this.payloadBuffer.get());
        }
        return dstBuffer.position() - initialPosition;
    }

    @Override
    public int write(ByteBuffer srcBuffer) throws IOException {
        int payloadLength = srcBuffer.remaining();
        if (payloadLength > 0) {
            WSFrame.writeFrame(this.socketChannel, new WSBinaryFrame(srcBuffer));
        }
        return payloadLength;
    }

    public void sendPing(ByteBuffer payload) throws IOException {
        WSPingFrame pingFrame = new WSPingFrame(payload, this.socketChannel);
        WSFrame.writeFrame(this.socketChannel, pingFrame);
    }

    @Override
    public void setBufferSize(int newBufferSize) {
        if (this.bufferSize == newBufferSize) {
            return;
        }
        this.bufferSize = newBufferSize;
        ByteBuffer newPayloadBuffer = ByteBuffer.allocate(newBufferSize);
        if (this.payloadBuffer.hasRemaining()) {
            newPayloadBuffer.put(this.payloadBuffer);
        }
        newPayloadBuffer.flip();
        this.payloadBuffer = newPayloadBuffer;
    }

    private void doWSHandShake(String uri, String host, int port, String authUser, OpaqueString authPwd) throws IOException {
        WSHandshakeHelper handShakeHelper = new WSHandshakeHelper(uri, null, host, port, authUser, authPwd);
        handShakeHelper.sendHandshakeData(this.socketChannel);
        handShakeHelper.receiveHandshakeResponse(this.socketChannel);
    }

    private void readFromSocket() throws IOException {
        WSFrame frame = WSFrame.readFrame(this.socketChannel, this.payloadBuffer);
        if (frame.header.opCode == 9) {
            WSPongFrame pongFrame = new WSPongFrame(this.payloadBuffer);
            WSFrame.writeFrame(this.socketChannel, pongFrame);
            this.readFromSocket();
        } else if (frame.header.opCode == 8) {
            throw new IOException("WebSocket : Connection closed. (Error code : " + ((WSCloseFrame)frame).errorCode + ")");
        }
    }

    @Override
    public void disconnect() {
        try (Monitor.CloseableLock lock = this.acquireCloseableLock();){
            try {
                if (!this.isClosed && this.socketChannel != null && this.socketChannel.isOpen()) {
                    WSCloseFrame closeFrame = new WSCloseFrame();
                    WSFrame.writeFrame(this.socketChannel, closeFrame);
                    while (WSFrame.readFrame((SocketChannel)this.socketChannel, (ByteBuffer)this.payloadBuffer).header.opCode != 8) {
                    }
                }
            }
            catch (Exception exception) {
                // empty catch block
            }
            this.isClosed = true;
            try {
                if (this.socketChannel instanceof SocketChannelWrapper) {
                    ((SocketChannelWrapper)this.socketChannel).disconnect();
                }
            }
            catch (Exception exception) {
                // empty catch block
            }
        }
    }

    @Override
    public final Monitor.CloseableLock getMonitorLock() {
        return this.monitorLock;
    }

    public SecurityInformation.DNMatchStatus getDnMatchStatus() {
        if (this.socketChannel instanceof SSLSocketChannel) {
            return ((SSLSocketChannel)this.socketChannel).getDnMatchStatus();
        }
        return SecurityInformation.DNMatchStatus.NOT_VERIFIED;
    }

    private static class WSBinaryFrame
    extends WSFrame {
        private WSBinaryFrame(WSHeader header, ByteBuffer payloadBuffer, SocketChannel sc) {
            super(header, payloadBuffer, sc);
        }

        private WSBinaryFrame(ByteBuffer payloadBuffer) {
            this.payloadBuffer = payloadBuffer;
        }

        @Override
        void prepareForWrite() throws IOException {
            this.header = new WSHeader();
            this.header.isFinalChunk = true;
            this.header.isPayloadMasked = true;
            this.header.maskingKey = WS_DUMMY_MASK_KEY;
            this.header.opCode = (byte)2;
            this.header.payloadLength = this.payloadBuffer.remaining();
        }

        @Override
        void readPayload() throws IOException {
            this.readPayloadFromSocket();
        }
    }

    private static abstract class WSFrame {
        protected WSHeader header;
        protected ByteBuffer payloadBuffer;
        protected SocketChannel socketChannel;

        private WSFrame(WSHeader header, ByteBuffer payloadBuffer, SocketChannel sc) {
            this.header = header;
            this.payloadBuffer = payloadBuffer;
            this.socketChannel = sc;
        }

        private WSFrame() {
        }

        static WSFrame readFrame(SocketChannel socketChannel, ByteBuffer payloadBuffer) throws IOException {
            WSHeader wsHeader = new WSHeader();
            wsHeader.read(socketChannel);
            WSFrame frame = null;
            switch (wsHeader.opCode) {
                case 0: 
                case 2: {
                    frame = new WSBinaryFrame(wsHeader, payloadBuffer, socketChannel);
                    break;
                }
                case 9: {
                    frame = new WSPingFrame(wsHeader, payloadBuffer, socketChannel);
                    break;
                }
                case 10: {
                    frame = new WSPongFrame(wsHeader, payloadBuffer, socketChannel);
                    break;
                }
                case 8: {
                    frame = new WSCloseFrame(wsHeader, payloadBuffer, socketChannel);
                    break;
                }
                default: {
                    throw new IOException("Websocket : Invalid frame type : " + wsHeader.opCode);
                }
            }
            frame.readPayload();
            frame.maskOrUnmaskPayload();
            return frame;
        }

        static void writeFrame(SocketChannel socketChannel, WSFrame wsFrame) throws IOException {
            wsFrame.prepareForWrite();
            wsFrame.header.write(socketChannel);
            if (wsFrame.payloadBuffer != null) {
                wsFrame.maskOrUnmaskPayload();
                while (wsFrame.payloadBuffer.hasRemaining()) {
                    socketChannel.write(wsFrame.payloadBuffer);
                }
            }
        }

        private void maskOrUnmaskPayload() {
            if (this.header.isPayloadMasked && this.header.payloadLength > 0) {
                int intialPosition;
                for (int i = intialPosition = this.payloadBuffer.position(); i < this.payloadBuffer.limit(); ++i) {
                    byte maskedByte = this.payloadBuffer.get(i);
                    maskedByte = (byte)(maskedByte ^ this.header.maskingKey[i % 4]);
                    this.payloadBuffer.put(i, maskedByte);
                }
                this.payloadBuffer.rewind();
                this.payloadBuffer.position(intialPosition);
            }
        }

        protected void readPayloadFromSocket() throws IOException {
            this.payloadBuffer.clear();
            this.payloadBuffer.limit(this.header.payloadLength);
            if (this.header.payloadLength <= 0) {
                return;
            }
            while (this.payloadBuffer.hasRemaining()) {
                this.socketChannel.read(this.payloadBuffer);
            }
            this.payloadBuffer.flip();
        }

        abstract void readPayload() throws IOException;

        abstract void prepareForWrite() throws IOException;
    }

    private static class WSPingFrame
    extends WSFrame {
        private WSPingFrame(WSHeader header, ByteBuffer payloadBuffer, SocketChannel sc) {
            super(header, payloadBuffer, sc);
        }

        private WSPingFrame(ByteBuffer payloadBuffer, SocketChannel sc) {
            super(null, payloadBuffer, sc);
        }

        @Override
        void prepareForWrite() throws IOException {
            this.header = new WSHeader();
            this.header.isFinalChunk = true;
            this.header.isPayloadMasked = true;
            this.header.maskingKey = WS_DUMMY_MASK_KEY;
            this.header.opCode = (byte)9;
            this.header.payloadLength = this.payloadBuffer != null ? this.payloadBuffer.remaining() : 0;
        }

        @Override
        void readPayload() throws IOException {
            this.readPayloadFromSocket();
        }
    }

    private static class WSHandshakeHelper {
        private final byte[] MAGIC = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes(StandardCharsets.UTF_8);
        private final int SWITCHING_PROTOCOLS = 101;
        private final Pattern PAT_STATUS_LINE = Pattern.compile("^HTTP/1.[01]\\s+(\\d+)\\s+(.*)", 2);
        private final Pattern PAT_HEADER = Pattern.compile("([^:]+):\\s*(.*)");
        private final String uri;
        private final String queryParam;
        private final String host;
        private final int port;
        private final String key;
        private final OpaqueString httpBasicAuthKey;

        WSHandshakeHelper(String uri, String queryParam, String host, int port, String authUser, OpaqueString authPwd) {
            this.uri = uri;
            this.queryParam = queryParam;
            this.host = host;
            this.port = port;
            this.key = this.genRandomKey();
            this.httpBasicAuthKey = authUser != null && authPwd != null && authPwd != OpaqueString.NULL ? this.getHTTPAuthHeader(authUser, authPwd) : null;
        }

        void sendHandshakeData(SocketChannel socketChannel) throws IOException {
            ByteBuffer handShakeRequestBuffer = ByteBuffer.wrap(this.generateUpgradeRequest().getBytes(StandardCharsets.ISO_8859_1));
            while (handShakeRequestBuffer.hasRemaining()) {
                socketChannel.write(handShakeRequestBuffer);
            }
        }

        void receiveHandshakeResponse(SocketChannel socketChannel) throws IOException {
            ByteBuffer resBuffer = ByteBuffer.allocate(1024);
            socketChannel.read(resBuffer);
            resBuffer.flip();
            String response = new String(resBuffer.array(), resBuffer.arrayOffset(), resBuffer.limit(), StandardCharsets.ISO_8859_1);
            BufferedReader responseReader = new BufferedReader(new StringReader(response));
            String currentLine = responseReader.readLine();
            this.validateStatus(currentLine);
            Properties responseHeaders = new Properties();
            currentLine = responseReader.readLine();
            while (currentLine != null && currentLine.trim().length() > 0) {
                this.parseHeader(currentLine, responseHeaders);
                currentLine = responseReader.readLine();
            }
            this.validateResponseHeaders(responseHeaders);
        }

        private String generateUpgradeRequest() {
            StringBuilder request = new StringBuilder();
            request.append("GET ");
            if (this.uri != null && this.uri.length() > 0) {
                if (!this.uri.startsWith("/")) {
                    request.append("/");
                }
                request.append(this.uri);
            } else {
                request.append("/");
            }
            if (this.queryParam != null && this.queryParam.length() != 0) {
                request.append("?").append(this.queryParam);
            }
            request.append(" HTTP/1.1\r\n");
            request.append("Host: ").append(this.host);
            if (this.port > 0) {
                request.append(':').append(this.port);
            }
            request.append("\r\n");
            request.append("Upgrade: websocket\r\n");
            request.append("Connection: Upgrade\r\n");
            request.append("Sec-WebSocket-Key: ").append(this.key).append("\r\n");
            request.append("Sec-WebSocket-Version: 13\r\n");
            request.append("Sec-WebSocket-Protocol: sqlnet\r\n");
            if (this.httpBasicAuthKey != null) {
                request.append("Authorization: " + this.httpBasicAuthKey.get() + "\r\n");
            }
            request.append("Pragma: no-cache\r\n");
            request.append("Cache-Control: no-cache\r\n");
            request.append("\r\n");
            return request.toString();
        }

        private void validateStatus(String statusLine) throws IOException {
            Matcher matcher = this.PAT_STATUS_LINE.matcher(statusLine);
            if (!matcher.matches()) {
                throw new IOException("WebSocket: Unexpected HTTP response status line [" + statusLine + "]");
            }
            int statusCode = Integer.parseInt(matcher.group(1));
            String statusReason = matcher.group(2);
            if (statusCode != 101) {
                throw new IOException("WebSocket: Unable to upgrade to websocket protocol [" + statusCode + " : " + statusReason + "]");
            }
        }

        private void validateResponseHeaders(@Blind(value=PropertiesBlinder.class) Properties responseHeaders) throws IOException {
            byte[] responseHash;
            String connection = responseHeaders.getProperty("Connection");
            if (!"upgrade".equalsIgnoreCase(connection)) {
                throw new IOException("WebSocket: value of the header Connection is  " + connection + " (expected 'upgrade')");
            }
            String upgrade = responseHeaders.getProperty("Upgrade");
            if (!"websocket".equalsIgnoreCase(upgrade)) {
                throw new IOException("WebSocket: value of the header Upgrade is  " + connection + " (expected 'websocket')");
            }
            String respHashStr = responseHeaders.getProperty("Sec-WebSocket-Accept");
            byte[] byArray = responseHash = respHashStr == null ? null : respHashStr.getBytes(StandardCharsets.ISO_8859_1);
            if (responseHash == null || responseHash.length < 20) {
                throw new IOException("Invalid Sec-WebSocket-Accept hash");
            }
            byte[] expectedHash = this.expectedAcceptHash(this.key);
            for (int i = 0; i < 20; ++i) {
                if (responseHash[i] == expectedHash[i]) continue;
                throw new IOException("Sec-WebSocket-Accept hash does not match");
            }
        }

        private OpaqueString getHTTPAuthHeader(String user, OpaqueString pwd) {
            String auth = user + ":" + pwd.get();
            byte[] encodedAuth = Base64.getEncoder().encode(auth.getBytes(StandardCharsets.ISO_8859_1));
            String authHeader = "Basic " + new String(encodedAuth, StandardCharsets.ISO_8859_1);
            return OpaqueString.newOpaqueString(authHeader);
        }

        private void parseHeader(String headerLine, @Blind(value=PropertiesBlinder.class) Properties headerProperties) {
            Matcher matcher = this.PAT_HEADER.matcher(headerLine);
            if (matcher.matches()) {
                headerProperties.setProperty(matcher.group(1), matcher.group(2));
            }
        }

        private final String genRandomKey() {
            byte[] bytes = new byte[16];
            ThreadLocalRandom.current().nextBytes(bytes);
            return Base64.getEncoder().encodeToString(bytes);
        }

        private final byte[] expectedAcceptHash(String requestKey) {
            try {
                MessageDigest md = MessageDigest.getInstance("SHA1");
                md.update(this.key.getBytes(StandardCharsets.UTF_8));
                md.update(this.MAGIC);
                return Base64.getEncoder().encode(md.digest());
            }
            catch (NoSuchAlgorithmException e) {
                throw new RuntimeException(e);
            }
        }
    }

    private static class WSHeader {
        private boolean isFinalChunk;
        private byte opCode;
        private int payloadLength;
        private boolean isPayloadMasked;
        private byte[] maskingKey;

        private WSHeader() {
        }

        void read(SocketChannel socketChannel) throws IOException {
            ByteBuffer headerBuffer = ByteBuffer.allocate(14);
            headerBuffer.limit(2);
            while (headerBuffer.hasRemaining()) {
                socketChannel.read(headerBuffer);
            }
            headerBuffer.flip();
            byte firstByte = headerBuffer.get();
            byte secondByte = headerBuffer.get();
            this.isFinalChunk = (firstByte & 0x80) != 0;
            this.opCode = (byte)(firstByte & 0xF);
            this.isPayloadMasked = (secondByte & 0x80) != 0;
            this.payloadLength = (byte)(0x7F & secondByte);
            this.readRemainingHeaderBytes(socketChannel, headerBuffer);
            if (this.payloadLength == 126) {
                this.payloadLength = headerBuffer.getShort() & 0xFFFF;
            } else if (this.payloadLength >= 127) {
                this.payloadLength = (int)headerBuffer.getLong();
            }
            if (this.isPayloadMasked) {
                this.maskingKey = new byte[4];
                headerBuffer.get(this.maskingKey);
            }
        }

        private void readRemainingHeaderBytes(SocketChannel socketChannel, ByteBuffer headerBuffer) throws IOException {
            int neededHeaderBytes = 2;
            if (this.payloadLength == 126) {
                neededHeaderBytes += 2;
            } else if (this.payloadLength >= 127) {
                neededHeaderBytes += 8;
            }
            if (this.isPayloadMasked) {
                neededHeaderBytes += 4;
            }
            headerBuffer.position(2);
            headerBuffer.limit(neededHeaderBytes);
            while (headerBuffer.hasRemaining()) {
                socketChannel.read(headerBuffer);
            }
            headerBuffer.flip();
            headerBuffer.position(2);
        }

        private void write(SocketChannel socketChannel) throws IOException {
            ByteBuffer headerBuffer = ByteBuffer.allocate(14);
            byte firstByte = this.opCode;
            if (this.isFinalChunk) {
                firstByte = (byte)(firstByte | 0x80);
            }
            byte secondByte = 0;
            if (this.isPayloadMasked) {
                secondByte = -128;
            }
            headerBuffer.put(firstByte);
            if (this.payloadLength > 65535) {
                secondByte = (byte)(secondByte | 0x7F);
                headerBuffer.put(secondByte);
                headerBuffer.putLong(this.payloadLength);
            } else if (this.payloadLength >= 126) {
                secondByte = (byte)(secondByte | 0x7E);
                headerBuffer.put(secondByte);
                headerBuffer.putShort((short)this.payloadLength);
            } else {
                if (this.payloadLength != 0) {
                    secondByte = (byte)(secondByte | this.payloadLength & 0x7F);
                }
                headerBuffer.put(secondByte);
            }
            if (this.isPayloadMasked) {
                headerBuffer.put(this.maskingKey);
            }
            headerBuffer.flip();
            socketChannel.write(headerBuffer);
        }
    }

    private static class WSPongFrame
    extends WSFrame {
        private WSPongFrame(WSHeader header, ByteBuffer payloadBuffer, SocketChannel sc) {
            super(header, payloadBuffer, sc);
        }

        private WSPongFrame(ByteBuffer payloadBuffer) {
            this.payloadBuffer = payloadBuffer;
        }

        @Override
        void prepareForWrite() throws IOException {
            this.header = new WSHeader();
            this.header.isFinalChunk = true;
            this.header.isPayloadMasked = true;
            this.header.maskingKey = WS_DUMMY_MASK_KEY;
            this.header.opCode = (byte)10;
            this.header.payloadLength = this.payloadBuffer != null ? this.payloadBuffer.remaining() : 0;
        }

        @Override
        void readPayload() throws IOException {
            this.readPayloadFromSocket();
            if (this.header.payloadLength > 0) {
                byte[] pingData = new byte[this.payloadBuffer.remaining()];
                this.payloadBuffer.get(pingData);
            }
        }
    }

    private static class WSCloseFrame
    extends WSFrame {
        int errorCode;

        private WSCloseFrame(WSHeader header, ByteBuffer payloadBuffer, SocketChannel sc) {
            super(header, payloadBuffer, sc);
        }

        private WSCloseFrame() {
        }

        @Override
        void prepareForWrite() throws IOException {
            this.header = new WSHeader();
            this.header.isFinalChunk = true;
            this.header.isPayloadMasked = true;
            this.header.maskingKey = WS_DUMMY_MASK_KEY;
            this.header.opCode = (byte)8;
            this.header.payloadLength = 0;
        }

        @Override
        void readPayload() throws IOException {
            this.readPayloadFromSocket();
            this.errorCode = this.payloadBuffer.getShort();
        }
    }
}

