SubProcessRunner.java

package net.morimekta.io.proc;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import static java.util.Objects.requireNonNull;

/**
 * Helper class to run a subprocess and handle the process' input and
 * output streams. This can be done to run a subprocess as an interactive
 * shell, like this:
 *
 * <pre>{@code
 * SubProcessRunner runner = new SubProcessRunner();
 * runner.setOut(System.out);
 * runner.setErr(System.err);
 * runner.exec(System.in, "sh", "-c" "cat file | less");
 * }</pre>
 * <p>
 * It's main design is to handle simple commands with it's output.
 * Setting the sub-process' input (part of {@link #exec(InputStream, String...)} params)
 * will act the same as controlling the programs input, whether it
 * asks for key-presses or piped input.
 */
public class SubProcessRunner {
    private static final long NS_IN_MS = TimeUnit.MILLISECONDS.toNanos(1);

    private Runtime             runtime;
    private ThreadFactory       threadFactory;
    private OutputStream        out;
    private OutputStream        err;
    private Map<String, String> env;
    private Path                workingDir;
    private long                deadlineMs;
    private long                deadlineFlushMs;

    /**
     * Create a SubProcessRunner instance. By default it will ignore all output
     * and wait forever for program to exit.
     */
    public SubProcessRunner() {
        this.threadFactory = Executors.defaultThreadFactory();
        this.runtime = Runtime.getRuntime();
        this.out = OutputStream.nullOutputStream();
        this.err = OutputStream.nullOutputStream();
        this.env = System.getenv();
        this.deadlineMs = 0;
        this.deadlineFlushMs = 0;
    }

    /**
     * Set the runtime implementation used to run processes.
     *
     * @param runtime The runtime implementation.
     */
    public void setRuntime(Runtime runtime) {
        requireNonNull(runtime, "runtime == null");
        this.runtime = runtime;
    }

    /**
     * @param threadFactory Set the thread factory used to generate local threads used
     *                      for handling IO piping.
     */
    public void setThreadFactory(ThreadFactory threadFactory) {
        requireNonNull(threadFactory, "threadFactory == null");
        this.threadFactory = threadFactory;
    }

    /**
     * Set output stream to receive the programs standard output stream.
     *
     * @param out The output stream
     */
    public void setOut(OutputStream out) {
        requireNonNull(out, "out == null");
        this.out = out;
    }

    /**
     * Set output stream to receive the programs standard error stream.
     *
     * @param err The output stream
     */
    public void setErr(OutputStream err) {
        requireNonNull(err, "err == null");
        this.err = err;
    }

    /**
     * Set the working dir where the program should be run.
     *
     * @param workingDir The working dir path.
     */
    public void setWorkingDir(Path workingDir) {
        requireNonNull(workingDir, "workingDir == null");
        this.workingDir = workingDir;
    }

    /**
     * Set the environment for the process.
     *
     * @param env The environment map.
     */
    public void setEnv(Map<String, String> env) {
        this.env = Map.copyOf(requireNonNull(env));
    }

    /**
     * Add a set of variables to the environment for the process.
     *
     * @param env Env variables to set.
     */
    public void addToEnv(Map<String, String> env) {
        HashMap<String, String> tmp = new HashMap<>(this.env);
        tmp.putAll(requireNonNull(env, "env == null"));
        setEnv(tmp);
    }

    /**
     * Set deadline in millis for the program to finish. If the program does not
     * finish in the assigned deadline, it will be forcefully terminated and an
     * {@link IOException} will be thrown from {@link #exec(String...)}.
     *
     * @param deadlineMs The deadline in MS.
     */
    public void setDeadlineMs(long deadlineMs) {
        if (deadlineMs < 0L) {
            throw new IllegalArgumentException("deadlineMs: " + deadlineMs + " < 0");
        }
        this.deadlineMs = deadlineMs;
    }

    /**
     * Set deadline in millis for the output handlers to finish after the program has
     * exited. If the streams are not fully handled within that time, the {@link #exec(String...)}
     * method will throw an {@link IOException}. This is per default 1s, which will be the
     * deadline of 0 is set.
     *
     * @param deadlineFlushMs The deadline in MS.
     */
    public void setDeadlineFlushMs(long deadlineFlushMs) {
        if (deadlineFlushMs < 0L) {
            throw new IllegalArgumentException("deadlineFlushMs: " + deadlineFlushMs + " < 0");
        }
        this.deadlineFlushMs = deadlineFlushMs;
    }

    /**
     * Execute command with no input. This will immediately close the stream to the process,
     * so if it tries to read it will not block, but exit.
     *
     * @param command The command to be executed.
     * @return The exit code of the program.
     * @throws IOException If the program execution failed.
     */
    public int exec(String... command) throws IOException {
        return exec(null, command);
    }

    /**
     * Execute command with specified input.
     *
     * @param in      The input stream to read input from.
     * @param command The command to be executed.
     * @return The exit code of the program.
     * @throws IOException If the program execution failed.
     */
    public int exec(InputStream in, String... command) throws IOException {
        return runSubProcess(runtime,
                             threadFactory,
                             env,
                             out,
                             err,
                             in,
                             workingDir,
                             deadlineMs,
                             deadlineFlushMs,
                             command);
    }

    // --- Private ---

    private static String[] makeEnvP(Map<String, String> env) {
        return env.entrySet()
                  .stream()
                  .map(e -> e.getKey() + "=" + e.getValue())
                  .toArray(String[]::new);
    }

    private static int runSubProcess(
            Runtime runtime,
            ThreadFactory threadFactory,
            Map<String, String> env,
            OutputStream out,
            OutputStream err,
            InputStream in,
            Path workingDir,
            long deadlineMs,
            long deadlineFlushMs,
            String... command) throws IOException {
        requireNonNull(runtime, "runtime == null");
        requireNonNull(out, "out == null");
        requireNonNull(err, "err == null");
        requireNonNull(command, "cmd == null");
        if (command.length == 0) {
            throw new IllegalArgumentException("empty command");
        }

        ExecutorService executor = Executors.newFixedThreadPool(3, threadFactory);
        try {
            AtomicReference<IOException> ioException = new AtomicReference<>();
            long startNano = System.nanoTime();

            Process process = workingDir == null
                              ? runtime.exec(command, makeEnvP(env))
                              : runtime.exec(command, makeEnvP(env), workingDir.toFile());
            executor.execute(() -> handleStreamInternal(ioException, process.getErrorStream(), err));
            executor.execute(() -> handleStreamInternal(ioException, process.getInputStream(), out));
            if (in != null) {
                executor.execute(() -> handleStreamInternal(ioException, in, process.getOutputStream()));
            } else {
                // Always close the program's input stream to force it to stop reading.
                // NOTE: This is not identical to how interactive apps work, but avoids
                // the test to halt because of problems reading from std input stream.
                //
                // TODO: Figure out the correct default behavior + deadline.
                process.getOutputStream().close();
            }
            Thread shutDownThread = threadFactory.newThread(process::destroyForcibly);
            shutDownThread.setDaemon(false);
            try {
                runtime.addShutdownHook(shutDownThread);
                if (deadlineMs > 0) {
                    if (!process.waitFor(deadlineMs, TimeUnit.MILLISECONDS)) {
                        process.destroyForcibly();
                        long endNano = System.nanoTime();
                        long ms = (endNano - startNano) / NS_IN_MS;
                        throw new IOException(makeTimeoutExceptionMessage(ms, deadlineMs, command));
                    }
                } else {
                    process.waitFor();
                }
            } catch (InterruptedException e) {
                executor.shutdown();
                process.destroyForcibly();
                throw e;
            } finally {
                runtime.removeShutdownHook(shutDownThread);
            }

            executor.shutdown();
            long flushDeadline = deadlineFlushMs == 0 ? TimeUnit.SECONDS.toMillis(1L) : deadlineFlushMs;
            if (!executor.awaitTermination(flushDeadline, TimeUnit.MILLISECONDS)) {
                executor.shutdownNow();
                throw new IOException("IO thread handling timeout");
            }

            if (ioException.get() != null) {
                throw new IOException(ioException.get().getMessage(), ioException.get());
            }
            return process.exitValue();
        } catch (InterruptedException e) {
            throw new IOException(e.getMessage(), e);
        } finally {
            if (!executor.isShutdown()) {
                executor.shutdownNow();
            }
        }
    }

    private static String makeTimeoutExceptionMessage(long ms, long deadlineMs, String[] cmd) {
        StringBuilder bld = new StringBuilder();
        boolean first = true;
        for (String c : cmd) {
            if (first) {
                first = false;
            } else {
                bld.append(" ");
            }

            // Simple but not that efficient escaping.
            String esc = c.replaceAll("\\\\", "\\\\")
                          .replaceAll("\t", "\\t")
                          .replaceAll("\f", "\\f")
                          .replaceAll("\n", "\\n")
                          .replaceAll("\r", "\\r")
                          .replaceAll("\"", "\\\"");
            // Quote where escaping is needed, OR the argument contains a literal space.
            if (c.contains(" ") || c.contains("'") || esc.contains("\\")) {
                bld.append('\"')
                   .append(esc)
                   .append('\"');
            } else {
                bld.append(c);
            }
        }

        return "deadline exceeded: " + ms + " > " + deadlineMs + ": " + bld;
    }

    private static void handleStreamInternal(
            AtomicReference<IOException> ioException,
            InputStream inputStream,
            OutputStream outputStream) {
        try {
            streamCopy(inputStream, outputStream);
        } catch (IOException e) {
            ioException.updateAndGet(old -> maybeSuppress(old, e));
        } finally {
            try {
                outputStream.close();
            } catch (IOException e) {
                ioException.updateAndGet(old -> maybeSuppress(old, e));
            }
            try {
                inputStream.close();
            } catch (IOException e) {
                ioException.updateAndGet(old -> maybeSuppress(old, e));
            }
        }
    }

    private static void streamCopy(InputStream in, OutputStream out) throws IOException {
        byte[] buffer = new byte[4 * 1024];
        int b;
        while ((b = in.read(buffer)) >= 0) {
            out.write(buffer, 0, b);
        }
        out.flush();
    }

    private static IOException maybeSuppress(IOException old, IOException e) {
        if (old != null) {
            old.addSuppressed(e);
            return old;
        }
        return e;
    }
}