package com.sun.grizzly;
import com.sun.grizzly.filter.ParserProtocolFilter;
import com.sun.grizzly.util.ByteBufferFactory;
import com.sun.grizzly.util.OutputWriter;
import com.sun.grizzly.util.WorkerThread;
import com.sun.grizzly.utils.ControllerUtils;
import com.sun.grizzly.utils.TCPIOClient;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SelectableChannel;
import java.util.Arrays;
import junit.framework.TestCase;
/**
* Test a simple custom TCP Protocol parsed by Grizzly ProtocolParser API
*
* 4 bytes Magic Number - Just to identify Protocol
* 4 bytes RequestID - Can be used to identify Replies (-1 for Exception)
* 4 bytes Data length - The length of Data
* n bytes Data - Message
* John Vieten
*/
public class ProtocolParserStateTest extends TestCase {
public static final int PORT = 17505;
public static final int PACKETS_COUNT = 100;
static int requestId = 0;
public interface State {
final int START = 1;
final int EXPECTING_MORE_DATA = 2;
final int HAS_MORE_BYTES_TO_PARSE = 3;
final int MEASSAGE_PARSED = 4;
final int MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE = 5;
}
/**
* ProtocolParser Filter which parses a Custom Protocol
* into a IncomingMessage. These IncomingMessages can then be used
* by other Protocolchain Filters
*
* The Messages can be arbitrary long.
* The Protocol has the following format (12 bytes Header + Data):
* 4 bytes Magic Number - Just to identify Protocol
* 4 bytes RequestID - Can be used to identify Replies (-1 for Exception)
* 4 bytes Data length - The length of Data
* n bytes Data - Message
*
* The Protocol asumes that on average Messages are not too large otherwise
* this implementation might not be effective.
*
*
* Tries to avoid bytearray copying and therfore maps the IncomingMessage onto
* the orginal Grizzly Bytebuffer. If Messages do not fit into an default
* sized Grizzly ByteBuffer a larger one with enough capacity is temporarly
* given to the framework.
*
* This class is still under construction and for example Error Handling,
* Timeouts, Recovery has to be greatly improved!!!
*
*
* John Vieten
*/
public class CustomProtocol implements ProtocolParser, State {
public static final int HEADER_LENGTH = 12;
public static final int START_MARK = 0x77005434;
public static final int DEFAULT_BUFFER_SIZE = 8192;
public static final int ERROR_REQUEST_ID = -1;
/**
* Handle to byteBuffer which gets filled in by grizzly
*/
private ByteBuffer grizzlyBuffer;
/**
* If Message are larger than DEFAULT_BUFFER_SIZE
* a new larger Grizzly ByteBuffer is created
* This handle is used to give Grizzly back its original default ByteBuffer .
*/
private ByteBuffer restoreHandle;
/**
* The length of the data depicted by Protocol Header
*/
private int dataLength = 0;
/**
* Just dataLength + Header Length
*/
private int messageLength = 0;
/*
Keeps track of the bytes read by the Grizzly ReadFilter
*/
private int trackBytesRead = 0;
/**
* Keeps track of the start position of the current message in the
* Grizzly ByteBuffer.
* Needed because an Grizzly ByteBuffer can contaion several Messages
*/
int currentMessageStartPosition = 0;
/**
* Holds the state of this Statemachine
*/
private int state = State.START;
/**
* The Message given to other Filters up the chain
*/
private IncomingMessage incomingMessage;
public boolean isExpectingMoreData() {
return state == State.EXPECTING_MORE_DATA;
}
public boolean hasMoreBytesToParse() {
return state == State.HAS_MORE_BYTES_TO_PARSE;
}
public IncomingMessage getNextMessage() {
IncomingMessage tmp = incomingMessage;
incomingMessage = null;
switch (state) {
case MEASSAGE_PARSED:
resetState();
break;
case MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE:
state = State.HAS_MORE_BYTES_TO_PARSE;
}
return tmp;
}
public boolean hasNextMessage() {
parseMessage();
switch (state) {
case MEASSAGE_PARSED:
case MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE:
return true;
default:
return false;
}
}
public void parseMessage() {
switch (state) {
case MEASSAGE_PARSED:
case MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE:
return;
}
switch (state) {
case HAS_MORE_BYTES_TO_PARSE:
// we need to do some housekeeping
// because startBuffer was not called
grizzlyBuffer.position(currentMessageStartPosition);
case START:
int startmark = grizzlyBuffer.getInt();
if (startmark != START_MARK) {
incomingMessage =
new IncomingMessage(new Exception("Bad Startmark"));
incomingMessage.setRequestId(ERROR_REQUEST_ID);
state = State.MEASSAGE_PARSED;
return;
}
final int requestId = grizzlyBuffer.getInt();
dataLength = grizzlyBuffer.getInt();
messageLength = dataLength + HEADER_LENGTH;
if (dataLength <= 0) {
incomingMessage =
new IncomingMessage(new Exception("Bad Message length"));
incomingMessage.setRequestId(ERROR_REQUEST_ID);
state = State.MEASSAGE_PARSED;
return;
}
if ((grizzlyBuffer.capacity() - grizzlyBuffer.position()) < dataLength) {
// Not enough room to store message so we
// need a bigger buffer and also some extra room for
// any new message chunks
int newCapacity = messageLength + DEFAULT_BUFFER_SIZE;
grizzlyBuffer.position(0);
ByteBuffer newBuffer =
ByteBufferFactory.allocateView(newCapacity,
grizzlyBuffer.isDirect());
newBuffer.put(grizzlyBuffer);
WorkerThread workerThread = (WorkerThread) Thread.currentThread();
workerThread.setByteBuffer(grizzlyBuffer = newBuffer);
}
// Now create a sliced copy of GrizzlyBuffer so that
// we have start and end pointers on our wanted message
grizzlyBuffer.position(HEADER_LENGTH + currentMessageStartPosition);
grizzlyBuffer.limit(messageLength + currentMessageStartPosition);
ByteBuffer slicedHandle = grizzlyBuffer.slice();
grizzlyBuffer.limit(grizzlyBuffer.capacity());
// Now Buffer has enough space for message
incomingMessage = new IncomingMessage(dataLength, slicedHandle);
incomingMessage.setRequestId(requestId);
case EXPECTING_MORE_DATA:
int remaining = trackBytesRead - messageLength - currentMessageStartPosition;
if (remaining == 0) {
// Ok we found a message
state = State.MEASSAGE_PARSED;
} else if (remaining > 0) {
state = State.MEASSAGE_PARSED_AND_HAS_MORE_BYTES_TO_PARSE;
currentMessageStartPosition = currentMessageStartPosition + messageLength;
} else {
// we have to keep on ready bytes from the net
state = State.EXPECTING_MORE_DATA;
grizzlyBuffer.position(trackBytesRead);
}
}
}
public boolean releaseBuffer() {
boolean saveCurrentParserState = false;
if (isExpectingMoreData()) {
saveCurrentParserState = true;
} else {
WorkerThread workerThread = (WorkerThread) Thread.currentThread();
restoreHandle.clear();
workerThread.setByteBuffer(restoreHandle);
grizzlyBuffer = null;
restoreHandle = null;
}
return saveCurrentParserState;
}
public void startBuffer(ByteBuffer bb) {
restoreHandle = bb;
trackBytesRead = bb.position();
grizzlyBuffer = bb;
if (!isExpectingMoreData()) {
grizzlyBuffer.position(0);
resetState();
}
}
private void resetState() {
state = State.START;
dataLength = 0;
messageLength = 0;
currentMessageStartPosition = 0;
}
}
/**
* User: John Vieten
*/
public class IncomingMessage {
private int requestId;
private int messageLength = -1;
private ByteBuffer bb;
// private byte[] bbArray;
private Exception exception;
public Exception getException() {
return exception;
}
public void setException(Exception exception) {
this.exception = exception;
}
public boolean hasException() {
return exception != null;
}
public IncomingMessage(Exception exception) {
this.exception = exception;
}
public IncomingMessage(int messageLength, ByteBuffer sourceSliced) {
this.messageLength = messageLength;
// bbArray = new byte[messageLength];
bb = sourceSliced;
}
public int getMessageLength() {
return messageLength;
}
public ByteBuffer getMessage() {
return bb;
}
public int getRequestId() {
return requestId;
}
public void setRequestId(int requestId) {
this.requestId = requestId;
}
public void clear() {
bb = null;
}
}
public class CustomProtocolEcho implements ProtocolFilter {
public boolean execute(final Context ctx) {
IncomingMessage result =
(IncomingMessage) ctx.removeAttribute(ProtocolParser.MESSAGE);
if (result.hasException()) {
writeBack(CustomProtocol.ERROR_REQUEST_ID,
result.getException().getMessage().getBytes(), ctx);
} else {
ByteBuffer bb=result.getMessage();
byte[] bytes=new byte[result.getMessageLength()];
bb.get(bytes);
writeBack(result.getRequestId(),bytes, ctx);
}
return false;
}
private void writeBack(int requestId, byte[] bytes, Context ctx) {
SelectableChannel channel = ctx.getSelectionKey().channel();
int byteLength = bytes.length;
ByteBuffer newBuffer
= ByteBufferFactory
.allocateView(CustomProtocol.HEADER_LENGTH + byteLength, false);
newBuffer.putInt(CustomProtocol.START_MARK);
newBuffer.putInt(requestId);
newBuffer.putInt(byteLength);
newBuffer.put(bytes);
newBuffer.flip();
try {
synchronized (channel) {
OutputWriter.flushChannel(channel, newBuffer);
}
newBuffer.clear();
} catch (IOException e) {
e.printStackTrace();
}
}
public boolean postExecute(Context context) throws IOException {
return true;
}
}
//-------------------------- Start Test Code ----------------------------
public void testCustomProtocolSmall() throws IOException {
Controller controller = createController(PORT);
TCPIOClient client = new TCPIOClient("localhost", PORT);
try {
ControllerUtils.startController(controller);
client.connect();
sendToServer("Hello World".getBytes(), client);
} finally {
controller.stop();
client.close();
}
}
public void testCustomProtocolBig() throws IOException {
Controller controller = createController(PORT);
TCPIOClient client = new TCPIOClient("localhost", PORT);
try {
ControllerUtils.startController(controller);
client.connect();
for (int i = 1; i < 7; i++) {
byte[] bigMessage = new byte[8192 * i];
Arrays.fill(bigMessage, (byte) i);
sendToServer(bigMessage, client);
}
} finally {
controller.stop();
client.close();
}
}
public void testException() throws IOException {
Controller controller = createController(PORT);
TCPIOClient client = new TCPIOClient("localhost", PORT);
try {
ControllerUtils.startController(controller);
client.connect();
byte[] data = "Hello World This should be echoed".getBytes();
client.send(data);
byte[] header = new byte[CustomProtocol.HEADER_LENGTH];
client.receive(header);
ByteBuffer receiveBuffer = ByteBuffer.wrap(header);
int startMark = receiveBuffer.getInt();
assertTrue(startMark == CustomProtocol.START_MARK);
int receiveRequestId = receiveBuffer.getInt();
assertTrue(receiveRequestId == -1);
int receiveLength = receiveBuffer.getInt();
byte[] message = new byte[receiveLength];
client.receive(message);
} finally {
controller.stop();
client.close();
}
}
public static void sendToServer(byte[] data, TCPIOClient client)
throws IOException {
int messagelength = CustomProtocol.HEADER_LENGTH + data.length;
ByteBuffer newBuffer = ByteBufferFactory.allocateView(messagelength, false);
newBuffer.putInt(CustomProtocol.START_MARK);
requestId++;
newBuffer.putInt(requestId);
newBuffer.putInt(data.length);
newBuffer.put(data);
newBuffer.flip();
byte[] sendBytes = new byte[messagelength];
newBuffer.get(sendBytes);
client.send(sendBytes);
byte[] header = new byte[CustomProtocol.HEADER_LENGTH];
client.receive(header);
ByteBuffer receiveBuffer = ByteBuffer.wrap(header);
int startMark = receiveBuffer.getInt();
assertTrue(startMark == CustomProtocol.START_MARK);
int receiveRequestId = receiveBuffer.getInt();
assertTrue(receiveRequestId == requestId);
int receiveLength = receiveBuffer.getInt();
assertTrue(receiveLength == data.length);
byte[] message = new byte[receiveLength];
ByteBuffer messageBuffer = ByteBuffer.wrap(message);
client.receive(message);
messageBuffer.get(message);
assertTrue(Arrays.equals(data, message));
}
private Controller createController(int port) {
Controller controller = new Controller();
TCPSelectorHandler tcpSelectorHandler = new TCPSelectorHandler();
tcpSelectorHandler.setPort(port);
controller.addSelectorHandler(tcpSelectorHandler);
final ProtocolFilter parserProtocolFilter = new ParserProtocolFilter() {
public ProtocolParser newProtocolParser() {
return new CustomProtocol();
}
};
final ProtocolChain protocolChain = new DefaultProtocolChain();
protocolChain.addFilter(parserProtocolFilter);
protocolChain.addFilter(new CustomProtocolEcho());
((DefaultProtocolChain) protocolChain).setContinuousExecution(true);
ProtocolChainInstanceHandler pciHandler =
new DefaultProtocolChainInstanceHandler() {
public ProtocolChain poll() {
return protocolChain;
}
public boolean offer(ProtocolChain protocolChain) {
return false;
}
};
controller.setProtocolChainInstanceHandler(pciHandler);
return controller;
}
}