SubProcessRunner.java
package net.morimekta.io.proc;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import static java.util.Objects.requireNonNull;
/**
* Helper class to run a subprocess and handle the process' input and
* output streams. This can be done to run a subprocess as an interactive
* shell, like this:
*
* <pre>{@code
* SubProcessRunner runner = new SubProcessRunner();
* runner.setOut(System.out);
* runner.setErr(System.err);
* runner.exec(System.in, "sh", "-c" "cat file | less");
* }</pre>
* <p>
* It's main design is to handle simple commands with it's output.
* Setting the sub-process' input (part of {@link #exec(InputStream, String...)} params)
* will act the same as controlling the programs input, whether it
* asks for key-presses or piped input.
*/
public class SubProcessRunner {
private static final long NS_IN_MS = TimeUnit.MILLISECONDS.toNanos(1);
private Runtime runtime;
private ThreadFactory threadFactory;
private OutputStream out;
private OutputStream err;
private Map<String, String> env;
private Path workingDir;
private long deadlineMs;
private long deadlineFlushMs;
/**
* Create a SubProcessRunner instance. By default it will ignore all output
* and wait forever for program to exit.
*/
public SubProcessRunner() {
this.threadFactory = Executors.defaultThreadFactory();
this.runtime = Runtime.getRuntime();
this.out = OutputStream.nullOutputStream();
this.err = OutputStream.nullOutputStream();
this.env = System.getenv();
this.deadlineMs = 0;
this.deadlineFlushMs = 0;
}
/**
* Set the runtime implementation used to run processes.
*
* @param runtime The runtime implementation.
*/
public void setRuntime(Runtime runtime) {
requireNonNull(runtime, "runtime == null");
this.runtime = runtime;
}
/**
* @param threadFactory Set the thread factory used to generate local threads used
* for handling IO piping.
*/
public void setThreadFactory(ThreadFactory threadFactory) {
requireNonNull(threadFactory, "threadFactory == null");
this.threadFactory = threadFactory;
}
/**
* Set output stream to receive the programs standard output stream.
*
* @param out The output stream
*/
public void setOut(OutputStream out) {
requireNonNull(out, "out == null");
this.out = out;
}
/**
* Set output stream to receive the programs standard error stream.
*
* @param err The output stream
*/
public void setErr(OutputStream err) {
requireNonNull(err, "err == null");
this.err = err;
}
/**
* Set the working dir where the program should be run.
*
* @param workingDir The working dir path.
*/
public void setWorkingDir(Path workingDir) {
requireNonNull(workingDir, "workingDir == null");
this.workingDir = workingDir;
}
/**
* Set the environment for the process.
*
* @param env The environment map.
*/
public void setEnv(Map<String, String> env) {
this.env = Map.copyOf(requireNonNull(env));
}
/**
* Add a set of variables to the environment for the process.
*
* @param env Env variables to set.
*/
public void addToEnv(Map<String, String> env) {
HashMap<String, String> tmp = new HashMap<>(this.env);
tmp.putAll(requireNonNull(env, "env == null"));
setEnv(tmp);
}
/**
* Set deadline in millis for the program to finish. If the program does not
* finish in the assigned deadline, it will be forcefully terminated and an
* {@link IOException} will be thrown from {@link #exec(String...)}.
*
* @param deadlineMs The deadline in MS.
*/
public void setDeadlineMs(long deadlineMs) {
if (deadlineMs < 0L) {
throw new IllegalArgumentException("deadlineMs: " + deadlineMs + " < 0");
}
this.deadlineMs = deadlineMs;
}
/**
* Set deadline in millis for the output handlers to finish after the program has
* exited. If the streams are not fully handled within that time, the {@link #exec(String...)}
* method will throw an {@link IOException}. This is per default 1s, which will be the
* deadline of 0 is set.
*
* @param deadlineFlushMs The deadline in MS.
*/
public void setDeadlineFlushMs(long deadlineFlushMs) {
if (deadlineFlushMs < 0L) {
throw new IllegalArgumentException("deadlineFlushMs: " + deadlineFlushMs + " < 0");
}
this.deadlineFlushMs = deadlineFlushMs;
}
/**
* Execute command with no input. This will immediately close the stream to the process,
* so if it tries to read it will not block, but exit.
*
* @param command The command to be executed.
* @return The exit code of the program.
* @throws IOException If the program execution failed.
*/
public int exec(String... command) throws IOException {
return exec(null, command);
}
/**
* Execute command with specified input.
*
* @param in The input stream to read input from.
* @param command The command to be executed.
* @return The exit code of the program.
* @throws IOException If the program execution failed.
*/
public int exec(InputStream in, String... command) throws IOException {
return runSubProcess(runtime,
threadFactory,
env,
out,
err,
in,
workingDir,
deadlineMs,
deadlineFlushMs,
command);
}
// --- Private ---
private static String[] makeEnvP(Map<String, String> env) {
return env.entrySet()
.stream()
.map(e -> e.getKey() + "=" + e.getValue())
.toArray(String[]::new);
}
private static int runSubProcess(
Runtime runtime,
ThreadFactory threadFactory,
Map<String, String> env,
OutputStream out,
OutputStream err,
InputStream in,
Path workingDir,
long deadlineMs,
long deadlineFlushMs,
String... command) throws IOException {
requireNonNull(runtime, "runtime == null");
requireNonNull(out, "out == null");
requireNonNull(err, "err == null");
requireNonNull(command, "cmd == null");
if (command.length == 0) {
throw new IllegalArgumentException("empty command");
}
ExecutorService executor = Executors.newFixedThreadPool(3, threadFactory);
try {
AtomicReference<IOException> ioException = new AtomicReference<>();
long startNano = System.nanoTime();
Process process = workingDir == null
? runtime.exec(command, makeEnvP(env))
: runtime.exec(command, makeEnvP(env), workingDir.toFile());
executor.execute(() -> handleStreamInternal(ioException, process.getErrorStream(), err));
executor.execute(() -> handleStreamInternal(ioException, process.getInputStream(), out));
if (in != null) {
executor.execute(() -> handleStreamInternal(ioException, in, process.getOutputStream()));
} else {
// Always close the program's input stream to force it to stop reading.
// NOTE: This is not identical to how interactive apps work, but avoids
// the test to halt because of problems reading from std input stream.
//
// TODO: Figure out the correct default behavior + deadline.
process.getOutputStream().close();
}
Thread shutDownThread = threadFactory.newThread(process::destroyForcibly);
shutDownThread.setDaemon(false);
try {
runtime.addShutdownHook(shutDownThread);
if (deadlineMs > 0) {
if (!process.waitFor(deadlineMs, TimeUnit.MILLISECONDS)) {
process.destroyForcibly();
long endNano = System.nanoTime();
long ms = (endNano - startNano) / NS_IN_MS;
throw new IOException(makeTimeoutExceptionMessage(ms, deadlineMs, command));
}
} else {
process.waitFor();
}
} catch (InterruptedException e) {
executor.shutdown();
process.destroyForcibly();
throw e;
} finally {
runtime.removeShutdownHook(shutDownThread);
}
executor.shutdown();
long flushDeadline = deadlineFlushMs == 0 ? TimeUnit.SECONDS.toMillis(1L) : deadlineFlushMs;
if (!executor.awaitTermination(flushDeadline, TimeUnit.MILLISECONDS)) {
executor.shutdownNow();
throw new IOException("IO thread handling timeout");
}
if (ioException.get() != null) {
throw new IOException(ioException.get().getMessage(), ioException.get());
}
return process.exitValue();
} catch (InterruptedException e) {
throw new IOException(e.getMessage(), e);
} finally {
if (!executor.isShutdown()) {
executor.shutdownNow();
}
}
}
private static String makeTimeoutExceptionMessage(long ms, long deadlineMs, String[] cmd) {
StringBuilder bld = new StringBuilder();
boolean first = true;
for (String c : cmd) {
if (first) {
first = false;
} else {
bld.append(" ");
}
// Simple but not that efficient escaping.
String esc = c.replaceAll("\\\\", "\\\\")
.replaceAll("\t", "\\t")
.replaceAll("\f", "\\f")
.replaceAll("\n", "\\n")
.replaceAll("\r", "\\r")
.replaceAll("\"", "\\\"");
// Quote where escaping is needed, OR the argument contains a literal space.
if (c.contains(" ") || c.contains("'") || esc.contains("\\")) {
bld.append('\"')
.append(esc)
.append('\"');
} else {
bld.append(c);
}
}
return "deadline exceeded: " + ms + " > " + deadlineMs + ": " + bld;
}
private static void handleStreamInternal(
AtomicReference<IOException> ioException,
InputStream inputStream,
OutputStream outputStream) {
try {
streamCopy(inputStream, outputStream);
} catch (IOException e) {
ioException.updateAndGet(old -> maybeSuppress(old, e));
} finally {
try {
outputStream.close();
} catch (IOException e) {
ioException.updateAndGet(old -> maybeSuppress(old, e));
}
try {
inputStream.close();
} catch (IOException e) {
ioException.updateAndGet(old -> maybeSuppress(old, e));
}
}
}
private static void streamCopy(InputStream in, OutputStream out) throws IOException {
byte[] buffer = new byte[4 * 1024];
int b;
while ((b = in.read(buffer)) >= 0) {
out.write(buffer, 0, b);
}
out.flush();
}
private static IOException maybeSuppress(IOException old, IOException e) {
if (old != null) {
old.addSuppressed(e);
return old;
}
return e;
}
}