org.apache.sshd.common.forward.PortForwardingLoadTest.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.sshd.common.forward.PortForwardingLoadTest.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.sshd.common.forward;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import com.jcraft.jsch.JSch;
import com.jcraft.jsch.JSchException;
import com.jcraft.jsch.Session;

import org.apache.commons.httpclient.HostConfiguration;
import org.apache.commons.httpclient.HttpClient;
import org.apache.commons.httpclient.HttpVersion;
import org.apache.commons.httpclient.MultiThreadedHttpConnectionManager;
import org.apache.commons.httpclient.methods.GetMethod;
import org.apache.mina.core.buffer.IoBuffer;
import org.apache.mina.core.service.IoAcceptor;
import org.apache.mina.core.service.IoHandlerAdapter;
import org.apache.mina.core.session.IoSession;
import org.apache.mina.transport.socket.nio.NioSocketAcceptor;
import org.apache.sshd.common.FactoryManager;
import org.apache.sshd.common.util.net.SshdSocketAddress;
import org.apache.sshd.server.SshServer;
import org.apache.sshd.server.forward.AcceptAllForwardingFilter;
import org.apache.sshd.util.test.BaseTestSupport;
import org.apache.sshd.util.test.JSchLogger;
import org.apache.sshd.util.test.SimpleUserInfo;
import org.apache.sshd.util.test.Utils;
import org.junit.After;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.FixMethodOrder;
import org.junit.Test;
import org.junit.runners.MethodSorters;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Port forwarding tests
 */
@FixMethodOrder(MethodSorters.NAME_ASCENDING)
public class PortForwardingLoadTest extends BaseTestSupport {
    private final Logger log;

    @SuppressWarnings({ "checkstyle:anoninnerlength", "synthetic-access" })
    private final PortForwardingEventListener serverSideListener = new PortForwardingEventListener() {
        @Override
        public void establishingExplicitTunnel(org.apache.sshd.common.session.Session session,
                SshdSocketAddress local, SshdSocketAddress remote, boolean localForwarding) throws IOException {
            log.info("establishingExplicitTunnel(session={}, local={}, remote={}, localForwarding={})", session,
                    local, remote, localForwarding);
        }

        @Override
        public void establishedExplicitTunnel(org.apache.sshd.common.session.Session session,
                SshdSocketAddress local, SshdSocketAddress remote, boolean localForwarding,
                SshdSocketAddress boundAddress, Throwable reason) throws IOException {
            log.info("establishedExplicitTunnel(session={}, local={}, remote={}, bound={}, localForwarding={}): {}",
                    session, local, remote, boundAddress, localForwarding, reason);
        }

        @Override
        public void tearingDownExplicitTunnel(org.apache.sshd.common.session.Session session,
                SshdSocketAddress address, boolean localForwarding) throws IOException {
            log.info("tearingDownExplicitTunnel(session={}, address={}, localForwarding={})", session, address,
                    localForwarding);
        }

        @Override
        public void tornDownExplicitTunnel(org.apache.sshd.common.session.Session session,
                SshdSocketAddress address, boolean localForwarding, Throwable reason) throws IOException {
            log.info("tornDownExplicitTunnel(session={}, address={}, localForwarding={}, reason={})", session,
                    address, localForwarding, reason);
        }

        @Override
        public void establishingDynamicTunnel(org.apache.sshd.common.session.Session session,
                SshdSocketAddress local) throws IOException {
            log.info("establishingDynamicTunnel(session={}, local={})", session, local);
        }

        @Override
        public void establishedDynamicTunnel(org.apache.sshd.common.session.Session session,
                SshdSocketAddress local, SshdSocketAddress boundAddress, Throwable reason) throws IOException {
            log.info("establishedDynamicTunnel(session={}, local={}, bound={}, reason={})", session, local,
                    boundAddress, reason);
        }

        @Override
        public void tearingDownDynamicTunnel(org.apache.sshd.common.session.Session session,
                SshdSocketAddress address) throws IOException {
            log.info("tearingDownDynamicTunnel(session={}, address={})", session, address);
        }

        @Override
        public void tornDownDynamicTunnel(org.apache.sshd.common.session.Session session, SshdSocketAddress address,
                Throwable reason) throws IOException {
            log.info("tornDownDynamicTunnel(session={}, address={}, reason={})", session, address, reason);
        }
    };

    private SshServer sshd;
    private int sshPort;
    private IoAcceptor acceptor;

    public PortForwardingLoadTest() {
        log = LoggerFactory.getLogger(getClass());
    }

    @BeforeClass
    public static void jschInit() {
        JSchLogger.init();
    }

    @Before
    public void setUp() throws Exception {
        sshd = setupTestServer();
        sshd.setTcpipForwardingFilter(AcceptAllForwardingFilter.INSTANCE);
        sshd.addPortForwardingEventListener(serverSideListener);
        sshd.start();
        sshPort = sshd.getPort();

        NioSocketAcceptor acceptor = new NioSocketAcceptor();
        acceptor.setHandler(new IoHandlerAdapter() {
            @Override
            public void messageReceived(IoSession session, Object message) throws Exception {
                IoBuffer recv = (IoBuffer) message;
                IoBuffer sent = IoBuffer.allocate(recv.remaining());
                sent.put(recv);
                sent.flip();
                session.write(sent);
            }
        });
        acceptor.setReuseAddress(true);
        acceptor.bind(new InetSocketAddress(0));
        log.info("setUp() echo address = {}", acceptor.getLocalAddress());
        this.acceptor = acceptor;
    }

    @After
    public void tearDown() throws Exception {
        if (sshd != null) {
            sshd.stop(true);
        }
        if (acceptor != null) {
            acceptor.dispose(true);
        }
    }

    @Test
    @SuppressWarnings("checkstyle:nestedtrydepth")
    public void testLocalForwardingPayload() throws Exception {
        final int numIterations = 100;
        final String payloadTmpData = "This is significantly longer Test Data. This is significantly "
                + "longer Test Data. This is significantly longer Test Data. This is significantly "
                + "longer Test Data. This is significantly longer Test Data. This is significantly "
                + "longer Test Data. This is significantly longer Test Data. This is significantly "
                + "longer Test Data. This is significantly longer Test Data. This is significantly "
                + "longer Test Data. ";
        StringBuilder sb = new StringBuilder(payloadTmpData.length() * 1000);
        for (int i = 0; i < 1000; i++) {
            sb.append(payloadTmpData);
        }
        final String payload = sb.toString();

        Session session = createSession();
        try (ServerSocket ss = new ServerSocket()) {
            ss.setReuseAddress(true);
            ss.bind(new InetSocketAddress((InetAddress) null, 0));
            int forwardedPort = ss.getLocalPort();
            int sinkPort = session.setPortForwardingL(0, TEST_LOCALHOST, forwardedPort);
            final AtomicInteger conCount = new AtomicInteger(0);
            final Semaphore iterationsSignal = new Semaphore(0);
            Thread tAcceptor = new Thread(getCurrentTestName() + "Acceptor") {
                @SuppressWarnings("synthetic-access")
                @Override
                public void run() {
                    try {
                        byte[] buf = new byte[8192];
                        log.info("Started...");
                        for (int i = 0; i < numIterations; ++i) {
                            try (Socket s = ss.accept()) {
                                conCount.incrementAndGet();

                                try (InputStream sockIn = s.getInputStream();
                                        ByteArrayOutputStream baos = new ByteArrayOutputStream()) {

                                    while (baos.size() < payload.length()) {
                                        int l = sockIn.read(buf);
                                        if (l < 0) {
                                            break;
                                        }
                                        baos.write(buf, 0, l);
                                    }

                                    assertEquals("Mismatched received data at iteration #" + i, payload,
                                            baos.toString());

                                    try (InputStream inputCopy = new ByteArrayInputStream(baos.toByteArray());
                                            OutputStream sockOut = s.getOutputStream()) {

                                        while (true) {
                                            int l = sockIn.read(buf);
                                            if (l < 0) {
                                                break;
                                            }
                                            sockOut.write(buf, 0, l);
                                        }
                                    }
                                }
                            }
                            log.info("Finished iteration {}", i);
                            iterationsSignal.release();
                        }
                        log.info("Done");
                    } catch (Exception e) {
                        log.error("Failed to complete run loop", e);
                    }
                }
            };
            tAcceptor.start();
            Thread.sleep(TimeUnit.SECONDS.toMillis(1L));

            byte[] buf = new byte[8192];
            byte[] bytes = payload.getBytes(StandardCharsets.UTF_8);
            for (int i = 0; i < numIterations; i++) {
                log.info("Iteration {}", i);
                try (Socket s = new Socket(TEST_LOCALHOST, sinkPort); OutputStream sockOut = s.getOutputStream()) {

                    s.setSoTimeout((int) FactoryManager.DEFAULT_NIO2_MIN_WRITE_TIMEOUT);

                    sockOut.write(bytes);
                    sockOut.flush();

                    try (InputStream sockIn = s.getInputStream();
                            ByteArrayOutputStream baos = new ByteArrayOutputStream(bytes.length)) {
                        while (baos.size() < payload.length()) {
                            int l = sockIn.read(buf);
                            if (l < 0) {
                                break;
                            }
                            baos.write(buf, 0, l);
                        }
                        assertEquals("Mismatched payload at iteration #" + i, payload, baos.toString());
                    }
                } catch (Exception e) {
                    log.error("Error in iteration #" + i, e);
                }
            }

            try {
                assertTrue("Failed to await pending iterations=" + numIterations,
                        iterationsSignal.tryAcquire(numIterations, numIterations, TimeUnit.SECONDS));
            } finally {
                session.delPortForwardingL(sinkPort);
            }

            ss.close();
            tAcceptor.join(TimeUnit.SECONDS.toMillis(11L));
        } finally {
            session.disconnect();
        }
    }

    @Test
    public void testRemoteForwardingPayload() throws Exception {
        final int numIterations = 100;
        final String payload = "This is significantly longer Test Data. This is significantly "
                + "longer Test Data. This is significantly longer Test Data. This is significantly "
                + "longer Test Data. This is significantly longer Test Data. This is significantly "
                + "longer Test Data. This is significantly longer Test Data. This is significantly "
                + "longer Test Data. ";
        Session session = createSession();
        try (ServerSocket ss = new ServerSocket()) {
            ss.setReuseAddress(true);
            ss.bind(new InetSocketAddress((InetAddress) null, 0));
            int forwardedPort = ss.getLocalPort();
            int sinkPort = Utils.getFreePort();
            session.setPortForwardingR(sinkPort, TEST_LOCALHOST, forwardedPort);
            final boolean started[] = new boolean[1];
            started[0] = false;
            final AtomicInteger conCount = new AtomicInteger(0);

            Thread tWriter = new Thread(getCurrentTestName() + "Writer") {
                @SuppressWarnings("synthetic-access")
                @Override
                public void run() {
                    started[0] = true;
                    try {
                        byte[] bytes = payload.getBytes(StandardCharsets.UTF_8);
                        for (int i = 0; i < numIterations; ++i) {
                            try (Socket s = ss.accept()) {
                                conCount.incrementAndGet();

                                try (OutputStream sockOut = s.getOutputStream()) {
                                    sockOut.write(bytes);
                                    sockOut.flush();
                                }
                            }
                        }
                    } catch (Exception e) {
                        log.error("Failed to complete run loop", e);
                    }
                }
            };
            tWriter.start();
            Thread.sleep(TimeUnit.SECONDS.toMillis(1L));
            assertTrue("Server not started", started[0]);

            final RuntimeException lenOK[] = new RuntimeException[numIterations];
            final RuntimeException dataOK[] = new RuntimeException[numIterations];
            byte b2[] = new byte[payload.length()];
            byte b1[] = new byte[b2.length / 2];

            for (int i = 0; i < numIterations; i++) {
                final int ii = i;
                try (Socket s = new Socket(TEST_LOCALHOST, sinkPort); InputStream sockIn = s.getInputStream()) {
                    s.setSoTimeout((int) TimeUnit.SECONDS.toMillis(10L));

                    int read1 = sockIn.read(b1);
                    String part1 = new String(b1, 0, read1, StandardCharsets.UTF_8);
                    Thread.sleep(50);

                    int read2 = sockIn.read(b2);
                    String part2 = new String(b2, 0, read2, StandardCharsets.UTF_8);
                    int totalRead = read1 + read2;
                    lenOK[ii] = (payload.length() == totalRead) ? null
                            : new IndexOutOfBoundsException(
                                    "Mismatched length: expected=" + payload.length() + ", actual=" + totalRead);

                    String readData = part1 + part2;
                    dataOK[ii] = payload.equals(readData) ? null : new IllegalStateException("Mismatched content");
                    if (lenOK[ii] != null) {
                        throw lenOK[ii];
                    }

                    if (dataOK[ii] != null) {
                        throw dataOK[ii];
                    }
                } catch (Exception e) {
                    if (e instanceof IOException) {
                        log.warn("I/O exception in iteration #" + i, e);
                    } else {
                        log.error("Failed to complete iteration #" + i, e);
                    }
                }
            }
            int ok = 0;
            for (int i = 0; i < numIterations; i++) {
                ok += (lenOK[i] == null) ? 1 : 0;
            }
            log.info("Successful iteration: " + ok + " out of " + numIterations);
            Thread.sleep(TimeUnit.SECONDS.toMillis(1L));
            for (int i = 0; i < numIterations; i++) {
                assertNull("Bad length at iteration " + i, lenOK[i]);
                assertNull("Bad data at iteration " + i, dataOK[i]);
            }
            Thread.sleep(TimeUnit.SECONDS.toMillis(1L));
            session.delPortForwardingR(forwardedPort);
            ss.close();
            tWriter.join(TimeUnit.SECONDS.toMillis(11L));
        } finally {
            session.disconnect();
        }
    }

    @Test
    public void testForwardingOnLoad() throws Exception {
        //        final String path = "/history/recent/troubles/";
        //        final String host = "www.bbc.co.uk";
        //        final String path = "";
        //        final String host = "www.bahn.de";
        final String path = "";
        final String host = TEST_LOCALHOST;
        final int nbThread = 2;
        final int nbDownloads = 2;
        final int nbLoops = 2;

        StringBuilder resp = new StringBuilder();
        resp.append("<html><body>\n");
        for (int i = 0; i < 1000; i++) {
            resp.append("0123456789\n");
        }
        resp.append("</body></html>\n");
        final StringBuilder sb = new StringBuilder();
        sb.append("HTTP/1.1 200 OK").append('\n');
        sb.append("Content-Type: text/HTML").append('\n');
        sb.append("Content-Length: ").append(resp.length()).append('\n');
        sb.append('\n');
        sb.append(resp);
        NioSocketAcceptor acceptor = new NioSocketAcceptor();
        acceptor.setHandler(new IoHandlerAdapter() {
            @Override
            public void messageReceived(IoSession session, Object message) throws Exception {
                session.write(IoBuffer.wrap(sb.toString().getBytes(StandardCharsets.UTF_8)));
            }
        });
        acceptor.setReuseAddress(true);
        acceptor.bind(new InetSocketAddress(0));
        final int port = acceptor.getLocalAddress().getPort();

        Session session = createSession();
        try {
            final int forwardedPort1 = session.setPortForwardingL(0, host, port);
            final int forwardedPort2 = Utils.getFreePort();
            session.setPortForwardingR(forwardedPort2, TEST_LOCALHOST, forwardedPort1);
            outputDebugMessage("URL: http://localhost %s", forwardedPort2);

            final CountDownLatch latch = new CountDownLatch(nbThread * nbDownloads * nbLoops);
            final Thread[] threads = new Thread[nbThread];
            final List<Throwable> errors = new CopyOnWriteArrayList<>();
            for (int i = 0; i < threads.length; i++) {
                threads[i] = new Thread(getCurrentTestName() + "[" + i + "]") {
                    @Override
                    public void run() {
                        for (int j = 0; j < nbLoops; j++) {
                            final MultiThreadedHttpConnectionManager mgr = new MultiThreadedHttpConnectionManager();
                            final HttpClient client = new HttpClient(mgr);
                            client.getHttpConnectionManager().getParams().setDefaultMaxConnectionsPerHost(100);
                            client.getHttpConnectionManager().getParams().setMaxTotalConnections(1000);
                            for (int i = 0; i < nbDownloads; i++) {
                                try {
                                    checkHtmlPage(client, new URL("http://localhost:" + forwardedPort2 + path));
                                } catch (Throwable e) {
                                    errors.add(e);
                                } finally {
                                    latch.countDown();
                                    System.err.println("Remaining: " + latch.getCount());
                                }
                            }
                            mgr.shutdown();
                        }
                    }
                };
            }
            for (Thread thread : threads) {
                thread.start();
            }
            latch.await();
            for (Throwable t : errors) {
                t.printStackTrace();
            }
            assertEquals(0, errors.size());
        } finally {
            session.disconnect();
        }
    }

    protected Session createSession() throws JSchException {
        JSch sch = new JSch();
        Session session = sch.getSession("sshd", TEST_LOCALHOST, sshPort);
        session.setUserInfo(new SimpleUserInfo("sshd"));
        session.connect();
        return session;
    }

    protected void checkHtmlPage(HttpClient client, URL url) throws IOException {
        client.setHostConfiguration(new HostConfiguration());
        client.getHostConfiguration().setHost(url.getHost(), url.getPort());
        GetMethod get = new GetMethod("");
        get.getParams().setVersion(HttpVersion.HTTP_1_1);
        client.executeMethod(get);
        String str = get.getResponseBodyAsString();
        if (str.indexOf("</html>") <= 0) {
            System.err.println(str);
        }
        assertTrue("Missing HTML close tag", str.indexOf("</html>") > 0);
        get.releaseConnection();
        //        url.openConnection().setDefaultUseCaches(false);
        //        Reader reader = new BufferedReader(new InputStreamReader(url.openStream()));
        //        try {
        //            StringWriter sw = new StringWriter();
        //            char[] buf = new char[8192];
        //            while (true) {
        //                int len = reader.read(buf);
        //                if (len < 0) {
        //                    break;
        //                }
        //                sw.write(buf, 0, len);
        //            }
        //            assertTrue(sw.toString().indexOf("</html>") > 0);
        //        } finally {
        //            reader.close();
        //        }
    }
}