package com.sun.grizzly; import com.sun.grizzly.filter.ParserProtocolFilter; import com.sun.grizzly.util.ByteBufferFactory; import com.sun.grizzly.util.OutputWriter; import com.sun.grizzly.util.WorkerThread; import com.sun.grizzly.utils.ControllerUtils; import com.sun.grizzly.utils.TCPIOClient; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.SelectableChannel; import java.util.Arrays; import junit.framework.TestCase; /** * Test a simple custom TCP Protocol parsed by Grizzly ProtocolParser API * * 4 bytes Magic Number - Just to identify Protocol * 4 bytes RequestID - Can be used to identify Replies (-1 for Exception) * 4 bytes Data length - The length of Data * n bytes Data - Message * John Vieten */ public class ProtocolParserStateTest extends TestCase { public static final int PORT = 17505; public static final int PACKETS_COUNT = 100; static int requestId = 0; public interface State { final int START = 1; final int EXPECTING_MORE_DATA = 2; final int HAS_MORE_BYTES_TO_PARSE = 3; final int MEASSAGE_PARSED = 4; final int MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE = 5; } /** * ProtocolParser Filter which parses a Custom Protocol * into a IncomingMessage. These IncomingMessages can then be used * by other Protocolchain Filters * * The Messages can be arbitrary long. * The Protocol has the following format (12 bytes Header + Data): * 4 bytes Magic Number - Just to identify Protocol * 4 bytes RequestID - Can be used to identify Replies (-1 for Exception) * 4 bytes Data length - The length of Data * n bytes Data - Message *

* The Protocol asumes that on average Messages are not too large otherwise * this implementation might not be effective. *

*

* Tries to avoid bytearray copying and therfore maps the IncomingMessage onto * the orginal Grizzly Bytebuffer. If Messages do not fit into an default * sized Grizzly ByteBuffer a larger one with enough capacity is temporarly * given to the framework. *

* This class is still under construction and for example Error Handling, * Timeouts, Recovery has to be greatly improved!!! * * * John Vieten */ public class CustomProtocol implements ProtocolParser, State { public static final int HEADER_LENGTH = 12; public static final int START_MARK = 0x77005434; public static final int DEFAULT_BUFFER_SIZE = 8192; public static final int ERROR_REQUEST_ID = -1; /** * Handle to byteBuffer which gets filled in by grizzly */ private ByteBuffer grizzlyBuffer; /** * If Message are larger than DEFAULT_BUFFER_SIZE * a new larger Grizzly ByteBuffer is created * This handle is used to give Grizzly back its original default ByteBuffer . */ private ByteBuffer restoreHandle; /** * The length of the data depicted by Protocol Header */ private int dataLength = 0; /** * Just dataLength + Header Length */ private int messageLength = 0; /* Keeps track of the bytes read by the Grizzly ReadFilter */ private int trackBytesRead = 0; /** * Keeps track of the start position of the current message in the * Grizzly ByteBuffer. * Needed because an Grizzly ByteBuffer can contaion several Messages */ int currentMessageStartPosition = 0; /** * Holds the state of this Statemachine */ private int state = State.START; /** * The Message given to other Filters up the chain */ private IncomingMessage incomingMessage; public boolean isExpectingMoreData() { return state == State.EXPECTING_MORE_DATA; } public boolean hasMoreBytesToParse() { return state == State.HAS_MORE_BYTES_TO_PARSE; } public IncomingMessage getNextMessage() { IncomingMessage tmp = incomingMessage; incomingMessage = null; switch (state) { case MEASSAGE_PARSED: resetState(); break; case MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE: state = State.HAS_MORE_BYTES_TO_PARSE; } return tmp; } public boolean hasNextMessage() { parseMessage(); switch (state) { case MEASSAGE_PARSED: case MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE: return true; default: return false; } } public void parseMessage() { switch (state) { case MEASSAGE_PARSED: case MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE: return; } switch (state) { case HAS_MORE_BYTES_TO_PARSE: // we need to do some housekeeping // because startBuffer was not called grizzlyBuffer.position(currentMessageStartPosition); case START: int startmark = grizzlyBuffer.getInt(); if (startmark != START_MARK) { incomingMessage = new IncomingMessage(new Exception("Bad Startmark")); incomingMessage.setRequestId(ERROR_REQUEST_ID); state = State.MEASSAGE_PARSED; return; } final int requestId = grizzlyBuffer.getInt(); dataLength = grizzlyBuffer.getInt(); messageLength = dataLength + HEADER_LENGTH; if (dataLength <= 0) { incomingMessage = new IncomingMessage(new Exception("Bad Message length")); incomingMessage.setRequestId(ERROR_REQUEST_ID); state = State.MEASSAGE_PARSED; return; } if ((grizzlyBuffer.capacity() - grizzlyBuffer.position()) < dataLength) { // Not enough room to store message so we // need a bigger buffer and also some extra room for // any new message chunks int newCapacity = messageLength + DEFAULT_BUFFER_SIZE; grizzlyBuffer.position(0); ByteBuffer newBuffer = ByteBufferFactory.allocateView(newCapacity, grizzlyBuffer.isDirect()); newBuffer.put(grizzlyBuffer); WorkerThread workerThread = (WorkerThread) Thread.currentThread(); workerThread.setByteBuffer(grizzlyBuffer = newBuffer); } // Now create a sliced copy of GrizzlyBuffer so that // we have start and end pointers on our wanted message grizzlyBuffer.position(HEADER_LENGTH + currentMessageStartPosition); grizzlyBuffer.limit(messageLength + currentMessageStartPosition); ByteBuffer slicedHandle = grizzlyBuffer.slice(); grizzlyBuffer.limit(grizzlyBuffer.capacity()); // Now Buffer has enough space for message incomingMessage = new IncomingMessage(dataLength, slicedHandle); incomingMessage.setRequestId(requestId); case EXPECTING_MORE_DATA: int remaining = trackBytesRead - messageLength - currentMessageStartPosition; if (remaining == 0) { // Ok we found a message state = State.MEASSAGE_PARSED; } else if (remaining > 0) { state = State.MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE; currentMessageStartPosition = currentMessageStartPosition + messageLength; } else { // we have to keep on ready bytes from the net state = State.EXPECTING_MORE_DATA; grizzlyBuffer.position(trackBytesRead); } } } public boolean releaseBuffer() { boolean saveCurrentParserState = false; if (isExpectingMoreData()) { saveCurrentParserState = true; } else { WorkerThread workerThread = (WorkerThread) Thread.currentThread(); restoreHandle.clear(); workerThread.setByteBuffer(restoreHandle); grizzlyBuffer = null; restoreHandle = null; } return saveCurrentParserState; } public void startBuffer(ByteBuffer bb) { restoreHandle = bb; trackBytesRead = bb.position(); grizzlyBuffer = bb; if (!isExpectingMoreData()) { grizzlyBuffer.position(0); resetState(); } } private void resetState() { state = State.START; dataLength = 0; messageLength = 0; currentMessageStartPosition = 0; } } /** * User: John Vieten */ public class IncomingMessage { private int requestId; private int messageLength = -1; private ByteBuffer bb; // private byte[] bbArray; private Exception exception; public Exception getException() { return exception; } public void setException(Exception exception) { this.exception = exception; } public boolean hasException() { return exception != null; } public IncomingMessage(Exception exception) { this.exception = exception; } public IncomingMessage(int messageLength, ByteBuffer sourceSliced) { this.messageLength = messageLength; // bbArray = new byte[messageLength]; bb = sourceSliced; } public int getMessageLength() { return messageLength; } public ByteBuffer getMessage() { return bb; } public int getRequestId() { return requestId; } public void setRequestId(int requestId) { this.requestId = requestId; } public void clear() { bb = null; } } public class CustomProtocolEcho implements ProtocolFilter { public boolean execute(final Context ctx) { IncomingMessage result = (IncomingMessage) ctx.removeAttribute(ProtocolParser.MESSAGE); if (result.hasException()) { writeBack(CustomProtocol.ERROR_REQUEST_ID, result.getException().getMessage().getBytes(), ctx); } else { ByteBuffer bb=result.getMessage(); byte[] bytes=new byte[result.getMessageLength()]; bb.get(bytes); writeBack(result.getRequestId(),bytes, ctx); } return false; } private void writeBack(int requestId, byte[] bytes, Context ctx) { SelectableChannel channel = ctx.getSelectionKey().channel(); int byteLength = bytes.length; ByteBuffer newBuffer = ByteBufferFactory .allocateView(CustomProtocol.HEADER_LENGTH + byteLength, false); newBuffer.putInt(CustomProtocol.START_MARK); newBuffer.putInt(requestId); newBuffer.putInt(byteLength); newBuffer.put(bytes); newBuffer.flip(); try { synchronized (channel) { OutputWriter.flushChannel(channel, newBuffer); } newBuffer.clear(); } catch (IOException e) { e.printStackTrace(); } } public boolean postExecute(Context context) throws IOException { return true; } } //-------------------------- Start Test Code ---------------------------- public void testCustomProtocolSmall() throws IOException { Controller controller = createController(PORT); TCPIOClient client = new TCPIOClient("localhost", PORT); try { ControllerUtils.startController(controller); client.connect(); sendToServer("Hello World".getBytes(), client); } finally { controller.stop(); client.close(); } } public void testCustomProtocolBig() throws IOException { Controller controller = createController(PORT); TCPIOClient client = new TCPIOClient("localhost", PORT); try { ControllerUtils.startController(controller); client.connect(); for (int i = 1; i < 7; i++) { byte[] bigMessage = new byte[8192 * i]; Arrays.fill(bigMessage, (byte) i); sendToServer(bigMessage, client); } } finally { controller.stop(); client.close(); } } public void testException() throws IOException { Controller controller = createController(PORT); TCPIOClient client = new TCPIOClient("localhost", PORT); try { ControllerUtils.startController(controller); client.connect(); byte[] data = "Hello World This should be echoed".getBytes(); client.send(data); byte[] header = new byte[CustomProtocol.HEADER_LENGTH]; client.receive(header); ByteBuffer receiveBuffer = ByteBuffer.wrap(header); int startMark = receiveBuffer.getInt(); assertTrue(startMark == CustomProtocol.START_MARK); int receiveRequestId = receiveBuffer.getInt(); assertTrue(receiveRequestId == -1); int receiveLength = receiveBuffer.getInt(); byte[] message = new byte[receiveLength]; client.receive(message); } finally { controller.stop(); client.close(); } } public static void sendToServer(byte[] data, TCPIOClient client) throws IOException { int messagelength = CustomProtocol.HEADER_LENGTH + data.length; ByteBuffer newBuffer = ByteBufferFactory.allocateView(messagelength, false); newBuffer.putInt(CustomProtocol.START_MARK); requestId++; newBuffer.putInt(requestId); newBuffer.putInt(data.length); newBuffer.put(data); newBuffer.flip(); byte[] sendBytes = new byte[messagelength]; newBuffer.get(sendBytes); client.send(sendBytes); byte[] header = new byte[CustomProtocol.HEADER_LENGTH]; client.receive(header); ByteBuffer receiveBuffer = ByteBuffer.wrap(header); int startMark = receiveBuffer.getInt(); assertTrue(startMark == CustomProtocol.START_MARK); int receiveRequestId = receiveBuffer.getInt(); assertTrue(receiveRequestId == requestId); int receiveLength = receiveBuffer.getInt(); assertTrue(receiveLength == data.length); byte[] message = new byte[receiveLength]; ByteBuffer messageBuffer = ByteBuffer.wrap(message); client.receive(message); messageBuffer.get(message); assertTrue(Arrays.equals(data, message)); } private Controller createController(int port) { Controller controller = new Controller(); TCPSelectorHandler tcpSelectorHandler = new TCPSelectorHandler(); tcpSelectorHandler.setPort(port); controller.addSelectorHandler(tcpSelectorHandler); final ProtocolFilter parserProtocolFilter = new ParserProtocolFilter() { public ProtocolParser newProtocolParser() { return new CustomProtocol(); } }; final ProtocolChain protocolChain = new DefaultProtocolChain(); protocolChain.addFilter(parserProtocolFilter); protocolChain.addFilter(new CustomProtocolEcho()); ((DefaultProtocolChain) protocolChain).setContinuousExecution(true); ProtocolChainInstanceHandler pciHandler = new DefaultProtocolChainInstanceHandler() { public ProtocolChain poll() { return protocolChain; } public boolean offer(ProtocolChain protocolChain) { return false; } }; controller.setProtocolChainInstanceHandler(pciHandler); return controller; } }