ConsoleManager.java

package net.morimekta.testing.console;

import net.morimekta.io.tty.TTY;
import net.morimekta.io.tty.TTYMode;
import net.morimekta.io.tty.TTYSize;
import net.morimekta.strings.chr.CharUtil;
import net.morimekta.strings.chr.Color;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.io.PrintStream;
import java.io.UncheckedIOException;
import java.util.concurrent.atomic.AtomicReference;

import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Objects.requireNonNull;
import static net.morimekta.testing.console.Console.DEFAULT_TERMINAL_SIZE;

/**
 * Extension for adding a fully virtual TTY and I/O for testing. This will forcefully replace standard in, out and err
 * while the test is running, falling back to default (system streams) when completed. This means any test that uses
 * normal system I/O to print ongoing status will not work with this extension.
 *
 * <pre>{@code
 * {@literal@}ExtendWith(ConsoleExtension.class)
 * public class MyTest {
 *     {@literal@}Test
 *     public void testMyThing(Console console) {
 *         // use the console I/O or TTY or both.
 *     }
 * }
 * }</pre>
 */
public class ConsoleManager {
    /** Create a console manager. */
    public ConsoleManager() {
        originalOut = System.out;
        originalErr = System.err;
        originalIn = System.in;

        in = new WrappedInputStream();
        out = new PrintStream(new WrappedOutputStream(), false, UTF_8);
        err = new PrintStream(new WrappedErrorStream(), true, UTF_8);

        tty = new TTYImpl();
        console = new ConsoleImpl();

        setUpStreams();
    }

    /** @return The fake TTY for the console. */
    public TTY getTTY() {
        return tty;
    }

    /** @return The console interface. */
    public Console getConsole() {
        return console;
    }

    /**
     * @param terminalSize The new terminal size.
     */
    public void setTerminalSize(TTYSize terminalSize) {
        this.terminalSize = requireNonNull(terminalSize, "terminalSize == null");
    }

    /**
     * @param interactive If the console should be interactive.
     */
    public void setInteractive(boolean interactive) {
        this.interactive = interactive;
    }

    /**
     * @param dumpErrorOnFailure If standard ERR should be dumped on failures.
     */
    public void setDumpErrorOnFailure(boolean dumpErrorOnFailure) {
        this.dumpErrorOnFailure = dumpErrorOnFailure;
    }

    /**
     * @param forkError If standard ERR should be forked to original output.
     */
    public void setForkError(boolean forkError) {
        this.forkError = forkError;
    }

    /**
     * @param dumpOutputOnFailure If standard OUT should be dumped on failures.
     */
    public void setDumpOutputOnFailure(boolean dumpOutputOnFailure) {
        this.dumpOutputOnFailure = dumpOutputOnFailure;
    }

    /**
     * @param forkOutput If standard OUT should be forked to original output.
     */
    public void setForkOutput(boolean forkOutput) {
        this.forkOutput = forkOutput;
    }

    /**
     * Trigger start of test. Clears IO streams and overrides native IO.
     */
    public void doBeforeEach() {
        setUpStreams();

        System.setIn(in);
        System.setErr(err);
        System.setOut(out);
    }

    /**
     * Trigger test failure. Handles eventual printing on test failure, if there
     * is anything to print.
     *
     * @param displayName Display name of test.
     */
    public void onTestFailed(String displayName) {
        if (dumpOutputOnFailure && outStream.size() > 0) {
            out.flush();
            originalErr.println(
                    Color.BOLD + " <<< --- stdout : " + displayName + " --- >>>" + Color.CLEAR);
            originalErr.print(getOutputInternal());
            originalErr.println(Color.BOLD + " <<< --- stdout : END --- >>>" + Color.CLEAR);
            if (dumpErrorOnFailure && errStream.size() > 0) {
                originalErr.println();
            }
        }
        if (dumpErrorOnFailure && errStream.size() > 0) {
            originalErr.println(
                    Color.BOLD + " <<< --- stderr : " + displayName + " --- >>>" + Color.CLEAR);
            originalErr.print(Color.RED + getErrorInternal() + Color.CLEAR);
            originalErr.println(Color.BOLD + " <<< --- stderr : END --- >>>" + Color.CLEAR);
        }
    }

    /**
     * Trigger end of test. Set native IO streams back to original.
     */
    public void doAfterEach() {
        System.setErr(originalErr);
        System.setOut(originalOut);
        System.setIn(originalIn);
        try {
            tty.getAndUpdateMode(TTYMode.COOKED);
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    // ---- Private ----

    /**
     * @return Get the normal output.
     */
    private String getOutputInternal() {
        return outStream.toString(UTF_8);
    }

    /**
     * @return Get the error output.
     */
    private String getErrorInternal() {
        return errStream.toString(UTF_8);
    }

    /**
     * Set input to return the given bytes.
     *
     * @param in The bytes for input.
     */
    private void setInputInternal(byte[] in) {
        inStream = new ByteArrayInputStream(in);
    }

    /**
     * Set input with dynamic content.
     *
     * @param in The input values.
     */
    private void setInputInternal(Object... in) {
        assert in.length > 0 : "Require at least one input item";
        setInputInternal(CharUtil.inputBytes(in));
    }

    /**
     * Set input to return the given bytes.
     *
     * @return An output stream that when written to will
     */
    private OutputStream createInputSourceInternal() throws IOException {
        var out = new PipedOutputStream();
        try {
            inStream.close();
        } catch (IOException ignore) {
        } finally {
            inStream = new PipedInputStream(out);
        }
        return out;
    }

    private void setUpStreams() {
        outStream = new ByteArrayOutputStream();
        errStream = new ByteArrayOutputStream();
        inStream = new ByteArrayInputStream(new byte[0]);
    }

    private class WrappedOutputStream
            extends OutputStream {
        @Override
        public void write(int i) {
            if (forkOutput) {
                originalOut.write(i);
            }
            outStream.write(i);
        }

        @Override
        public void write(byte[] bytes, int off, int len) {
            if (forkOutput) {
                originalOut.write(bytes, off, len);
            }
            outStream.write(bytes, off, len);
        }

        @Override
        public void flush() {
            originalOut.flush();
        }
    }

    private class WrappedErrorStream
            extends OutputStream {
        @Override
        public void write(int i) {
            if (forkError) {
                originalErr.write(i);
            }
            errStream.write(i);
        }

        @Override
        public void write(byte[] bytes, int off, int len) {
            if (forkError) {
                originalErr.write(bytes, off, len);
            }
            errStream.write(bytes, off, len);
        }

        @Override
        public void flush() {
            originalErr.flush();
        }
    }

    private class WrappedInputStream
            extends InputStream {
        @Override
        public int read() throws IOException {
            return inStream.read();
        }

        @Override
        public int read(byte[] bytes) throws IOException {
            return inStream.read(bytes);
        }

        @Override
        public int read(byte[] bytes, int i, int i1) throws IOException {
            return inStream.read(bytes, i, i1);
        }

        @Override
        public long skip(long l) throws IOException {
            return inStream.skip(l);
        }

        @Override
        public void close() {
            try {
                inStream.close();
            } catch (IOException ignore) {
            } finally {
                inStream = new ByteArrayInputStream(new byte[0]);
            }
        }

        @Override
        public int available() throws IOException {
            return inStream.available();
        }

        @Override
        public boolean markSupported() {
            return inStream.markSupported();
        }

        @Override
        public synchronized void mark(int readlimit) {
            inStream.mark(readlimit);
        }

        @Override
        public synchronized void reset() throws IOException {
            inStream.reset();
        }
    }

    private class ConsoleImpl
            implements Console {
        @Override
        public void reset() {
            setUpStreams();
        }

        @Override
        public String error() {
            return getErrorInternal();
        }

        @Override
        public String output() {
            return getOutputInternal();
        }

        @Override
        public void setInput(Object... in) {
            setInputInternal(in);
        }

        @Override
        public void setInput(byte[] in) {
            setInputInternal(in);
        }

        @Override
        public OutputStream createInputSource() {
            try {
                return createInputSourceInternal();
            } catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        }

        @Override
        public TTY tty() {
            return tty;
        }
    }

    private class TTYImpl
            extends TTY {
        @Override
        public TTYMode getAndUpdateMode(TTYMode mode) {
            return currentMode.getAndSet(mode);
        }

        @Override
        public TTYSize getTerminalSize() {
            if (!isInteractive()) {
                throw new UncheckedIOException(new IOException("Non-interactive test-console"));
            }
            return terminalSize;
        }

        @Override
        public boolean isInteractive() {
            return interactive;
        }
    }

    private final AtomicReference<TTYMode> currentMode = new AtomicReference<>(TTYMode.COOKED);

    private final TTY     tty;
    private final Console console;

    private final InputStream in;
    private final PrintStream out;
    private final PrintStream err;
    private final PrintStream originalOut;
    private final PrintStream originalErr;
    private final InputStream originalIn;

    private ByteArrayOutputStream outStream = null;
    private ByteArrayOutputStream errStream = null;
    private InputStream           inStream  = null;

    private TTYSize terminalSize        = DEFAULT_TERMINAL_SIZE;
    private boolean interactive         = true;
    private boolean dumpOutputOnFailure = false;
    private boolean dumpErrorOnFailure  = false;
    private boolean forkOutput          = true;
    private boolean forkError           = true;
}