de.alexkamp.sandbox.ChrootSandbox.java Source code

Java tutorial

Introduction

Here is the source code for de.alexkamp.sandbox.ChrootSandbox.java

Source

package de.alexkamp.sandbox;

/**
 * This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
 *
 * If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
 */

import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonToken;
import de.alexkamp.sandbox.model.SandboxData;

import java.io.*;
import java.net.Socket;
import java.nio.file.FileVisitResult;
import java.nio.file.FileVisitor;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.attribute.BasicFileAttributes;
import java.util.Stack;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

/**
 * Abstraction around a connection to a sandbox.
 */
public class ChrootSandbox implements Runnable, Sandbox {
    private final JsonFactory factory;
    private final Socket socket;
    private final JsonGenerator sender;

    private final SandboxData data;

    private final Thread reader;

    private final Lock socketLock = new ReentrantLock();
    private final Condition socketCondition = socketLock.newCondition();

    private ChrootSandboxProcess currentProcess;
    private final Stack<SandboxException> expStack = new Stack<>();

    public ChrootSandbox(JsonFactory factory, String host, int port, SandboxData data) {
        this.factory = new JsonFactory();

        this.data = data;

        try {
            socket = new Socket(host, port);
            this.sender = factory.createGenerator(socket.getOutputStream());
        } catch (IOException ex) {
            throw new SandboxException(ex);
        }

        this.reader = new Thread(this);
        this.reader.start();
        try {
            connect();
        } catch (IOException ex) {
            throw new SandboxException(ex);
        }
    }

    private void connect() throws IOException {
        data.toJson(sender);
        sender.flush();
    }

    @Override
    public SandboxProcess newProcess(String workdir, long timeout, String binary, String... args) {
        checkState();

        return new ChrootSandboxProcess(this, workdir, binary, args, timeout);
    }

    protected void checkState() {
        if (!expStack.isEmpty()) {
            if (1 == expStack.size()) {
                throw expStack.pop();
            }
            throw new MultipleSandboxException(expStack);
        }
        if (socket.isClosed()) {
            throw new SandboxException("Connection is already closed.");
        }
    }

    @Override
    public void close() throws Exception {
        if (socket.isClosed()) {
            return;
        }

        try {
            socketLock.lock();
            sender.writeStartObject();
            sender.writeObjectField("Executable", "exit");
            sender.writeEndObject();
            sender.flush();
            while (!socket.isClosed()) {
                try {
                    socketCondition.await();
                } catch (InterruptedException ex) {
                }
            }
        } catch (IOException ex) {
            throw new SandboxException(ex);
        } finally {
            socketLock.unlock();
        }
    }

    public void closeSocket() throws Exception {
        try {
            socketLock.lock();
            socket.close();
            socketCondition.signalAll();
        } finally {
            socketLock.unlock();
        }
    }

    @Override
    public void run() {
        try {
            JsonParser parser = factory.createParser(socket.getInputStream());

            String caller = null;
            String channel = null;
            String message = null;
            JsonToken jt = null;
            while (null != (jt = parser.nextToken())) {
                if (jt.isStructStart()) {
                    caller = null;
                    channel = null;
                    message = null;
                } else if (jt.isStructEnd()) {
                    try {
                        handle(caller, channel, message);
                    } catch (Exception ex) {
                        handleAsyncError(ex);
                    }
                } else {
                    String name = parser.getCurrentName();
                    parser.nextToken();
                    switch (name) {
                    case "Caller":
                        caller = parser.getText();
                        break;
                    case "Channel":
                        channel = parser.getText();
                        break;
                    case "Message":
                        message = parser.getText();
                        break;
                    }
                }
            }

        } catch (Exception e) {
            handleAsyncError(e);
        }
    }

    private void handleAsyncError(Exception e) {
        if (e instanceof SandboxException) {
            expStack.push((SandboxException) e);
            if (null != currentProcess) {
                currentProcess.wakeUp();
            }
        } else {
            expStack.push(new SandboxException(e));
            if (null != currentProcess) {
                currentProcess.wakeUp();
            }
        }
    }

    private void handle(String caller, String channel, String message) throws Exception {
        switch (caller) {
        case "server":
            // the server talks only in case of errors
            if (!"error".equals(channel)) {
                throw new SandboxException(
                        "Unexpected message: " + caller + " " + channel + " \"" + message + "\"");
            }
            closeSocket();
            throw new SandboxException(message);
        case "sandbox":
            // if the sandbox says error, something bad is going on
            if ("error".equals(channel)) {
                closeSocket();
                throw new SandboxException(
                        "Unexpected message: " + caller + " " + channel + " \"" + message + "\"");
            } else if ("exit".equals(channel)) {
                closeSocket();
            }
            break;
        case "process":
            try {
                if (!currentProcess.handle(channel, message)) {
                    processDone();
                }
            } catch (SandboxException ex) {
                processDone();
                throw ex;
            }
            break;
        default:
            throw new SandboxException("Message from " + caller
                    + ". Are you combining different versions? Message is \"" + message + "\"");
        }
    }

    private void processDone() {
        ChrootSandboxProcess sp = currentProcess;
        currentProcess = null;
        sp.wakeUp();
    }

    protected void start(ChrootSandboxProcess process) {
        this.checkState();
        if (null != currentProcess) {
            throw new SandboxException("There seems to be a process running in this sandbox.");
        }
        currentProcess = process;

        try {
            currentProcess.toJson(sender);
            sender.flush();
        } catch (IOException e) {
            throw new SandboxException(e);
        }
    }

    @Override
    public void setEnv(String variable, String value) {
        try {
            SandboxProcess proc = newProcess("/", -1, "export", variable + "=" + value);
            proc.start();
        } catch (SandboxTimeoutException e) {
            throw new IllegalStateException("A timeout on export? Something is really wrong!");
        }
    }

    @Override
    public void copyTo(String path, File target) throws IOException {
        if (!target.exists()) {
            if (!target.mkdirs()) {
                throw new SandboxException("Can not create directory " + target.getName());
            }
        }
        if (!target.isDirectory()) {
            throw new IllegalArgumentException(target.getName() + " is not a directory.");
        }

        walkDirectoryTree(path, new CopyWalker(target));
    }

    @Override
    public void walkDirectoryTree(String basePath, final DirectoryWalker walker) throws IOException {
        final int baseNameCount = data.getBaseDir().toPath().getNameCount();

        File base = new File(data.getBaseDir(), basePath);

        Files.walkFileTree(base.toPath(), new FileVisitor<Path>() {
            @Override
            public FileVisitResult preVisitDirectory(Path path, BasicFileAttributes basicFileAttributes)
                    throws IOException {
                if (walker.visitDirectory(calcSubpath(path))) {
                    return FileVisitResult.CONTINUE;
                }
                return FileVisitResult.SKIP_SUBTREE;
            }

            private String calcSubpath(Path path) {
                if (path.getNameCount() == baseNameCount) {
                    return "/";
                }
                return "/" + path.subpath(baseNameCount, path.getNameCount()).toString();
            }

            @Override
            public FileVisitResult visitFile(Path path, BasicFileAttributes basicFileAttributes)
                    throws IOException {
                String subpath = calcSubpath(path);
                if (walker.visitFile(subpath)) {
                    try (InputStream is = Files.newInputStream(path)) {
                        walker.visitFileContent(subpath, is);
                    }
                }
                return FileVisitResult.CONTINUE;
            }

            @Override
            public FileVisitResult visitFileFailed(Path path, IOException e) throws IOException {
                if (walker.failed(e)) {
                    return FileVisitResult.CONTINUE;
                } else {
                    return FileVisitResult.TERMINATE;
                }
            }

            @Override
            public FileVisitResult postVisitDirectory(Path path, IOException e) throws IOException {
                return FileVisitResult.CONTINUE;
            }
        });
    }
}