org.apache.sshd.PortForwardingLoadTest.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.sshd.PortForwardingLoadTest.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.sshd;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.URL;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;

import com.jcraft.jsch.JSch;
import com.jcraft.jsch.JSchException;
import com.jcraft.jsch.Session;
import org.apache.commons.httpclient.HostConfiguration;
import org.apache.commons.httpclient.HttpClient;
import org.apache.commons.httpclient.HttpVersion;
import org.apache.commons.httpclient.MultiThreadedHttpConnectionManager;
import org.apache.commons.httpclient.methods.GetMethod;
import org.apache.mina.core.buffer.IoBuffer;
import org.apache.mina.core.service.IoAcceptor;
import org.apache.mina.core.service.IoHandlerAdapter;
import org.apache.mina.core.session.IoSession;
import org.apache.mina.transport.socket.nio.NioSocketAcceptor;
import org.apache.sshd.util.BaseTest;
import org.apache.sshd.util.BogusForwardingFilter;
import org.apache.sshd.util.BogusPasswordAuthenticator;
import org.apache.sshd.util.EchoShellFactory;
import org.apache.sshd.util.JSchLogger;
import org.apache.sshd.util.SimpleUserInfo;
import org.apache.sshd.util.Utils;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.LoggerFactory;

import static org.apache.sshd.util.Utils.getFreePort;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

/**
 * Port forwarding tests
 */
public class PortForwardingLoadTest extends BaseTest {

    private final org.slf4j.Logger log = LoggerFactory.getLogger(getClass());

    private SshServer sshd;
    private int sshPort;
    private int echoPort;
    private IoAcceptor acceptor;

    @Before
    public void setUp() throws Exception {
        sshPort = getFreePort();
        echoPort = getFreePort();

        sshd = SshServer.setUpDefaultServer();
        sshd.setPort(sshPort);
        sshd.setKeyPairProvider(Utils.createTestHostKeyProvider());
        sshd.setShellFactory(new EchoShellFactory());
        sshd.setPasswordAuthenticator(new BogusPasswordAuthenticator());
        sshd.setTcpipForwardingFilter(new BogusForwardingFilter());
        sshd.start();

        NioSocketAcceptor acceptor = new NioSocketAcceptor();
        acceptor.setHandler(new IoHandlerAdapter() {
            @Override
            public void messageReceived(IoSession session, Object message) throws Exception {
                IoBuffer recv = (IoBuffer) message;
                IoBuffer sent = IoBuffer.allocate(recv.remaining());
                sent.put(recv);
                sent.flip();
                session.write(sent);
            }
        });
        acceptor.setReuseAddress(true);
        acceptor.bind(new InetSocketAddress(echoPort));
        this.acceptor = acceptor;

    }

    @After
    public void tearDown() throws Exception {
        if (sshd != null) {
            sshd.stop();
            Thread.sleep(50);
        }
        if (acceptor != null) {
            acceptor.dispose();
        }
    }

    @Test
    public void testLocalForwardingPayload() throws Exception {
        final int NUM_ITERATIONS = 100;
        final String PAYLOAD_TMP = "This is significantly longer Test Data. This is significantly "
                + "longer Test Data. This is significantly longer Test Data. This is significantly "
                + "longer Test Data. This is significantly longer Test Data. This is significantly "
                + "longer Test Data. This is significantly longer Test Data. This is significantly "
                + "longer Test Data. This is significantly longer Test Data. This is significantly "
                + "longer Test Data. ";
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < 1000; i++) {
            sb.append(PAYLOAD_TMP);
        }
        final String PAYLOAD = sb.toString();
        Session session = createSession();
        final ServerSocket ss = new ServerSocket(0);
        int forwardedPort = ss.getLocalPort();
        int sinkPort = getFreePort();
        session.setPortForwardingL(sinkPort, "localhost", forwardedPort);
        final AtomicInteger conCount = new AtomicInteger(0);

        new Thread() {
            public void run() {
                try {
                    for (int i = 0; i < NUM_ITERATIONS; ++i) {
                        Socket s = ss.accept();
                        conCount.incrementAndGet();
                        InputStream is = s.getInputStream();
                        ByteArrayOutputStream baos = new ByteArrayOutputStream();
                        byte[] buf = new byte[8192];
                        int l;
                        while (baos.size() < PAYLOAD.length() && (l = is.read(buf)) > 0) {
                            baos.write(buf, 0, l);
                        }
                        if (!PAYLOAD.equals(baos.toString())) {
                            assertEquals(PAYLOAD, baos.toString());
                        }
                        is = new ByteArrayInputStream(baos.toByteArray());
                        OutputStream os = s.getOutputStream();
                        while ((l = is.read(buf)) > 0) {
                            os.write(buf, 0, l);
                        }
                        s.close();
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }.start();
        Thread.sleep(50);

        for (int i = 0; i < NUM_ITERATIONS; i++) {
            Socket s = null;
            try {
                LoggerFactory.getLogger(getClass()).info("Iteration {}", i);
                s = new Socket("localhost", sinkPort);
                s.getOutputStream().write(PAYLOAD.getBytes());
                s.getOutputStream().flush();
                ByteArrayOutputStream baos = new ByteArrayOutputStream();
                byte[] buf = new byte[8192];
                int l;
                while (baos.size() < PAYLOAD.length() && (l = s.getInputStream().read(buf)) > 0) {
                    baos.write(buf, 0, l);
                }
                assertEquals(PAYLOAD, baos.toString());
            } catch (Exception e) {
                e.printStackTrace();
            } finally {
                if (s != null) {
                    s.close();
                }
            }
        }
        session.delPortForwardingL(sinkPort);
    }

    @Test
    public void testRemoteForwardingPayload() throws Exception {
        final int NUM_ITERATIONS = 100;
        final String PAYLOAD = "This is significantly longer Test Data. This is significantly "
                + "longer Test Data. This is significantly longer Test Data. This is significantly "
                + "longer Test Data. This is significantly longer Test Data. This is significantly "
                + "longer Test Data. This is significantly longer Test Data. This is significantly "
                + "longer Test Data. This is significantly longer Test Data. This is significantly "
                + "longer Test Data. ";
        Session session = createSession();
        final ServerSocket ss = new ServerSocket(0);
        int forwardedPort = ss.getLocalPort();
        int sinkPort = getFreePort();
        session.setPortForwardingR(sinkPort, "localhost", forwardedPort);
        final boolean started[] = new boolean[1];
        started[0] = false;
        final AtomicInteger conCount = new AtomicInteger(0);

        new Thread() {
            public void run() {
                started[0] = true;
                try {
                    for (int i = 0; i < NUM_ITERATIONS; ++i) {
                        Socket s = ss.accept();
                        conCount.incrementAndGet();
                        s.getOutputStream().write(PAYLOAD.getBytes());
                        s.getOutputStream().flush();
                        s.close();
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }.start();
        Thread.sleep(50);
        Assert.assertTrue("Server not started", started[0]);

        final boolean lenOK[] = new boolean[NUM_ITERATIONS];
        final boolean dataOK[] = new boolean[NUM_ITERATIONS];
        for (int i = 0; i < NUM_ITERATIONS; i++) {
            final int ii = i;
            Socket s = null;
            try {
                s = new Socket("localhost", sinkPort);
                byte b1[] = new byte[PAYLOAD.length() / 2];
                byte b2[] = new byte[PAYLOAD.length()];
                int read1 = s.getInputStream().read(b1);
                Thread.sleep(50);
                int read2 = s.getInputStream().read(b2);
                lenOK[ii] = PAYLOAD.length() == read1 + read2;
                dataOK[ii] = PAYLOAD.equals(new String(b1, 0, read1) + new String(b2, 0, read2));
                if (!lenOK[ii] || !dataOK[ii]) {
                    throw new Exception("Bad data");
                }
            } catch (Exception e) {
                e.printStackTrace();
            } finally {
                if (s != null) {
                    s.close();
                }
            }
        }
        int ok = 0;
        for (int i = 0; i < NUM_ITERATIONS; i++) {
            ok += lenOK[i] ? 1 : 0;
        }
        Thread.sleep(50);
        for (int i = 0; i < NUM_ITERATIONS; i++) {
            Assert.assertTrue(lenOK[i]);
            Assert.assertTrue(dataOK[i]);
        }
        session.delPortForwardingR(forwardedPort);
    }

    @Test
    public void testForwardingOnLoad() throws Exception {
        //        final String path = "/history/recent/troubles/";
        //        final String host = "www.bbc.co.uk";
        //        final String path = "";
        //        final String host = "www.bahn.de";
        final String path = "";
        final String host = "localhost";
        final int nbThread = 2;
        final int nbDownloads = 2;
        final int nbLoops = 2;

        final int port = getFreePort();
        StringBuilder resp = new StringBuilder();
        resp.append("<html><body>\n");
        for (int i = 0; i < 1000; i++) {
            resp.append("0123456789\n");
        }
        resp.append("</body></html>\n");
        final StringBuilder sb = new StringBuilder();
        sb.append("HTTP/1.1 200 OK").append('\n');
        sb.append("Content-Type: text/HTML").append('\n');
        sb.append("Content-Length: ").append(resp.length()).append('\n');
        sb.append('\n');
        sb.append(resp);
        NioSocketAcceptor acceptor = new NioSocketAcceptor();
        acceptor.setHandler(new IoHandlerAdapter() {
            @Override
            public void messageReceived(IoSession session, Object message) throws Exception {
                session.write(IoBuffer.wrap(sb.toString().getBytes()));
            }
        });
        acceptor.setReuseAddress(true);
        acceptor.bind(new InetSocketAddress(port));

        Session session = createSession();

        final int forwardedPort1 = getFreePort();
        final int forwardedPort2 = getFreePort();
        System.err.println("URL: http://localhost:" + forwardedPort2);

        session.setPortForwardingL(forwardedPort1, host, port);
        session.setPortForwardingR(forwardedPort2, "localhost", forwardedPort1);

        final CountDownLatch latch = new CountDownLatch(nbThread * nbDownloads * nbLoops);

        final Thread[] threads = new Thread[nbThread];
        final List<Throwable> errors = new CopyOnWriteArrayList<Throwable>();
        for (int i = 0; i < threads.length; i++) {
            threads[i] = new Thread() {
                public void run() {
                    for (int j = 0; j < nbLoops; j++) {
                        final MultiThreadedHttpConnectionManager mgr = new MultiThreadedHttpConnectionManager();
                        final HttpClient client = new HttpClient(mgr);
                        client.getHttpConnectionManager().getParams().setDefaultMaxConnectionsPerHost(100);
                        client.getHttpConnectionManager().getParams().setMaxTotalConnections(1000);
                        for (int i = 0; i < nbDownloads; i++) {
                            try {
                                checkHtmlPage(client, new URL("http://localhost:" + forwardedPort2 + path));
                            } catch (Throwable e) {
                                errors.add(e);
                            } finally {
                                latch.countDown();
                                System.err.println("Remaining: " + latch.getCount());
                            }
                        }
                        mgr.shutdown();
                    }
                }
            };
        }
        for (int i = 0; i < threads.length; i++) {
            threads[i].start();
        }
        latch.await();
        for (Throwable t : errors) {
            t.printStackTrace();
        }
        assertEquals(0, errors.size());
    }

    protected Session createSession() throws JSchException {
        JSchLogger.init();
        JSch sch = new JSch();
        Session session = sch.getSession("sshd", "localhost", sshPort);
        session.setUserInfo(new SimpleUserInfo("sshd"));
        session.connect();
        return session;
    }

    protected void checkHtmlPage(HttpClient client, URL url) throws IOException {
        client.setHostConfiguration(new HostConfiguration());
        client.getHostConfiguration().setHost(url.getHost(), url.getPort());
        GetMethod get = new GetMethod("");
        get.getParams().setVersion(HttpVersion.HTTP_1_1);
        client.executeMethod(get);
        String str = get.getResponseBodyAsString();
        if (str.indexOf("</html>") <= 0) {
            System.err.println(str);
        }
        assertTrue((str.indexOf("</html>") > 0));
        get.releaseConnection();
        //        url.openConnection().setDefaultUseCaches(false);
        //        Reader reader = new BufferedReader(new InputStreamReader(url.openStream()));
        //        try {
        //            StringWriter sw = new StringWriter();
        //            char[] buf = new char[8192];
        //            while (true) {
        //                int len = reader.read(buf);
        //                if (len < 0) {
        //                    break;
        //                }
        //                sw.write(buf, 0, len);
        //            }
        //            assertTrue(sw.toString().indexOf("</html>") > 0);
        //        } finally {
        //            reader.close();
        //        }
    }

}