se.nbis.sftpsquid.SftpSquid.java Source code

Java tutorial

Introduction

Here is the source code for se.nbis.sftpsquid.SftpSquid.java

Source

/*
 * Copyright 2016 Johan Viklund, NBIS, https://www.nbis.se
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package se.nbis.sftpsquid;

import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.common.StreamCopier;
import net.schmizz.sshj.sftp.FileMode;
import net.schmizz.sshj.sftp.OpenMode;
import net.schmizz.sshj.sftp.RemoteFile;
import net.schmizz.sshj.sftp.RemoteResourceInfo;
import net.schmizz.sshj.sftp.SFTPClient;
import net.schmizz.sshj.transport.TransportException;
import net.schmizz.sshj.transport.verification.PromiscuousVerifier;
import net.schmizz.sshj.userauth.UserAuthException;
import net.schmizz.sshj.userauth.method.AuthKeyboardInteractive;
import net.schmizz.sshj.userauth.method.AuthMethod;
import net.schmizz.sshj.userauth.method.AuthPassword;

import org.apache.commons.io.FilenameUtils;
import org.apache.log4j.BasicConfigurator;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;

import java.io.IOException;
import java.net.ConnectException;
import java.util.EnumSet;
import java.util.LinkedList;
import java.util.List;

/**
 * SftpSquid - Program to transfer files between multiple sftp servers.
 *
 * <p>It is assumed that the user of this program do not have shell access through
 * SSH. Instead it streams the data from one server to the other through the
 * local machine. No data is stored locally.
 *
 * <p>Further the program do not store any authentication information. Passwords
 * and other credentials are just passed through to the servers in question.
 *
 * @author  Johan Viklund <johan.viklund@bils.se>
 * @version 0.1
 * @since   2016-03-31
 */
public class SftpSquid {
    private SSHClient[] ssh_clients;
    private SFTPClient[] sftp_clients;
    private HostFileInfo[] hfs;

    private Logger log = Logger.getLogger(getClass());

    public static void main(String[] args) throws IOException {
        BasicConfigurator.configure();
        Logger.getRootLogger().setLevel(Level.ERROR);
        //Logger.getLogger("se.nbis.sftp_squid.SftpSquid").setLevel(Level.DEBUG);
        HostFileInfo[] hf;

        try {
            hf = parseArgs(args);
            SftpSquid ss = new SftpSquid(hf);
            ss.run();
        } catch (IOException e) {
            System.err.println("ERROR: " + e.getMessage());
            System.exit(1);
        }
    }

    /**
     * Parse command line arguments.
     *
     * @param args the args array from the main method
     */
    public static HostFileInfo[] parseArgs(String[] args) throws IOException {
        if (args.length < 2) {
            throw new IOException("You need to supply at least 2 servers");
        }
        HostFileInfo[] hf = new HostFileInfo[args.length];
        for (int i = 0; i < args.length; i++) {
            try {
                hf[i] = new HostFileInfo(args[i]);
            } catch (IOException e) {
                throw new IOException("Malformed host argument '" + args[i] + "'");
            }
        }
        return hf;
    }

    /**
     * Construct a new SftpSquid object from an array of HostFileInfos
     *
     * @param HostFileInfo[] an array of HostFileInfo
     */
    SftpSquid(HostFileInfo[] hf) throws IOException {
        if (hf.length < 2) {
            throw new IOException("Need at least 2 hosts to transfer between");
        }
        this.hfs = hf;
        this.ssh_clients = new SSHClient[hf.length];
        this.sftp_clients = new SFTPClient[hf.length];
    }

    /**
     * The main loop of the program.
     */
    public void run() throws IOException {
        log.debug("run()");
        try {
            connectAll();
            transfer();
        } finally {
            closeAll();
        }
    }

    /**
     * Connect to the servers specified on the command line
     */
    public void connectAll() throws IOException {
        log.debug("connectAll");
        for (int i = 0; i < hfs.length; i++) {
            try {
                int tries = 3;
                while (tries-- > 0) {
                    try {
                        SSHClient ssh = connect(hfs[i]);
                        SFTPClient sftp = ssh.newSFTPClient();

                        ssh_clients[i] = ssh;
                        sftp_clients[i] = sftp;
                        tries = -1;
                    } catch (UserAuthException e) {
                        System.err.println(
                                "Incorrect username and/or password for " + hfs[i].userHostSpec() + " try again.");
                        if (tries <= 0) {
                            throw e;
                        }
                    }
                }
            } catch (TransportException e) {
                System.err.println("Something went wrong: " + e);
                throw e;
            } catch (ConnectException e) {
                System.err.printf("Connection failure for %s: %s\n", hfs[i].userHostSpec(), e.getMessage());
                throw e;
            }
        }
    }

    /**
     * Close all open connections
     */
    public void closeAll() throws IOException {
        log.debug("closeAll()");
        for (SSHClient ssh_client : ssh_clients) {
            if (ssh_client != null && ssh_client.isConnected()) {
                ssh_client.close();
            }
        }
    }

    /**
     * Connect to one SFTP server
     *
     * @param  HostFileInfo A HostFileInfo object representing the server
     * @return SSHClient
     */
    private SSHClient connect(HostFileInfo hf) throws IOException {
        log.debug("Connecting to " + hf);
        SSHClient ssh = new SSHClient();
        ssh.addHostKeyVerifier(new PromiscuousVerifier());
        ssh.connect(hf.host, hf.port);

        try {
            List<AuthMethod> authmethods = new LinkedList<AuthMethod>();
            authmethods.add(new AuthKeyboardInteractive(new UserKeyboardAuth(hf)));
            authmethods.add(new AuthPassword(new PasswordAuth(hf)));
            ssh.auth(hf.user, authmethods);
        } catch (IOException e) {
            ssh.close(); // We have to clean this up
            throw e;
        }

        return ssh;
    }

    /**
     * Transfer the files
     */
    public void transfer() throws IOException {
        String source = hfs[0].file;
        String destination = hfs[1].file;
        log.debug("transfer(): " + source + " -> " + destination);

        FileMode.Type sourceType = getType(source, sftp_clients[0]);
        FileMode.Type destType = getType(destination, sftp_clients[1]);

        if (sourceType == FileMode.Type.DIRECTORY && destType != sourceType) {
            throw new IOException("Destination has to be a directory");
        }

        List<String> sources = buildTransferList(source, sftp_clients[0]);

        int lastSeparatorInSource = FilenameUtils.indexOfLastSeparator(source);
        if (lastSeparatorInSource == -1) {
            if (sourceType == FileMode.Type.REGULAR) {
                lastSeparatorInSource = 0;
            } else {
                lastSeparatorInSource = source.length();
            }
        }
        log.debug("transfer() lastSeparatorInSource: " + lastSeparatorInSource);

        for (String s : sources) {
            String d = destination;
            if (destType == FileMode.Type.DIRECTORY) {
                d += s.substring(lastSeparatorInSource);
            }

            transferFile(s, d);
        }
    }

    /**
     * Create a list of files to transfer
     */
    private List<String> buildTransferList(String source, SFTPClient c) throws IOException {
        log.debug("Building transferlist from " + source);

        FileMode.Type type = getType(source, c);
        List<String> ret = new LinkedList<String>();

        if (type == FileMode.Type.REGULAR) {
            ret.add(source);
            return ret;
        }
        if (type != FileMode.Type.DIRECTORY) {
            throw new IOException("Can't transfer this type of file (" + type + ")");
        }

        List<RemoteResourceInfo> paths = c.ls(source);
        for (RemoteResourceInfo p : paths) {
            ret.addAll(buildTransferList(p.getPath(), c));
        }

        return ret;
    }

    /**
     * Remove superflous parts of the path, such as double /
     *
     * @param String path to normalize
     * @return String
     */
    private String normalizePath(String path) {
        String normalized = FilenameUtils.normalize(path);
        return FilenameUtils.separatorsToUnix(normalized); // In case we run on windows
    }

    /**
     * Transfer one file between the two systems
     *
     * @param source The source file as a string
     * @param destination the destination file as a string
     */
    private void transferFile(String source, String destination) throws IOException {
        log.debug("Transfer " + source + " -> " + destination);
        createPath(destination, sftp_clients[1]);

        RemoteFile fileSource = sftp_clients[0].open(source);
        RemoteFile fileDestination = sftp_clients[1].open(destination,
                EnumSet.of(OpenMode.WRITE, OpenMode.CREAT, OpenMode.TRUNC));

        try {
            RemoteFile.ReadAheadRemoteFileInputStream streamSource = fileSource.new ReadAheadRemoteFileInputStream(
                    16);
            RemoteFile.RemoteFileOutputStream streamDestination = fileDestination.new RemoteFileOutputStream(0, 16);

            try {
                StreamCopier sc = new StreamCopier(streamSource, streamDestination);
                sc.bufSize(calculateMaxBufferSize(fileSource, fileDestination));
                sc.keepFlushing(false);
                sc.listener(new ProgressBarListener(fileSource.length(), fileNameOnly(source)));
                sc.copy();
            } finally {
                streamSource.close();
                streamDestination.close();
            }
        } finally {
            fileSource.close();
            fileDestination.close();
        }
    }

    /**
     * Helper to create directories if needed on target
     */
    private void createPath(String path, SFTPClient client) {
        path = normalizePath(path);
        log.debug("createPath based on " + path);

        String dir = "";
        if (path.charAt(0) != '/') {
            dir = ".";
        }

        String[] components = path.split("/");
        for (int i = 0; i < components.length - 1; i++) {
            dir += '/' + components[i];
        }
        dir = normalizePath(dir);
        if (dir.length() == 0) {
            return;
        }

        if (dir.charAt(0) != '/') {
            dir = "./" + dir;
        }

        log.debug("createPath: " + dir);
        try {
            client.mkdirs(dir);
        } catch (IOException e) {
            log.debug("Could not createPath: " + e);
        }
    }

    /**
     * Only the filename part of the path
     */
    private String fileNameOnly(String path) {
        String[] parts = path.split("/"); // SFTP Servers always use this separator
        return parts[parts.length - 1];
    }

    /* This code is taken from the library docs and adjusted to match the
     * situation we have with two remote servers */
    private int calculateMaxBufferSize(RemoteFile f1, RemoteFile f2) {
        int remoteMaxPacketSize = sftp_clients[0].getSFTPEngine().getSubsystem().getRemoteMaxPacketSize();
        if (remoteMaxPacketSize > sftp_clients[1].getSFTPEngine().getSubsystem().getRemoteMaxPacketSize()) {
            remoteMaxPacketSize = sftp_clients[1].getSFTPEngine().getSubsystem().getRemoteMaxPacketSize();
        }
        int packetOverhead = f2.getOutgoingPacketOverhead();
        return remoteMaxPacketSize - packetOverhead;
    }

    /**
     * Get filetype
     */
    private FileMode.Type getType(String p, SFTPClient c) {
        FileMode.Type t;
        try {
            t = c.type(p);
            if (t == FileMode.Type.SYMKLINK) {
                t = getType(c.readlink(p), c);
            }
        } catch (Exception e) {
            t = FileMode.Type.UNKNOWN;
        }
        return t;
    }
}