package com.sun.grizzly;

import junit.framework.TestCase;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.WritableByteChannel;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.net.InetSocketAddress;

import com.sun.grizzly.utils.ControllerUtils;
import com.sun.grizzly.filter.ReadFilter;
import com.sun.grizzly.filter.EchoFilter;

/**
 * Tests {@link TCPConnectorHandler} and {@link UDPConnectorHandler}'s connect() concurrently
 * This test unit is temporary
 *
 * @author Bongjae Chang
 */
public class RoundRobinConnectionTest extends TestCase {

    public static final int PORT = 17522;
    public static final long WAIT_TIMEOUT = 60 * 1000; // ms
    public static final int RESPONSE_ARRIVE_TIMEOUT = 30 * 1000; // ms
    public static final int READ_TRY_COUNT = 1000;

    public void testTCPConnectionRoundRobin() throws IOException {
        doTest( Controller.Protocol.TCP, 50, 1000 );
    }

    public void testUDPConnectionRoundRobin() throws IOException {
        doTest( Controller.Protocol.UDP, 50, 1000 );
    }

    private void doTest( final Controller.Protocol protocol, final int clientCount, final int packetCount ) {
        final Controller controller = createController();
        ControllerUtils.startController( controller );
        final Thread[] threads = new Thread[clientCount];
        final CountDownLatch latch = new CountDownLatch( clientCount );
        boolean result = false;
        long startTime = System.currentTimeMillis();
        try {
            for( int i = 0; i < clientCount; i++ ) {
                threads[i] = new Thread() {
                    public void run() {
                        try {
                            connectAndSendPackets( controller, protocol, packetCount, latch );
                        } catch( IOException ie ) {
                            Controller.logger().log( Level.INFO, "Got the unexpected error.", ie );
                            assertTrue( "Got the unexpected error.", false );
                        }
                    }
                };
                threads[i].start();
            }
            result = waitOnLatch( latch, WAIT_TIMEOUT, TimeUnit.MILLISECONDS );
            assertTrue( result );
            Controller.logger().log( Level.INFO, "elapse = " + ( System.currentTimeMillis() - startTime ) + "ms" );
        } finally {
            for( Thread client : threads ) {
                client.interrupt();
            }
            try {
                controller.stop();
            } catch( Throwable t ) {
                t.printStackTrace();
            }
        }
    }

    @SuppressWarnings( "unchecked" )
    private void connectAndSendPackets( final Controller controller,
                                       final Controller.Protocol protocol,
                                       final int packetCount,
                                       final CountDownLatch latch ) throws IOException {
        final ConnectorHandler connectorHandler =
                controller.acquireConnectorHandler( protocol );
        int packetIndexLength = String.valueOf( packetCount ).length();
        String packetIndexString = "";
        for ( int i=0; i< packetIndexLength; i++ ) {
            packetIndexString += "0";
        }
        final byte[] testData = new String( "Hello. Client#" + protocol + " Packet#" + packetIndexString ).getBytes();
        final byte[] response = new byte[testData.length];

        final ByteBuffer writeBB = ByteBuffer.wrap( testData );
        final ByteBuffer readBB = ByteBuffer.wrap( response );

        final CountDownLatch[] responseArrivedLatchHolder = new CountDownLatch[1];
        CallbackHandler callbackHandler = createCallbackHandler( connectorHandler,
                                                                 responseArrivedLatchHolder,
                                                                 writeBB,
                                                                 readBB,
                                                                 testData );

        try {
            connectorHandler.connect( new InetSocketAddress( "localhost", PORT ), callbackHandler );
            for( int j = 0; j < packetCount; j++ ) {
                CountDownLatch responseArrivedLatch = new CountDownLatch( 1 );
                responseArrivedLatchHolder[0] = responseArrivedLatch;
                readBB.clear();
                writeBB.position( writeBB.limit() - packetIndexLength );
                byte[] packetNum = Integer.toString( j ).getBytes();
                writeBB.put( packetNum );
                writeBB.position( 0 );
                connectorHandler.write( writeBB, false );
                connectorHandler.read( readBB, false );

                if( readBB.position() < testData.length ) {
                    waitOnLatch( responseArrivedLatch, RESPONSE_ARRIVE_TIMEOUT, TimeUnit.MILLISECONDS );
                }

                readBB.flip();
                String val1 = new String( testData );
                String val2 = new String( toArray( readBB ) );
                assertEquals( val1, val2 );
            }
        } finally {
            connectorHandler.close();
            controller.releaseConnectorHandler( connectorHandler );
            latch.countDown();
        }
    }

    private Controller createController() {
        final ProtocolFilter readFilter = new ReadFilter();
        final ProtocolFilter echoFilter = new EchoFilter();

        TCPSelectorHandler tcpSelectorHandler = new TCPSelectorHandler();
        tcpSelectorHandler.setPort( PORT );
        UDPSelectorHandler udpSelectorHandler = new UDPSelectorHandler();
        udpSelectorHandler.setPort( PORT );

        final Controller controller = new Controller();

        controller.setSelectorHandler( tcpSelectorHandler );
        controller.setSelectorHandler( udpSelectorHandler );

        controller.setProtocolChainInstanceHandler(
                new DefaultProtocolChainInstanceHandler() {
                    @Override
                    public ProtocolChain poll() {
                        ProtocolChain protocolChain = protocolChains.poll();
                        if( protocolChain == null ) {
                            protocolChain = new DefaultProtocolChain();
                            protocolChain.addFilter( readFilter );
                            protocolChain.addFilter( echoFilter );
                        }
                        return protocolChain;
                    }
                } );

        return controller;
    }

    private CallbackHandler createCallbackHandler( final ConnectorHandler connectorHandler,
                                                   final CountDownLatch[] responseArrivedLatchHolder,
                                                   final ByteBuffer writeBB,
                                                   final ByteBuffer readBB,
                                                   final byte[] testData ) {
        return new CallbackHandler<Context>() {

            private int readTry;

            public void onConnect( IOEvent<Context> ioEvent ) {
                SelectionKey key = ioEvent.attachment().getSelectionKey();
                try {
                    connectorHandler.finishConnect( key );
                } catch( IOException ex ) {
                    ex.printStackTrace();
                }
                ioEvent.attachment().getSelectorHandler().register( key,
                                                                    SelectionKey.OP_READ );
            }

            public void onRead( IOEvent<Context> ioEvent ) {
                SelectionKey key = ioEvent.attachment().getSelectionKey();
                SelectorHandler selectorHandler = ioEvent.attachment().
                        getSelectorHandler();
                ReadableByteChannel channel = (ReadableByteChannel)key.channel();

                try {
                    int nRead = channel.read( readBB );
                    if( nRead == 0 && readTry++ < READ_TRY_COUNT ) {
                        selectorHandler.register( key, SelectionKey.OP_READ );
                    } else if( testData.length <= readBB.position() ) {
                        responseArrivedLatchHolder[0].countDown();
                    } else {
                        Controller.logger().log( Level.INFO,
                                                 "###position=" + readBB.position() + ",nRead=" + nRead + ",readTry=" + readTry );
                    }
                } catch( IOException ex ) {
                    ex.printStackTrace();
                    selectorHandler.getSelectionKeyHandler().cancel( key );
                }
            }

            public void onWrite( IOEvent<Context> ioEvent ) {
                SelectionKey key = ioEvent.attachment().getSelectionKey();
                SelectorHandler selectorHandler = ioEvent.attachment().
                        getSelectorHandler();
                WritableByteChannel channel = (WritableByteChannel)key.channel();
                try {
                    while( writeBB.hasRemaining() ) {
                        int nWrite = channel.write( writeBB );

                        if( nWrite == 0 ) {
                            selectorHandler.register( key, SelectionKey.OP_WRITE );
                            return;
                        }
                    }

                    connectorHandler.read( readBB, false );
                } catch( IOException ex ) {
                    ex.printStackTrace();
                    selectorHandler.getSelectionKeyHandler().cancel( key );
                }

            }
        };
    }

    private boolean waitOnLatch( CountDownLatch latch, long timeout, TimeUnit timeUnit ) {
        try {
            return latch.await( timeout, timeUnit );
        } catch( InterruptedException ex ) {
            ex.printStackTrace();
        }
        return false;
    }

    private byte[] toArray( ByteBuffer bb ) {
        byte[] buf = new byte[bb.remaining()];
        bb.get( buf );
        return buf;
    }
}