package com.sun.grizzly; import junit.framework.TestCase; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.SelectionKey; import java.nio.channels.ReadableByteChannel; import java.nio.channels.WritableByteChannel; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.net.InetSocketAddress; import com.sun.grizzly.utils.ControllerUtils; import com.sun.grizzly.filter.ReadFilter; import com.sun.grizzly.filter.EchoFilter; /** * Tests {@link TCPConnectorHandler} and {@link UDPConnectorHandler}'s connect() concurrently * This test unit is temporary * * @author Bongjae Chang */ public class RoundRobinConnectionTest extends TestCase { public static final int PORT = 17522; public static final long WAIT_TIMEOUT = 60 * 1000; // ms public static final int RESPONSE_ARRIVE_TIMEOUT = 30 * 1000; // ms public static final int READ_TRY_COUNT = 1000; public void testTCPConnectionRoundRobin() throws IOException { doTest( Controller.Protocol.TCP, 50, 1000 ); } public void testUDPConnectionRoundRobin() throws IOException { doTest( Controller.Protocol.UDP, 50, 1000 ); } private void doTest( final Controller.Protocol protocol, final int clientCount, final int packetCount ) { final Controller controller = createController(); ControllerUtils.startController( controller ); final Thread[] threads = new Thread[clientCount]; final CountDownLatch latch = new CountDownLatch( clientCount ); boolean result = false; long startTime = System.currentTimeMillis(); try { for( int i = 0; i < clientCount; i++ ) { threads[i] = new Thread() { public void run() { try { connectAndSendPackets( controller, protocol, packetCount, latch ); } catch( IOException ie ) { Controller.logger().log( Level.INFO, "Got the unexpected error.", ie ); assertTrue( "Got the unexpected error.", false ); } } }; threads[i].start(); } result = waitOnLatch( latch, WAIT_TIMEOUT, TimeUnit.MILLISECONDS ); assertTrue( result ); Controller.logger().log( Level.INFO, "elapse = " + ( System.currentTimeMillis() - startTime ) + "ms" ); } finally { for( Thread client : threads ) { client.interrupt(); } try { controller.stop(); } catch( Throwable t ) { t.printStackTrace(); } } } @SuppressWarnings( "unchecked" ) private void connectAndSendPackets( final Controller controller, final Controller.Protocol protocol, final int packetCount, final CountDownLatch latch ) throws IOException { final ConnectorHandler connectorHandler = controller.acquireConnectorHandler( protocol ); int packetIndexLength = String.valueOf( packetCount ).length(); String packetIndexString = ""; for ( int i=0; i< packetIndexLength; i++ ) { packetIndexString += "0"; } final byte[] testData = new String( "Hello. Client#" + protocol + " Packet#" + packetIndexString ).getBytes(); final byte[] response = new byte[testData.length]; final ByteBuffer writeBB = ByteBuffer.wrap( testData ); final ByteBuffer readBB = ByteBuffer.wrap( response ); final CountDownLatch[] responseArrivedLatchHolder = new CountDownLatch[1]; CallbackHandler callbackHandler = createCallbackHandler( connectorHandler, responseArrivedLatchHolder, writeBB, readBB, testData ); try { connectorHandler.connect( new InetSocketAddress( "localhost", PORT ), callbackHandler ); for( int j = 0; j < packetCount; j++ ) { CountDownLatch responseArrivedLatch = new CountDownLatch( 1 ); responseArrivedLatchHolder[0] = responseArrivedLatch; readBB.clear(); writeBB.position( writeBB.limit() - packetIndexLength ); byte[] packetNum = Integer.toString( j ).getBytes(); writeBB.put( packetNum ); writeBB.position( 0 ); connectorHandler.write( writeBB, false ); connectorHandler.read( readBB, false ); if( readBB.position() < testData.length ) { waitOnLatch( responseArrivedLatch, RESPONSE_ARRIVE_TIMEOUT, TimeUnit.MILLISECONDS ); } readBB.flip(); String val1 = new String( testData ); String val2 = new String( toArray( readBB ) ); assertEquals( val1, val2 ); } } finally { connectorHandler.close(); controller.releaseConnectorHandler( connectorHandler ); latch.countDown(); } } private Controller createController() { final ProtocolFilter readFilter = new ReadFilter(); final ProtocolFilter echoFilter = new EchoFilter(); TCPSelectorHandler tcpSelectorHandler = new TCPSelectorHandler(); tcpSelectorHandler.setPort( PORT ); UDPSelectorHandler udpSelectorHandler = new UDPSelectorHandler(); udpSelectorHandler.setPort( PORT ); final Controller controller = new Controller(); controller.setSelectorHandler( tcpSelectorHandler ); controller.setSelectorHandler( udpSelectorHandler ); controller.setProtocolChainInstanceHandler( new DefaultProtocolChainInstanceHandler() { @Override public ProtocolChain poll() { ProtocolChain protocolChain = protocolChains.poll(); if( protocolChain == null ) { protocolChain = new DefaultProtocolChain(); protocolChain.addFilter( readFilter ); protocolChain.addFilter( echoFilter ); } return protocolChain; } } ); return controller; } private CallbackHandler createCallbackHandler( final ConnectorHandler connectorHandler, final CountDownLatch[] responseArrivedLatchHolder, final ByteBuffer writeBB, final ByteBuffer readBB, final byte[] testData ) { return new CallbackHandler() { private int readTry; public void onConnect( IOEvent ioEvent ) { SelectionKey key = ioEvent.attachment().getSelectionKey(); try { connectorHandler.finishConnect( key ); } catch( IOException ex ) { ex.printStackTrace(); } ioEvent.attachment().getSelectorHandler().register( key, SelectionKey.OP_READ ); } public void onRead( IOEvent ioEvent ) { SelectionKey key = ioEvent.attachment().getSelectionKey(); SelectorHandler selectorHandler = ioEvent.attachment(). getSelectorHandler(); ReadableByteChannel channel = (ReadableByteChannel)key.channel(); try { int nRead = channel.read( readBB ); if( nRead == 0 && readTry++ < READ_TRY_COUNT ) { selectorHandler.register( key, SelectionKey.OP_READ ); } else if( testData.length <= readBB.position() ) { responseArrivedLatchHolder[0].countDown(); } else { Controller.logger().log( Level.INFO, "###position=" + readBB.position() + ",nRead=" + nRead + ",readTry=" + readTry ); } } catch( IOException ex ) { ex.printStackTrace(); selectorHandler.getSelectionKeyHandler().cancel( key ); } } public void onWrite( IOEvent ioEvent ) { SelectionKey key = ioEvent.attachment().getSelectionKey(); SelectorHandler selectorHandler = ioEvent.attachment(). getSelectorHandler(); WritableByteChannel channel = (WritableByteChannel)key.channel(); try { while( writeBB.hasRemaining() ) { int nWrite = channel.write( writeBB ); if( nWrite == 0 ) { selectorHandler.register( key, SelectionKey.OP_WRITE ); return; } } connectorHandler.read( readBB, false ); } catch( IOException ex ) { ex.printStackTrace(); selectorHandler.getSelectionKeyHandler().cancel( key ); } } }; } private boolean waitOnLatch( CountDownLatch latch, long timeout, TimeUnit timeUnit ) { try { return latch.await( timeout, timeUnit ); } catch( InterruptedException ex ) { ex.printStackTrace(); } return false; } private byte[] toArray( ByteBuffer bb ) { byte[] buf = new byte[bb.remaining()]; bb.get( buf ); return buf; } }