TTY.java

/*
 * Copyright (c) 2016, Stein Eldar Johnsen
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.io.tty;

import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.time.Clock;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;

import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Objects.requireNonNull;

/**
 * A terminal controller helper.
 */
public class TTY {
    private final Runtime runtime;
    private final Clock   clock;

    private final AtomicLong                   termTime;
    private final AtomicReference<TTYSize>     termSize;
    private final AtomicReference<IOException> termException;

    /**
     * Create default instance.
     */
    public TTY() {
        this(Runtime.getRuntime(), Clock.systemUTC());
    }

    /**
     * Create instance with runtime and clock. Visible for testing.
     *
     * @param runtime Runtime used to run system commands.
     * @param clock   System clock. Used for timing the caching of TTY size.
     */
    public TTY(Runtime runtime, Clock clock) {
        this.runtime = requireNonNull(runtime, "runtime == null");
        this.clock = requireNonNull(clock, "clock == null");

        this.termException = new AtomicReference<>();
        this.termSize = new AtomicReference<>();
        this.termTime = new AtomicLong(0L);
    }

    /**
     * Set terminal mode.
     *
     * @param mode The mode to set.
     * @return The previous mode.
     * @throws IOException If setting mode failed.
     */
    public TTYMode getAndUpdateMode(TTYMode mode) throws IOException {
        return requireNonNull(
                getAndUpdateSttyMode(runtime, requireNonNull(mode, "mode == null")),
                "returning null");
    }

    /**
     * Get the currently active TTY mode.
     *
     * @return The active mode.
     */
    public TTYMode getCurrentMode() {
        return requireNonNull(singletonCurrentMode.get());
    }

    /**
     * Get the terminal size.
     *
     * @return the terminal size.
     * @throws UncheckedIOException If getting the terminal size failed.
     */
    public TTYSize getTerminalSize() {
        return Optional.ofNullable(updateAndGetTerminalSize())
                       .orElseThrow(() -> new UncheckedIOException(termException.get()));
    }

    /**
     * Clear the cached terminal size regardless of when it was last checked.
     *
     * @return The TTY.
     */
    public TTY clearCachedTerminalSize() {
        termTime.set(0L);
        return this;
    }

    /**
     * @return True if this is an interactive TTY terminal.
     */
    public boolean isInteractive() {
        return updateAndGetTerminalSize() != null;
    }

    // -- Object

    @Override
    public String toString() {
        StringBuilder toString = new StringBuilder("STTY{mode=")
                .append(getCurrentMode());
        TTYSize size = updateAndGetTerminalSize();
        if (size != null) {
            toString.append(", size=")
                    .append(size);
        }
        return toString.append("}").toString();
    }

    // ---------- Private ---------

    private TTYSize updateAndGetTerminalSize() {
        return termSize.updateAndGet(ts -> {
            if (termTime.get() < clock.millis()) {
                try {
                    return runSttyTerminalSize(runtime);
                } catch (IOException e) {
                    termException.set(e);
                } finally {
                    termTime.set(clock.millis() + 499);
                }
                return ts;
            }
            return ts;
        });
    }

    // ------ Private Static ------

    // Default output mode is COOKED.
    private static final AtomicReference<TTYMode> singletonCurrentMode = new AtomicReference<>(TTYMode.COOKED);

    private static TTYMode getAndUpdateSttyMode(Runtime runtime, TTYMode mode) throws IOException {
        try {
            return singletonCurrentMode.getAndUpdate(old -> {
                if (mode != old) {
                    try {
                        runSetSttyMode(runtime, mode);
                        return mode;
                    } catch (IOException e) {
                        throw new UncheckedIOException(e);
                    }
                }
                return old;
            });
        } catch (UncheckedIOException e) {
            throw e.getCause();
        }
    }

    private static void runSetSttyMode(Runtime runtime, TTYMode mode) throws IOException {
        if (mode == TTYMode.COOKED) {
            runInShell(runtime, "stty -raw echo </dev/tty");
        } else {
            runInShell(runtime, "stty raw -echo </dev/tty");
        }
    }

    private static TTYSize runSttyTerminalSize(Runtime runtime) throws IOException {
        Process p;
        try {
            p = runInShell(runtime, "stty size </dev/tty");
        } catch (IOException e) {
            throw new IOException("Unable to get TTY size: " + e.getMessage(), e);
        }

        String out = "";
        try (BufferedInputStream reader = new BufferedInputStream(p.getInputStream())) {
            out = new String(reader.readAllBytes(), UTF_8).strip();
            if (!out.isEmpty()) {
                String[] parts = out.split(" ");
                if (parts.length == 2) {
                    int rows = Integer.parseInt(parts[0]);
                    int cols = Integer.parseInt(parts[1]);
                    return new TTYSize(rows, cols);
                }
                throw new IOException("Unknown 'stty size' output: '" + out + "'");
            }
            throw new IOException("No 'stty size' output.");
        } catch (NumberFormatException e) {
            throw new IOException("Invalid stty size line '" + out + "'");
        }
    }

    private static Process runInShell(Runtime runtime, String cmd) throws IOException {
        Process p = runtime.exec(new String[]{"/bin/sh", "-c", cmd});

        try {
            p.waitFor();
        } catch (InterruptedException ie) {
            throw new IOException(ie.getMessage(), ie);
        }

        try (BufferedInputStream reader = new BufferedInputStream(p.getErrorStream())) {
            String err = new String(reader.readAllBytes(), UTF_8).strip();
            if (!err.isEmpty()) {
                throw new IOException(err);
            }
        }
        return p;
    }
}