package com.sun.grizzly; import com.sun.grizzly.filter.ParserProtocolFilter; import com.sun.grizzly.util.ByteBufferFactory; import com.sun.grizzly.util.OutputWriter; import com.sun.grizzly.util.WorkerThread; import com.sun.grizzly.utils.ControllerUtils; import com.sun.grizzly.utils.TCPIOClient; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.SelectableChannel; import java.util.Arrays; import junit.framework.TestCase; /** * Test a simple custom TCP Protocol parsed by Grizzly ProtocolParser API * The Messages can be arbitrary long. * The Protocol has the following format (12 bytes Header + Data): * 4 bytes Magic Number - Just to identify Protocol * 4 bytes RequestID - Can be used to identify Replies (-1 for Exception) * 4 bytes Data length - The length of Message * n bytes Data - Payload Message * * * John Vieten */ public class ProtocolParserStateTest extends TestCase { public static final int PORT = 17505; static int requestId = 0; //-------------------------- Start Test Code ---------------------------- public void testCustomProtocolSmall() throws IOException { Controller controller = createController(PORT); TCPIOClient client = new TCPIOClient("localhost", PORT); try { ControllerUtils.startController(controller); client.connect(); sindToServer("Hello World".getBytes(), client); } finally { controller.stop(); client.close(); } } public void testCustomProtocolBig() throws IOException { Controller controller = createController(PORT); TCPIOClient client = new TCPIOClient("localhost", PORT); try { ControllerUtils.startController(controller); client.connect(); for (int i = 1; i < 7; i++) { byte[] bigMessage = new byte[8192 * i]; Arrays.fill(bigMessage, (byte) i); sindToServer(bigMessage, client); } } finally { controller.stop(); client.close(); } } public void testException() throws IOException { Controller controller = createController(PORT); TCPIOClient client = new TCPIOClient("localhost", PORT); try { ControllerUtils.startController(controller); client.connect(); byte[] data = "Hello World This should be echoed".getBytes(); client.send(data); byte[] header = new byte[CustomProtocol.HEADERLENGTH]; client.receive(header); ByteBuffer receiveBuffer = ByteBuffer.wrap(header); int startMark = receiveBuffer.getInt(); assertTrue(startMark == CustomProtocol.STARTMARK); int receiveRequestId = receiveBuffer.getInt(); assertTrue(receiveRequestId == CustomProtocol.ERROR_REQUEST_ID); int receiveLength = receiveBuffer.getInt(); byte[] message = new byte[receiveLength]; client.receive(message); } finally { controller.stop(); client.close(); } } public static void sindToServer(byte[] data, TCPIOClient client) throws IOException { int messagelength = CustomProtocol.HEADERLENGTH + data.length; ByteBuffer newBuffer = ByteBufferFactory.allocateView(messagelength, false); newBuffer.putInt(CustomProtocol.STARTMARK); requestId++; newBuffer.putInt(requestId); newBuffer.putInt(data.length); newBuffer.put(data); newBuffer.flip(); byte[] sendBytes = new byte[messagelength]; newBuffer.get(sendBytes); client.send(sendBytes); byte[] header = new byte[CustomProtocol.HEADERLENGTH]; client.receive(header); ByteBuffer receiveBuffer = ByteBuffer.wrap(header); int startMark = receiveBuffer.getInt(); assertTrue(startMark == CustomProtocol.STARTMARK); int receiveRequestId = receiveBuffer.getInt(); assertTrue(receiveRequestId == requestId); int receiveLength = receiveBuffer.getInt(); assertTrue(receiveLength == data.length); byte[] message = new byte[receiveLength]; ByteBuffer messageBuffer = ByteBuffer.wrap(message); client.receive(message); messageBuffer.get(message); assertTrue(Arrays.equals(data, message)); } //----------------------- Start Parser Code ------------------------- private Controller createController(int port) { Controller controller = new Controller(); TCPSelectorHandler tcpSelectorHandler = new TCPSelectorHandler(); tcpSelectorHandler.setPort(port); controller.addSelectorHandler(tcpSelectorHandler); final ProtocolFilter parserProtocolFilter = new ParserProtocolFilter() { public ProtocolParser newProtocolParser() { return new CustomProtocol(); } }; final ProtocolChain protocolChain = new DefaultProtocolChain(); protocolChain.addFilter(parserProtocolFilter); protocolChain.addFilter(new CustomProtocolEcho()); ((DefaultProtocolChain) protocolChain).setContinuousExecution(true); ProtocolChainInstanceHandler pciHandler = new DefaultProtocolChainInstanceHandler() { public ProtocolChain poll() { return protocolChain; } public boolean offer(ProtocolChain protocolChain) { return false; } }; controller.setProtocolChainInstanceHandler(pciHandler); return controller; } public interface State { final int START = 1; final int EXPECTING_MORE_DATA = 2; final int HAS_MORE_BYTES_TO_PARSE = 3; final int MEASSAGE_PARSED = 4; final int MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE = 5; } /** * ProtocolParser Filter which parses a Custom Protocol * into IncomingMessage. These IncomingMessages can then be used * by other Protocolchain Filters * * John Vieten * */ public class CustomProtocol implements ProtocolParser, State { public static final int HEADERLENGTH = 12; public static final int STARTMARK = 0x77005434; public static final int ERROR_REQUEST_ID = -1; private int state = START; private IncomingMessage incomingMessage; public boolean isExpectingMoreData() { return (incomingMessage != null) && incomingMessage.isExpectingMoreData(); } public boolean hasMoreBytesToParse() { return state == HAS_MORE_BYTES_TO_PARSE; } public IncomingMessage getNextMessage() { IncomingMessage tmp = incomingMessage; incomingMessage = null; switch (state) { case MEASSAGE_PARSED: resetState(); break; case MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE: state = State.HAS_MORE_BYTES_TO_PARSE; } return tmp; } public boolean hasNextMessage() { parseMessage(); switch (state) { case MEASSAGE_PARSED: case MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE: return true; default: return false; } } public void parseMessage() { switch (state) { case MEASSAGE_PARSED: case MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE: return; } ByteBuffer bb = getCurrentByteBuffer(); if (!bb.hasRemaining()) { incomingMessage = new IncomingMessage(new Exception("Buffer Full")); state = State.MEASSAGE_PARSED; return; } switch (state) { case START: case HAS_MORE_BYTES_TO_PARSE: if (bb.remaining() < HEADERLENGTH) { return; } int startmark = bb.getInt(); if (startmark != STARTMARK) { incomingMessage = new IncomingMessage(new Exception("Bad Startmark")); state = State.MEASSAGE_PARSED; return; } final int requestId = bb.getInt(); final int len = bb.getInt(); if (len <= 0) { incomingMessage = new IncomingMessage(new Exception("Bad Messagelength")); state = State.MEASSAGE_PARSED; return; } // Now Message is valid incomingMessage = new IncomingMessage(len); incomingMessage.setRequestId(requestId); case EXPECTING_MORE_DATA: if (bb.remaining() >= incomingMessage.getMessageRemaining()) { int oldLimit = bb.limit(); bb.limit(incomingMessage.getMessageRemaining() + bb.position()); incomingMessage.put(bb); bb.limit(oldLimit); state = bb.hasRemaining() ? MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE : MEASSAGE_PARSED; } else if ((bb.capacity() - bb.position()) < incomingMessage.getMessageRemaining()) { incomingMessage.put(bb); // can make space in buffer since messages // are allready read from bytebuffer bb.clear(); state = State.EXPECTING_MORE_DATA; } else { state = State.EXPECTING_MORE_DATA; } } } public boolean releaseBuffer() { boolean saveCurrentParserState = false; if (isExpectingMoreData()) { saveCurrentParserState = true; } else { getCurrentByteBuffer().clear(); } return saveCurrentParserState; } public void startBuffer(ByteBuffer bb) { bb.flip(); if (!isExpectingMoreData()) { resetState(); } } private ByteBuffer getCurrentByteBuffer() { return ((WorkerThread) Thread.currentThread()).getByteBuffer(); } private void resetState() { state = State.START; } } class IncomingMessage { private int requestId; private int messageLength = -1; private ByteBuffer bb; private byte[] bbArray; private Exception exception; public Exception getException() { return exception; } public void setException(Exception exception) { this.exception = exception; } public boolean hasException() { return exception != null; } public IncomingMessage(Exception exception) { this.exception = exception; } public IncomingMessage(int messageLength) { this.messageLength = messageLength; bbArray = new byte[messageLength]; bb = ByteBuffer.wrap(bbArray); } public boolean isExpectingMoreData() { return (bb.position() < messageLength); } public int getMessageLength() { return messageLength; } public int getMessageRemaining() { return bb.remaining(); } public byte[] getMessage() { return bbArray; } public void put(ByteBuffer message) { this.bb.put(message); } public int getRequestId() { return requestId; } public void setRequestId(int requestId) { this.requestId = requestId; } public void clear() { bb = null; } } public class CustomProtocolEcho implements ProtocolFilter { public boolean execute(final Context ctx) { IncomingMessage result = (IncomingMessage) ctx.removeAttribute(ProtocolParser.MESSAGE); if (result.hasException()) { writeBack(CustomProtocol.ERROR_REQUEST_ID, result.getException().getMessage().getBytes(), ctx); } else { writeBack(result.getRequestId(), result.getMessage(), ctx); } return false; } private void writeBack(int requestId, byte[] bytes, Context ctx) { SelectableChannel channel = ctx.getSelectionKey().channel(); int byteLength = bytes.length; ByteBuffer newBuffer = ByteBufferFactory .allocateView(CustomProtocol.HEADERLENGTH + byteLength, false); newBuffer.putInt(CustomProtocol.STARTMARK); newBuffer.putInt(requestId); newBuffer.putInt(byteLength); newBuffer.put(bytes); newBuffer.flip(); try { synchronized (channel) { OutputWriter.flushChannel(channel, newBuffer); } newBuffer.clear(); } catch (IOException e) { e.printStackTrace(); } } public boolean postExecute(Context context) throws IOException { return true; } } }