package com.sun.grizzly;

import com.sun.grizzly.filter.ParserProtocolFilter;
import com.sun.grizzly.util.ByteBufferFactory;
import com.sun.grizzly.util.OutputWriter;
import com.sun.grizzly.util.WorkerThread;
import com.sun.grizzly.utils.ControllerUtils;
import com.sun.grizzly.utils.TCPIOClient;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SelectableChannel;
import java.util.Arrays;
import junit.framework.TestCase;

/**
 * Test a simple custom TCP Protocol parsed by Grizzly ProtocolParser API
 *
 * 4 bytes Magic Number - Just to identify Protocol
 * 4 bytes RequestID    - Can be used to identify Replies (-1 for Exception)
 * 4 bytes Data length  - The length of Data
 * n bytes Data         - Message

 * John Vieten
 */
public class ProtocolParserStateTest extends TestCase {
    public static final int PORT = 17505;
    public static final int PACKETS_COUNT = 100;
    static int requestId = 0;



    public interface State {
        final int START = 1;
        final int EXPECTING_MORE_DATA = 2;
        final int HAS_MORE_BYTES_TO_PARSE = 3;
        final int MEASSAGE_PARSED = 4;
        final int MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE = 5;

    }

    /**
     * ProtocolParser Filter which parses a Custom Protocol
     * into a IncomingMessage. These IncomingMessages can then be used
     * by other Protocolchain Filters
     *
     * The Messages can be arbitrary long.
     * The Protocol has the following format (12 bytes Header + Data):
     * 4 bytes Magic Number - Just to identify Protocol
     * 4 bytes RequestID    - Can be used to identify Replies (-1 for Exception)
     * 4 bytes Data length  - The length of Data
     * n bytes Data         - Message
     * <p/>
     * The Protocol asumes that on average Messages are not too large otherwise
     * this implementation might not be effective.
     * <p/>
     * <p/>
     * Tries to avoid bytearray copying and therfore maps the IncomingMessage onto
     * the orginal Grizzly Bytebuffer.  If Messages do not fit into an default 
     * sized Grizzly ByteBuffer a larger one  with enough capacity  is temporarly 
     * given to the framework.
     * <p/>
     * This class is still under construction  and for example Error Handling,
     * Timeouts, Recovery has to be greatly improved!!!
     *
     * 
     * John Vieten
     */


    public class CustomProtocol implements ProtocolParser<IncomingMessage>, State {
        public static final int HEADER_LENGTH = 12;
        public static final int START_MARK = 0x77005434;
        public static final int DEFAULT_BUFFER_SIZE = 8192;
        public static final int ERROR_REQUEST_ID = -1;

        /**
         * Handle to byteBuffer which gets filled in by grizzly
         */
        private ByteBuffer grizzlyBuffer;
        /**
         * If Message are larger than DEFAULT_BUFFER_SIZE
         * a new larger Grizzly ByteBuffer is created
         * This handle is used to give Grizzly back its original default ByteBuffer .
         */
        private ByteBuffer restoreHandle;

        /**
         * The length of the data depicted by Protocol Header
         */
        private int dataLength = 0;

        /**
         * Just dataLength +  Header Length
         */
        private int messageLength = 0;
        /*
          Keeps track of the bytes read by the Grizzly ReadFilter
        */
        private int trackBytesRead = 0;
        /**
         * Keeps track of the start position of the current message in the 
         * Grizzly ByteBuffer.
         * Needed because an Grizzly ByteBuffer can contaion several Messages
         */
        int currentMessageStartPosition = 0;

        /**
         * Holds the state of this Statemachine
         */
        private int state = State.START;

    /**
         * The Message given to other Filters up the chain
         */

        private IncomingMessage incomingMessage;


        public boolean isExpectingMoreData() {
            return state == State.EXPECTING_MORE_DATA;
        }


        public boolean hasMoreBytesToParse() {
            return state == State.HAS_MORE_BYTES_TO_PARSE;
        }

        public IncomingMessage getNextMessage() {
            IncomingMessage tmp = incomingMessage;
            incomingMessage = null;
            switch (state) {
                case MEASSAGE_PARSED:
                    resetState();
                    break;
                case MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE:
                    state = State.HAS_MORE_BYTES_TO_PARSE;

            }
            return tmp;
        }


        public boolean hasNextMessage() {
            parseMessage();
            switch (state) {
                case MEASSAGE_PARSED:
                case MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE:
                    return true;
                default:
                    return false;
            }
        }


        public void parseMessage() {
            switch (state) {
                case MEASSAGE_PARSED:
                case MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE:
                    return;
            }


            switch (state) {

                case HAS_MORE_BYTES_TO_PARSE:
                    // we need to do some housekeeping
                    // because startBuffer was not  called

                    grizzlyBuffer.position(currentMessageStartPosition);

                case START:

                    int startmark = grizzlyBuffer.getInt();

                    if (startmark != START_MARK) {
                        incomingMessage = 
                                new IncomingMessage(new Exception("Bad Startmark"));
                        incomingMessage.setRequestId(ERROR_REQUEST_ID);
                        state = State.MEASSAGE_PARSED;
                        return;
                    }

                    final int requestId = grizzlyBuffer.getInt();

                    dataLength = grizzlyBuffer.getInt();

                    messageLength = dataLength + HEADER_LENGTH;

                    if (dataLength <= 0) {
                        incomingMessage = 
                                new IncomingMessage(new Exception("Bad Message length"));
                         incomingMessage.setRequestId(ERROR_REQUEST_ID);
                        state = State.MEASSAGE_PARSED;
                        return;
                    }

                    if ((grizzlyBuffer.capacity() - grizzlyBuffer.position()) < dataLength) {
                        // Not enough room to store message so we
                        // need a bigger buffer and also some extra room for
                        // any new message chunks
                        int newCapacity = messageLength + DEFAULT_BUFFER_SIZE;
                        grizzlyBuffer.position(0);
                        ByteBuffer newBuffer = 
                                ByteBufferFactory.allocateView(newCapacity, 
                                grizzlyBuffer.isDirect());
                        newBuffer.put(grizzlyBuffer);
                        WorkerThread workerThread = (WorkerThread) Thread.currentThread();
                        workerThread.setByteBuffer(grizzlyBuffer = newBuffer);

                    }

                    // Now create a sliced copy of GrizzlyBuffer so that
                    // we have start and end pointers on our wanted message
                    grizzlyBuffer.position(HEADER_LENGTH + currentMessageStartPosition);
                    grizzlyBuffer.limit(messageLength + currentMessageStartPosition);

                    ByteBuffer slicedHandle = grizzlyBuffer.slice();
                    grizzlyBuffer.limit(grizzlyBuffer.capacity());
                    // Now Buffer has enough space for message

                    incomingMessage = new IncomingMessage(dataLength, slicedHandle);
                    incomingMessage.setRequestId(requestId);


                case EXPECTING_MORE_DATA:
                    int remaining = trackBytesRead - messageLength - currentMessageStartPosition;
                    if (remaining == 0) {
                        // Ok we found a message
                        state = State.MEASSAGE_PARSED;
                    } else if (remaining > 0) {

                        state = State.MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE;
                        currentMessageStartPosition = currentMessageStartPosition + messageLength;

                    } else {
                        // we have to keep on ready bytes from the net
                        state = State.EXPECTING_MORE_DATA;
                        grizzlyBuffer.position(trackBytesRead);
                    }
            }

        }

        public boolean releaseBuffer() {
            boolean saveCurrentParserState = false;
            if (isExpectingMoreData()) {
                saveCurrentParserState = true;
            } else {
                WorkerThread workerThread = (WorkerThread) Thread.currentThread();
                restoreHandle.clear();
                workerThread.setByteBuffer(restoreHandle);
                grizzlyBuffer = null;
                restoreHandle = null;
            }
            return saveCurrentParserState;

        }


        public void startBuffer(ByteBuffer bb) {
            restoreHandle = bb;
            trackBytesRead = bb.position();

            grizzlyBuffer = bb;
            if (!isExpectingMoreData()) {
                grizzlyBuffer.position(0);
                resetState();
            }
        }


        private void resetState() {
            state = State.START;
            dataLength = 0;
            messageLength = 0;
            currentMessageStartPosition = 0;
        }
    }


    /**
     * User: John Vieten
     */
    public class IncomingMessage {
        private int requestId;
        private int messageLength = -1;
        private ByteBuffer bb;
        // private byte[] bbArray;

        private Exception exception;

        public Exception getException() {
            return exception;
        }

        public void setException(Exception exception) {
            this.exception = exception;
        }

        public boolean hasException() {
            return exception != null;
        }

        public IncomingMessage(Exception exception) {
            this.exception = exception;
        }

        public IncomingMessage(int messageLength, ByteBuffer sourceSliced) {
            this.messageLength = messageLength;
            // bbArray = new byte[messageLength];
            bb = sourceSliced;
        }


        public int getMessageLength() {
            return messageLength;
        }


        public ByteBuffer getMessage() {

            return bb;
        }


        public int getRequestId() {
            return requestId;
        }

        public void setRequestId(int requestId) {
            this.requestId = requestId;
        }

        public void clear() {
            bb = null;

        }


    }

    public class CustomProtocolEcho implements ProtocolFilter {
        public boolean execute(final Context ctx) {

            IncomingMessage result =
                    (IncomingMessage) ctx.removeAttribute(ProtocolParser.MESSAGE);
            if (result.hasException()) {
                writeBack(CustomProtocol.ERROR_REQUEST_ID,
                        result.getException().getMessage().getBytes(), ctx);

            } else {
                ByteBuffer bb=result.getMessage();

                byte[] bytes=new byte[result.getMessageLength()];
                bb.get(bytes);
                writeBack(result.getRequestId(),bytes, ctx);
            }

            return false;
        }

        private void writeBack(int requestId, byte[] bytes, Context ctx) {
            SelectableChannel channel = ctx.getSelectionKey().channel();

            int byteLength = bytes.length;
            ByteBuffer newBuffer
                    = ByteBufferFactory
                    .allocateView(CustomProtocol.HEADER_LENGTH + byteLength, false);
            newBuffer.putInt(CustomProtocol.START_MARK);
            newBuffer.putInt(requestId);
            newBuffer.putInt(byteLength);
            newBuffer.put(bytes);
            newBuffer.flip();

            try {

                synchronized (channel) {
                    OutputWriter.flushChannel(channel, newBuffer);
                }
                newBuffer.clear();

            } catch (IOException e) {
                e.printStackTrace();
            }
        }

        public boolean postExecute(Context context) throws IOException {

            return true;
        }

    }


        //-------------------------- Start Test Code ----------------------------


    public void testCustomProtocolSmall() throws IOException {
        Controller controller = createController(PORT);
        TCPIOClient client = new TCPIOClient("localhost", PORT);
        try {

            ControllerUtils.startController(controller);
            client.connect();
            sendToServer("Hello World".getBytes(), client);

        } finally {
            controller.stop();
            client.close();
        }

    }

    public void testCustomProtocolBig() throws IOException {

        Controller controller = createController(PORT);
        TCPIOClient client = new TCPIOClient("localhost", PORT);
        try {
            ControllerUtils.startController(controller);
            client.connect();

            for (int i = 1; i < 7; i++) {
                byte[] bigMessage = new byte[8192 * i];
                Arrays.fill(bigMessage, (byte) i);
                sendToServer(bigMessage, client);
            }

        } finally {
            controller.stop();
            client.close();
        }

    }


    public void testException() throws IOException {

        Controller controller = createController(PORT);
        TCPIOClient client = new TCPIOClient("localhost", PORT);

        try {

            ControllerUtils.startController(controller);
            client.connect();

            byte[] data = "Hello World This should be echoed".getBytes();
            client.send(data);

            byte[] header = new byte[CustomProtocol.HEADER_LENGTH];
            client.receive(header);
            ByteBuffer receiveBuffer = ByteBuffer.wrap(header);

            int startMark = receiveBuffer.getInt();
            assertTrue(startMark == CustomProtocol.START_MARK);

            int receiveRequestId = receiveBuffer.getInt();
            assertTrue(receiveRequestId == -1);


            int receiveLength = receiveBuffer.getInt();
            byte[] message = new byte[receiveLength];
            client.receive(message);


        } finally {
            controller.stop();
            client.close();
        }
    }


    public static void sendToServer(byte[] data, TCPIOClient client)
            throws IOException {


        int messagelength = CustomProtocol.HEADER_LENGTH + data.length;
        ByteBuffer newBuffer = ByteBufferFactory.allocateView(messagelength, false);
        newBuffer.putInt(CustomProtocol.START_MARK);
        requestId++;

        newBuffer.putInt(requestId);
        newBuffer.putInt(data.length);
        newBuffer.put(data);

        newBuffer.flip();

        byte[] sendBytes = new byte[messagelength];
        newBuffer.get(sendBytes);
        client.send(sendBytes);

        byte[] header = new byte[CustomProtocol.HEADER_LENGTH];
        client.receive(header);
        ByteBuffer receiveBuffer = ByteBuffer.wrap(header);

        int startMark = receiveBuffer.getInt();
        assertTrue(startMark == CustomProtocol.START_MARK);

        int receiveRequestId = receiveBuffer.getInt();
        assertTrue(receiveRequestId == requestId);

        int receiveLength = receiveBuffer.getInt();
        assertTrue(receiveLength == data.length);

        byte[] message = new byte[receiveLength];
        ByteBuffer messageBuffer = ByteBuffer.wrap(message);
        client.receive(message);
        messageBuffer.get(message);
        assertTrue(Arrays.equals(data, message));

    }


    private Controller createController(int port) {
           Controller controller = new Controller();
           TCPSelectorHandler tcpSelectorHandler = new TCPSelectorHandler();
           tcpSelectorHandler.setPort(port);
           controller.addSelectorHandler(tcpSelectorHandler);

           final ProtocolFilter parserProtocolFilter = new ParserProtocolFilter() {
               public ProtocolParser newProtocolParser() {
                   return new CustomProtocol();
               }
           };
           final ProtocolChain protocolChain = new DefaultProtocolChain();
           protocolChain.addFilter(parserProtocolFilter);
           protocolChain.addFilter(new CustomProtocolEcho());
           ((DefaultProtocolChain) protocolChain).setContinuousExecution(true);


           ProtocolChainInstanceHandler pciHandler =
                   new DefaultProtocolChainInstanceHandler() {

                       public ProtocolChain poll() {

                           return protocolChain;
                       }

                       public boolean offer(ProtocolChain protocolChain) {
                           return false;

                       }
                   };

           controller.setProtocolChainInstanceHandler(pciHandler);
           return controller;
       }



}