dk.deck.remoteconsole.SshRemoteConsole.java Source code

Java tutorial

Introduction

Here is the source code for dk.deck.remoteconsole.SshRemoteConsole.java

Source

/*
 * Copyright 2011 Jesper Terkelsen.
 * 
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 * 
 * http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 * 
 */
package dk.deck.remoteconsole;

import com.jcraft.jsch.Channel;
import com.jcraft.jsch.ChannelExec;
import com.jcraft.jsch.ChannelShell;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.JSchException;
import com.jcraft.jsch.Session;
import com.jcraft.jsch.UserInfo;
import dk.deck.console.CommandResult;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.URL;
import dk.deck.remoteconsole.userinfo.LoggingUserInfo;
import dk.deck.remoteconsole.util.StreamUtil;
import java.io.Writer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/**
 *
 * @author Jesper Terkelsen
 */
public class SshRemoteConsole extends AbstractRemoteConsole implements UserInfoProvider {

    private static Log log = LogFactory.getLog(SshRemoteConsole.class);
    private final JSch jsch;
    private String user;
    private String host;
    private int port;
    private File identityFile;
    private Session session = null;
    private UserInfoProvider userInfoProvider = this;
    private boolean enablePty;

    public SshRemoteConsole(JSch jSch) {
        this.jsch = jSch;
    }

    @Override
    public void setUser(String user) {
        this.user = user;
    }

    @Override
    public String getUser() {
        return user;
    }

    @Override
    public void setHost(String host) {
        this.host = host;
    }

    @Override
    public String getHost() {
        return host;
    }

    @Override
    public void setPort(int port) {
        this.port = port;
    }

    @Override
    public int getPort() {
        return port;
    }

    @Override
    public void setIdentityFile(File identityFile) {
        this.identityFile = identityFile;
    }

    @Override
    public File getIdentityFile() {
        return identityFile;
    }

    public void setEnablePty(boolean enablePty) {
        this.enablePty = enablePty;
    }

    public boolean isEnablePty() {
        return enablePty;
    }

    public void setUserInfoProvider(UserInfoProvider userInfoProvider) {
        if (userInfoProvider == null) {
            throw new IllegalArgumentException("UserInfoProvicer cannot be null");
        }
        this.userInfoProvider = userInfoProvider;
    }

    private static int MAX_CONTENT_LENGTH = 1024 * 1024; // 1m buffer should be enugh for most commands

    @Override
    public CommandResult executeCommandResult(String command, boolean failOnExitNotZero, long disconnectAfterMillis,
            String disconnectAafterContent) throws IOException {
        return executeCommandResult(command, failOnExitNotZero, disconnectAfterMillis, disconnectAafterContent,
                null);
    }

    /**
     * Executes a command on the remote server, via a ssh channel.
     *
     * Captures stdout in the result, while stderr is only logged
     *
     * @param command The unix command to execute
     * @param failOnExitNotZero throw an exception if the unix command does not
     * return zero (0)
     * @param disconnectAfterMillis disconnect after a periods (this is usefull
     * when starting deamons)
     * @param disconnectAfterContent disconnect after this string has appeard
     * in output (this is usefull when starting deamons), can be used in
     * combination
     * @return a CommandResult entity with the output and errorcode.
     * @throws IOException on communication errors
     * @throws IllegalStateException if the exit code check is on
     * @todo Cleanup and split up into several methods
     */
    public CommandResult executeCommandResult(String command, boolean failOnExitNotZero, long disconnectAfterMillis,
            String disconnectAfterContent, Writer liveOutput) throws IOException {
        try {
            CommandResult result = new CommandResult();
            SshRemoteConsole.log.debug("Executing > " + command);
            boolean connect = session == null || !session.isConnected();
            if (connect) {
                connect();
            }
            try {
                Channel channel = session.openChannel("exec"); // shell
                ((ChannelExec) channel).setCommand(command);
                if (enablePty) {
                    ((ChannelExec) channel).setPty(true);
                }
                InputStream error = ((ChannelExec) channel).getErrStream();
                // channel.setOutputStream(System.err);
                InputStream in = channel.getInputStream();
                channel.connect();
                long start = System.currentTimeMillis();
                boolean contentReached = true;
                if (disconnectAfterContent != null && !disconnectAfterContent.equals("")) {
                    contentReached = false;
                }
                StringBuilder output = new StringBuilder();
                StringBuilder errorOutput = new StringBuilder();
                try {
                    byte[] inTmp = new byte[1024];
                    byte[] errorTmp = new byte[1024];
                    while (true) {
                        while (in.available() > 0) {
                            int i = in.read(inTmp, 0, 1024);
                            if (i < 0) {
                                break;
                            }
                            output(new String(inTmp, 0, i), liveOutput, output);
                            log.trace(new String(inTmp, 0, i));
                        }
                        while (error.available() > 0) {
                            int i = error.read(errorTmp, 0, 1024);
                            if (i < 0) {
                                break;
                            }
                            if (output.toString().length() < MAX_CONTENT_LENGTH) {
                                errorOutput.append(new String(errorTmp, 0, i));
                            }
                            log.debug("ERROR: " + new String(errorTmp, 0, i));
                        }
                        if (channel.isClosed()) {
                            result.setExitCode(channel.getExitStatus());
                            if (failOnExitNotZero && channel.getExitStatus() != 0) {
                                log.debug("exit-status: " + channel.getExitStatus());
                                throw new IllegalStateException(
                                        "Exitstatus was: " + channel.getExitStatus() + " output: "
                                                + output.toString() + " error-output: " + errorOutput.toString());
                            }
                            break;
                        }
                        if (disconnectAfterMillis > 0 && contentReached) {
                            long now = System.currentTimeMillis();
                            if (now - start > disconnectAfterMillis) {
                                log.trace("exiting before command is finished after: " + ((now - start) / 1000)
                                        + " seconds.");
                                break;
                            }
                        }
                        // TODO fix possible flaw that clashes with MAX_CONTENT_LENGTH
                        if (!contentReached && (output.toString().contains(disconnectAfterContent)
                                || errorOutput.toString().contains(disconnectAfterContent))) {
                            contentReached = true;
                            start = System.currentTimeMillis();
                            if (disconnectAfterMillis == 0) {
                                break;
                            }
                        }
                        try {
                            Thread.sleep(1000);
                        } catch (InterruptedException ex) {
                            log.warn("Interrupted in sleep", ex);
                        }
                    }
                } finally {
                    channel.disconnect();

                }
                result.setOutput(output.toString());
                result.setErrorOutput(errorOutput.toString());
            } finally {
                if (connect) {
                    disconnect();
                }
            }
            return result;
        } catch (JSchException ex) {
            IOException ioe = new IOException(ex.getMessage());
            ioe.initCause(ex);
            throw ioe;
        }
    }

    private void output(String data, Writer liveOutput, StringBuilder output) throws IOException {
        if (liveOutput != null) {
            liveOutput.append(data);
            liveOutput.flush();
        }
        if (output.toString().length() < MAX_CONTENT_LENGTH) {
            output.append(data);
        }
    }

    /**
     * Opens a shell channel to the server
     *
     * @return The ChannelShell object
     * @throws IOException
     */
    @Override
    public ChannelShell openShell() throws IOException {
        try {
            Channel channel = session.openChannel("shell"); // shell
            ChannelShell shell = (ChannelShell) channel;
            return shell;
        } catch (JSchException ex) {
            IOException ioe = new IOException();
            ioe.initCause(ex);
            throw ioe;
        }
    }

    /**
     * Uploads a file from the URL to the location specified on the server
     *
     * @param lfile URL with the file to upload
     * @param rfile Text location relative to home directory
     * @throws IOException
     */
    @Override
    public void uploadFile(URL lfile, String rfile) throws IOException {
        uploadFile(lfile, rfile, "scp -p -t " + rfile);
    }

    /**
     * Uploads a file to the server using the protocol from scp -t
     *
     * @param lfile URL with the file to upload
     * @param rfile Text location relative to home directory
     * @param command the scp -t command (with or without sudo)
     * @throws IOException
     */
    @Override
    public void uploadFile(URL lfile, String rfile, String command) throws IOException {
        try {
            log.debug("Upload file " + lfile.getFile() + " to " + rfile);
            boolean connect = (session == null || !session.isConnected());
            if (connect) {
                connect();
            }
            // exec 'scp -t rfile' remotely

            Channel channel = session.openChannel("exec");
            ((ChannelExec) channel).setCommand(command);
            // get I/O streams for remote scp
            OutputStream out = channel.getOutputStream();
            InputStream in = channel.getInputStream();
            channel.connect();
            int ack = checkAck(in);
            if (ack != 0) {
                throw new IllegalStateException("checkAck failed value " + ack);
            }
            //out.flush();
            InputStream fis = lfile.openStream();
            // Messure length
            long filesize = StreamUtil.messureContentLenth(fis);
            fis = lfile.openStream();
            command = "C0644 " + filesize + " ";
            if (lfile.getFile().lastIndexOf('/') > 0) {
                command += lfile.getFile().substring(lfile.getFile().lastIndexOf('/') + 1);
            } else {
                command += lfile.getFile();
            }
            command += "\n";
            out.write(command.getBytes());
            out.flush();
            ack = checkAck(in);
            if (ack != 0) {
                throw new IllegalStateException("checkAck failed value " + ack);
            }
            // send a content of lfile
            long written = 0;
            long percent = 0;
            long lastpercent = -1;
            byte[] buf = new byte[1024];
            while (true) {
                int len = fis.read(buf, 0, buf.length);
                if (len <= 0) {
                    break;
                }
                written += len;
                percent = ((written * 100 / filesize));
                if (lastpercent != percent) {
                    lastpercent = percent;
                    log.trace("written " + written + "/" + filesize + " bytes " + percent + "%");
                }
                out.write(buf, 0, len); //out.flush();
            }
            fis.close();
            fis = null;
            // send '\0'
            buf[0] = 0;
            out.write(buf, 0, 1);
            out.flush();
            ack = checkAck(in);
            if (ack != 0) {
                throw new IllegalStateException("checkAck failed value " + ack);
            }
            out.close();
            channel.disconnect();
            if (connect) {
                disconnect();
            }
        } catch (JSchException ex) {
            IOException ioe = new IOException(ex.getMessage());
            ioe.initCause(ex);
            throw ioe;
        }
    }

    @Override
    public void downloadFile(String rfile, OutputStream finalOutput) throws IOException {
        try {
            log.debug("Download file " + rfile);

            String command = "scp -f " + rfile;
            Channel channel = session.openChannel("exec");
            ((ChannelExec) channel).setCommand(command);

            // get I/O streams for remote scp
            OutputStream out = channel.getOutputStream();
            InputStream in = channel.getInputStream();

            channel.connect();

            byte[] buf = new byte[1024];

            // send '\0'
            buf[0] = 0;
            out.write(buf, 0, 1);
            out.flush();

            while (true) {
                int c = checkAck(in);
                if (c != 'C') {
                    break;
                }

                // read '0644 '
                in.read(buf, 0, 5);

                long filesize = 0L;
                while (true) {
                    if (in.read(buf, 0, 1) < 0) {
                        // error
                        break;
                    }
                    if (buf[0] == ' ') {
                        break;
                    }
                    filesize = filesize * 10L + (long) (buf[0] - '0');
                }

                String file = null;
                for (int i = 0;; i++) {
                    in.read(buf, i, 1);
                    if (buf[i] == (byte) 0x0a) {
                        file = new String(buf, 0, i);
                        break;
                    }
                }

                //System.out.println("filesize="+filesize+", file="+file);

                // send '\0'
                buf[0] = 0;
                out.write(buf, 0, 1);
                out.flush();
                long readed = 0;
                long percent = 0;
                long lastpercent = -1;
                long totalFileSize = filesize;

                // read a content of lfile
                int len;
                while (true) {
                    if (buf.length < filesize) {
                        len = buf.length;
                    } else {
                        len = (int) filesize;
                    }
                    len = in.read(buf, 0, len);
                    if (len < 0) {
                        // error 
                        break;
                    }
                    readed += len;
                    percent = ((readed * 100 / filesize));
                    if (lastpercent != percent) {
                        lastpercent = percent;
                        log.trace("reading " + readed + "/" + totalFileSize + " bytes " + percent + "%");
                    }

                    finalOutput.write(buf, 0, len);
                    filesize -= len;
                    if (filesize == 0L) {
                        break;
                    }
                }
                finalOutput.close();

                if (checkAck(in) != 0) {
                    System.exit(0);
                }

                // send '\0'
                buf[0] = 0;
                out.write(buf, 0, 1);
                out.flush();
            }

        } catch (JSchException ex) {
            IOException ioe = new IOException(ex.getMessage());
            ioe.initCause(ex);
            throw ioe;
        }
    }

    /**
     * Check for acknogede
     *
     * @param in The inputstream, at a point where acknogede is expected.
     * @return 0 for success, 1 for error, 2 for fatal error, -1 end of stream
     * @throws IOException
     */
    private static int checkAck(InputStream in) throws IOException {
        int b = in.read();
        // b may be 0 for success,
        //          1 for error,
        //          2 for fatal error,
        //          -1
        if (b == 0) {
            return b;
        }
        if (b == -1) {
            return b;
        }

        if (b == 1 || b == 2) {
            StringBuffer sb = new StringBuffer();
            int c;
            do {
                c = in.read();
                sb.append((char) c);
            } while (c != '\n');
            if (b == 1) { // error
                log.warn(sb.toString());
                throw new IllegalStateException("error: " + sb.toString());
            }
            if (b == 2) { // fatal error
                log.warn(sb.toString());
                throw new IllegalStateException("fatal error: " + sb.toString());
            }
        }
        return b;
    }

    /**
     * Creates a tcp connection with a ssh session. This is where the
     * authentication occours. From here you can execute commands or open a
     * shell, via channels.
     *
     * This method is using the credentials added on the setter methods.
     *
     * @throws IOException
     */
    @Override
    public void connect() throws IOException {
        try {
            if (isConnected()) {
                throw new IllegalStateException("Already connected");
            }
            session = createSession();
            log.debug("Connecting to " + getUser() + "@" + getHost());
            session.connect();
        } catch (JSchException ex) {
            IOException ioe = new IOException(ex.getMessage());
            ioe.initCause(ex);
            throw ioe;
        }
    }

    /**
     * Disconnects the tcp connection, and logout any channel open.
     */
    @Override
    public void disconnect() {
        if (isConnected()) {
            log.debug("Disconnecting");
            session.disconnect();
            session = null;
        } else {
            throw new IllegalStateException("Not connected");
        }
    }

    @Override
    public boolean isConnected() {
        if (session != null && session.isConnected()) {
            return true;
        } else {
            return false;
        }
    }

    private Session createSession() throws JSchException {
        jsch.removeAllIdentity();
        if (identityFile != null) {
            if (identityFile.exists()) {
                jsch.addIdentity(identityFile.getAbsolutePath(), "passphrase");
            } else {
                log.warn("Identity file " + identityFile + " does not exists");
            }
        }
        Session mysession = jsch.getSession(user, host, port);
        UserInfo info = userInfoProvider.getUserInfo();
        mysession.setUserInfo(info);
        return mysession;
    }

    @Override
    public UserInfo getUserInfo() {
        return new LoggingUserInfo();
    }
}