package com.sun.grizzly;

import com.sun.grizzly.filter.ParserProtocolFilter;
import com.sun.grizzly.util.ByteBufferFactory;
import com.sun.grizzly.util.OutputWriter;
import com.sun.grizzly.util.WorkerThread;
import com.sun.grizzly.utils.ControllerUtils;
import com.sun.grizzly.utils.TCPIOClient;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SelectableChannel;
import java.util.Arrays;
import junit.framework.TestCase;
/**
 * Test a simple custom TCP Protocol parsed by Grizzly ProtocolParser API
 * The Messages can be arbitrary long.
 * The Protocol has the following format (12 bytes Header + Data):

 * 4 bytes Magic Number - Just to identify Protocol
 * 4 bytes RequestID    - Can be used to identify Replies (-1 for Exception)
 * 4 bytes Data length  - The length of Message
 * n bytes Data         - Payload Message
 *
 *
 * John Vieten
 */
public class ProtocolParserStateTest extends TestCase {
    public static final int PORT = 17505;
   
    static int requestId = 0;

    //-------------------------- Start Test Code ----------------------------

    public void testCustomProtocolSmall() throws IOException {
        Controller controller = createController(PORT);
        TCPIOClient client = new TCPIOClient("localhost", PORT);
        try {

            ControllerUtils.startController(controller);
            client.connect();
            sindToServer("Hello World".getBytes(), client);

        } finally {
            controller.stop();
            client.close();
        }

    }

    public void testCustomProtocolBig() throws IOException {

        Controller controller = createController(PORT);
        TCPIOClient client = new TCPIOClient("localhost", PORT);
        try {
            ControllerUtils.startController(controller);
            client.connect();

            for (int i = 1; i < 7; i++) {
                byte[] bigMessage = new byte[8192 * i];
                Arrays.fill(bigMessage, (byte) i);
                sindToServer(bigMessage, client);
            }

        } finally {
            controller.stop();
            client.close();
        }

    }


    public void testException() throws IOException {

        Controller controller = createController(PORT);
        TCPIOClient client = new TCPIOClient("localhost", PORT);

        try {

            ControllerUtils.startController(controller);
            client.connect();

            byte[] data = "Hello World This should be echoed".getBytes();
            client.send(data);

            byte[] header = new byte[CustomProtocol.HEADERLENGTH];
            client.receive(header);
            ByteBuffer receiveBuffer = ByteBuffer.wrap(header);

            int startMark = receiveBuffer.getInt();
            assertTrue(startMark == CustomProtocol.STARTMARK);

            int receiveRequestId = receiveBuffer.getInt();
            assertTrue(receiveRequestId == CustomProtocol.ERROR_REQUEST_ID);


            int receiveLength = receiveBuffer.getInt();
            byte[] message = new byte[receiveLength];
            client.receive(message);


        } finally {
            controller.stop();
            client.close();
        }
    }


    public static void sindToServer(byte[] data, TCPIOClient client) 
            throws IOException {


        int messagelength = CustomProtocol.HEADERLENGTH + data.length;
        ByteBuffer newBuffer = ByteBufferFactory.allocateView(messagelength, false);
        newBuffer.putInt(CustomProtocol.STARTMARK);
        requestId++;

        newBuffer.putInt(requestId);
        newBuffer.putInt(data.length);
        newBuffer.put(data);

        newBuffer.flip();

        byte[] sendBytes = new byte[messagelength];
        newBuffer.get(sendBytes);
        client.send(sendBytes);

        byte[] header = new byte[CustomProtocol.HEADERLENGTH];
        client.receive(header);
        ByteBuffer receiveBuffer = ByteBuffer.wrap(header);

        int startMark = receiveBuffer.getInt();
        assertTrue(startMark == CustomProtocol.STARTMARK);

        int receiveRequestId = receiveBuffer.getInt();
        assertTrue(receiveRequestId == requestId);

        int receiveLength = receiveBuffer.getInt();
        assertTrue(receiveLength == data.length);

        byte[] message = new byte[receiveLength];
        ByteBuffer messageBuffer = ByteBuffer.wrap(message);
        client.receive(message);
        messageBuffer.get(message);
        assertTrue(Arrays.equals(data, message));

    }

    //----------------------- Start Parser Code -------------------------
    private Controller createController(int port) {
        Controller controller = new Controller();
        TCPSelectorHandler tcpSelectorHandler = new TCPSelectorHandler();
        tcpSelectorHandler.setPort(port);
        controller.addSelectorHandler(tcpSelectorHandler);

        final ProtocolFilter parserProtocolFilter = new ParserProtocolFilter() {
            public ProtocolParser newProtocolParser() {
                return new CustomProtocol();
            }
        };
        final ProtocolChain protocolChain = new DefaultProtocolChain();
        protocolChain.addFilter(parserProtocolFilter);
        protocolChain.addFilter(new CustomProtocolEcho());
        ((DefaultProtocolChain) protocolChain).setContinuousExecution(true);


        ProtocolChainInstanceHandler pciHandler =
                new DefaultProtocolChainInstanceHandler() {

                    public ProtocolChain poll() {

                        return protocolChain;
                    }

                    public boolean offer(ProtocolChain protocolChain) {
                        return false;

                    }
                };

        controller.setProtocolChainInstanceHandler(pciHandler);
        return controller;
    }

    public interface State {
        final int START = 1;
        final int EXPECTING_MORE_DATA = 2;
        final int HAS_MORE_BYTES_TO_PARSE = 3;
        final int MEASSAGE_PARSED = 4;
        final int MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE = 5;

    }
  /**
   * ProtocolParser Filter which parses a Custom Protocol
   * into IncomingMessage. These IncomingMessages can then be used 
   * by other Protocolchain Filters
   * 
   * John Vieten 
   * 
   */
    
    public class CustomProtocol 
            implements ProtocolParser<IncomingMessage>, State {

        public static final int HEADERLENGTH = 12;
        public static final int STARTMARK = 0x77005434;
        public static final int ERROR_REQUEST_ID = -1;
        private int state = START;

        private IncomingMessage incomingMessage;

        public boolean isExpectingMoreData() {
            return (incomingMessage != null) 
                    && incomingMessage.isExpectingMoreData();
        }


        public boolean hasMoreBytesToParse() {
            return state == HAS_MORE_BYTES_TO_PARSE;
        }

        public IncomingMessage getNextMessage() {
            IncomingMessage tmp = incomingMessage;
            incomingMessage = null;
            switch (state) {
                case MEASSAGE_PARSED:
                    resetState();
                    break;
                case MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE:
                    state = State.HAS_MORE_BYTES_TO_PARSE;

            }
            return tmp;
        }


        public boolean hasNextMessage() {
            parseMessage();
            switch (state) {
                case MEASSAGE_PARSED:
                case MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE:
                    return true;
                default:
                    return false;
            }
        }


        public void parseMessage() {
            switch (state) {
                case MEASSAGE_PARSED:
                case MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE:
                    return;
            }

            ByteBuffer bb = getCurrentByteBuffer();
            if (!bb.hasRemaining()) {
                incomingMessage = 
                        new IncomingMessage(new Exception("Buffer Full"));
                state = State.MEASSAGE_PARSED;
                return;
            }


            switch (state) {
                case START:
                case HAS_MORE_BYTES_TO_PARSE:
                    if (bb.remaining() < HEADERLENGTH) {
                        return;
                    }
                    int startmark = bb.getInt();

                    if (startmark != STARTMARK) {
                        incomingMessage = 
                                new IncomingMessage(new Exception("Bad Startmark"));
                        state = State.MEASSAGE_PARSED;
                        return;
                    }

                    final int requestId = bb.getInt();
                    final int len = bb.getInt();

                    if (len <= 0) {
                        incomingMessage = 
                                new IncomingMessage(new Exception("Bad Messagelength"));
                        state = State.MEASSAGE_PARSED;
                        return;
                    }

                    // Now Message is valid
                    incomingMessage = new IncomingMessage(len);
                    incomingMessage.setRequestId(requestId);

                case EXPECTING_MORE_DATA:
                    if (bb.remaining() >= incomingMessage.getMessageRemaining()) {
                        int oldLimit = bb.limit();
                        bb.limit(incomingMessage.getMessageRemaining() + bb.position());
                        incomingMessage.put(bb);
                        bb.limit(oldLimit);
                        state = bb.hasRemaining() ?
                                MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE :
                                MEASSAGE_PARSED;

                    } else if ((bb.capacity() - bb.position())
                            < incomingMessage.getMessageRemaining()) {
                        incomingMessage.put(bb);
                        // can make space in  buffer since messages 
                        // are allready read  from bytebuffer
                        bb.clear();
                        state = State.EXPECTING_MORE_DATA;
                    } else {
                        state = State.EXPECTING_MORE_DATA;
                    }
            }

        }

        public boolean releaseBuffer() {
            boolean saveCurrentParserState = false;
            if (isExpectingMoreData()) {
                saveCurrentParserState = true;
            } else {
                getCurrentByteBuffer().clear();
            }
            return saveCurrentParserState;

        }

        public void startBuffer(ByteBuffer bb) {
            bb.flip();
            if (!isExpectingMoreData()) {
                resetState();
            }
        }

        private ByteBuffer getCurrentByteBuffer() {
            return ((WorkerThread) Thread.currentThread()).getByteBuffer();
        }

        private void resetState() {
            state = State.START;
        }
    }


    class IncomingMessage {
        private int requestId;
        private int messageLength = -1;
        private ByteBuffer bb;
        private byte[] bbArray;

        private Exception exception;

        public Exception getException() {
            return exception;
        }

        public void setException(Exception exception) {
            this.exception = exception;
        }

        public boolean hasException() {
            return exception != null;
        }

        public IncomingMessage(Exception exception) {
            this.exception = exception;
        }

        public IncomingMessage(int messageLength) {
            this.messageLength = messageLength;
            bbArray = new byte[messageLength];
            bb = ByteBuffer.wrap(bbArray);
        }

        public boolean isExpectingMoreData() {
            return (bb.position() < messageLength);
        }

        public int getMessageLength() {
            return messageLength;
        }

        public int getMessageRemaining() {
            return bb.remaining();
        }

        public byte[] getMessage() {
            return bbArray;
        }

        public void put(ByteBuffer message) {
            this.bb.put(message);
        }

        public int getRequestId() {
            return requestId;
        }

        public void setRequestId(int requestId) {
            this.requestId = requestId;
        }

        public void clear() {
            bb = null;

        }


    }

    public class CustomProtocolEcho implements ProtocolFilter {
        public boolean execute(final Context ctx) {

            IncomingMessage result =
                    (IncomingMessage) ctx.removeAttribute(ProtocolParser.MESSAGE);
            if (result.hasException()) {
                writeBack(CustomProtocol.ERROR_REQUEST_ID, 
                        result.getException().getMessage().getBytes(), ctx);

            } else {
                writeBack(result.getRequestId(), result.getMessage(), ctx);
            }

            return false;
        }

        private void writeBack(int requestId, byte[] bytes, Context ctx) {
            SelectableChannel channel = ctx.getSelectionKey().channel();

            int byteLength = bytes.length;
            ByteBuffer newBuffer
                    = ByteBufferFactory
                    .allocateView(CustomProtocol.HEADERLENGTH + byteLength, false);
            newBuffer.putInt(CustomProtocol.STARTMARK);
            newBuffer.putInt(requestId);
            newBuffer.putInt(byteLength);
            newBuffer.put(bytes);
            newBuffer.flip();

            try {

                synchronized (channel) {
                    OutputWriter.flushChannel(channel, newBuffer);
                }
                newBuffer.clear();

            } catch (IOException e) {
                e.printStackTrace();
            }
        }

        public boolean postExecute(Context context) throws IOException {

            return true;
        }

    }


}