ImmediateScheduledExecutor.java

package net.morimekta.testing.concurrent;

import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.concurrent.Delayed;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;

import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

/**
 * A scheduled executor that uses a fake clock as back-bone to the executor.
 * To trigger the executions, call {@link FakeClock#tick(long)} on the fake
 * clock. The executions will be handled in thread of the tick itself.
 */
public class ImmediateScheduledExecutor
        implements ScheduledExecutorService, FakeClock.TimeListener {

    /**
     * A fake task / future instance.
     *
     * @param <V> The future return type.
     */
    public final class FakeTask<V> implements ScheduledFuture<V>, Runnable {
        private final Instant     triggersAt;
        private final Callable<V> callable;
        private final int         id;
        private       V           result;
        private       Throwable   except;

        private boolean cancelled;
        private boolean done;

        private FakeTask(Instant triggersAt,
                         Callable<V> callable) {
            this.triggersAt = triggersAt;
            this.callable = requireNonNull(callable, "callable == null");
            this.id = nextId.incrementAndGet();

            this.cancelled = false;
            this.done = false;
            this.result = null;
            this.except = null;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            FakeTask<?> fakeTask = (FakeTask<?>) o;
            return triggersAt.equals(fakeTask.triggersAt) &&
                   id == fakeTask.id &&
                   cancelled == fakeTask.cancelled &&
                   done == fakeTask.done &&
                   callable.equals(fakeTask.callable);
        }

        @Override
        public int hashCode() {
            return Objects.hash(triggersAt, callable, id, cancelled, done);
        }

        @Override
        public int compareTo(Delayed delayed) {
            requireNonNull(delayed, "delayed == null");
            if (delayed == this) {
                return 0;
            }
            if (delayed instanceof FakeTask) {
                int c = triggersAt.compareTo(((FakeTask<?>) delayed).triggersAt);
                if (0 != c) {
                    return c;
                }
                return Integer.compare(id, ((FakeTask<?>) delayed).id);
            }
            long diff = getDelay(MILLISECONDS) - delayed.getDelay(MILLISECONDS);
            return (diff < 0) ? -1 : (diff > 0) ? 1 : 0;
        }

        @Override
        public long getDelay(TimeUnit timeUnit) {
            requireNonNull(timeUnit, "timeUnit == null");
            Instant now = clock.instant();
            if (now.isBefore(triggersAt)) {
                return timeUnit.convert(Duration.between(now, triggersAt).toMillis(), MILLISECONDS);
            }
            return 0L;
        }

        @Override
        public boolean cancel(boolean mayInterruptIfRunning) {
            if (!done) {
                cancelled = true;
                done = true;
                scheduledTasks.remove(this);
            }
            return cancelled;
        }

        @Override
        public boolean isCancelled() {
            return cancelled;
        }

        @Override
        public boolean isDone() {
            return done;
        }

        @Override
        public V get() throws InterruptedException, ExecutionException {
            if (!done) {
                clock.tick(getDelay(MILLISECONDS));
            }
            if (cancelled) {
                throw new InterruptedException("Task cancelled");
            }
            if (except != null) {
                throw new ExecutionException(except.getMessage(), except);
            }
            return result;
        }

        @Override
        public V get(long l, TimeUnit timeUnit) throws InterruptedException, ExecutionException, TimeoutException {
            requireNonNull(timeUnit, "timeUnit == null");
            if (!done) {
                long maxWait = timeUnit.toMillis(l);
                clock.tick(Math.min(maxWait, getDelay(MILLISECONDS)));
                if (!done) {
                    throw new TimeoutException("Timed out after " + timeUnit.toMillis(l) + " millis");
                }
            }
            if (cancelled) {
                throw new InterruptedException("Task cancelled");
            }
            if (except != null) {
                throw new ExecutionException(except.getMessage(), except);
            }
            return result;
        }

        @Override
        public void run() {
            V result = null;
            try {
                result = callable.call();
            } catch (Exception e) {
                this.except = e;
            }
            this.result = result;
            this.done = true;
        }
    }

    /**
     * A fake recurring task / future.
     */
    public final class FakeRecurringTask implements ScheduledFuture<Void> {
        private final long                         delay;
        private final Runnable                     callable;
        private final AtomicReference<FakeTask<?>> next;
        AtomicLong nextExecution;

        private boolean cancelled;

        private FakeRecurringTask(long delay,
                                  long initialDelay,
                                  TimeUnit timeUnit,
                                  Runnable callable,
                                  AtomicReference<FakeTask<?>> first) {
            this.delay = timeUnit.toMillis(delay);
            this.nextExecution = new AtomicLong(clock.millis() + timeUnit.toMillis(initialDelay));
            this.callable = callable;
            this.next = first;
            this.cancelled = false;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            FakeRecurringTask that = (FakeRecurringTask) o;
            return delay == that.delay &&
                   cancelled == that.cancelled &&
                   callable.equals(that.callable) &&
                   nextExecution.get() == that.nextExecution.get();
        }

        @Override
        public int hashCode() {
            return Objects.hash(delay, callable, nextExecution, cancelled);
        }

        @Override
        public int compareTo(Delayed delayed) {
            requireNonNull(delayed, "delayed == null");
            if (delayed == this) {
                return 0;
            }
            if (delayed instanceof FakeRecurringTask) {
                FakeRecurringTask task = (FakeRecurringTask) delayed;
                int c = Long.compare(nextExecution.get(), task.nextExecution.get());
                if (c != 0) {
                    return c;
                }
                return Integer.compare(next.get().id, task.next.get().id);
            }
            long diff = getDelay(MILLISECONDS) - delayed.getDelay(MILLISECONDS);
            return (diff < 0) ? -1 : (diff > 0) ? 1 : 0;
        }

        @Override
        public boolean cancel(boolean mayInterruptIfRunning) {
            if (!cancelled) {
                cancelled = true;
                scheduledTasks.remove(next.get());
            }
            return cancelled;
        }

        @Override
        public boolean isCancelled() {
            return cancelled;
        }

        @Override
        public boolean isDone() {
            return cancelled;
        }

        @Override
        public Void get() throws InterruptedException {
            if (cancelled) {
                throw new InterruptedException("Task cancelled");
            }
            throw new IllegalStateException("Cannot wait for fake recurring tasks");
        }

        @Override
        public Void get(long l, TimeUnit timeUnit) throws InterruptedException {
            requireNonNull(timeUnit, "timeUnit == null");
            return get();
        }

        @Override
        public long getDelay(TimeUnit timeUnit) {
            requireNonNull(timeUnit, "timeUnit == null");
            long now = clock.millis();
            long realDelay = nextExecution.get() - now;
            return timeUnit.convert(realDelay, MILLISECONDS);
        }

        void runWithDelay() {
            try {
                callable.run();
            } catch (Exception e) {
                // e.printStackTrace();
                // keep going.
            }
            next.set(schedule(this::runWithDelay, delay, MILLISECONDS));
        }

        void runWithRate() {
            long now = clock.millis();

            // Run the task for each time it should have been run in the
            // time the last 'tick' should have triggered.
            long nextRun = nextExecution.get();
            while (nextRun <= now) {
                try {
                    callable.run();
                } catch (Exception e) {
                    // e.printStackTrace();
                    // keep going.
                }
                nextRun = nextExecution.addAndGet(delay);
            }
            long delay = nextRun - now;

            next.set(schedule(this::runWithRate, delay, MILLISECONDS));
        }
    }

    private final FakeClock         clock;
    private final List<FakeTask<?>> scheduledTasks;
    private final AtomicInteger     nextId = new AtomicInteger();
    private final FakeTimeListener  listener;

    private boolean shutdownCalled = false;

    /**
     * Create a fake scheduled executor.
     *
     * @param clock Clock to trigger executions.
     */
    public ImmediateScheduledExecutor(FakeClock clock) {
        this.scheduledTasks = new ArrayList<>();
        this.clock = requireNonNull(clock, "clock == null");
        this.listener = new FakeTimeListener();
        this.clock.addListener(listener);
    }

    // ---- FakeClock.TimeListener

    /**
     * @param newNow The new current time value.
     * @deprecated Use the {@link FakeClock#tick(long)} method to propagate time.
     */
    @Override
    @Deprecated(since = "v5.2.0", forRemoval = true)
    public void newCurrentTime(Instant newNow) {
        throw new UnsupportedOperationException("Not allowed to propagate time by calling this method");
    }

    /**
     * @param now The current time when asking.
     * @return The delay until the next execution should happen.
     * @deprecated Should not be used.
     */
    @Override
    @Deprecated(since = "v5.2.0", forRemoval = true)
    public Duration getDelay(Instant now) {
        return listener.getDelay(now);
    }

    // ---- ScheduledExecutorService

    @Override
    public FakeTask<?> schedule(Runnable runnable, long l, TimeUnit timeUnit) {
        requireNonNull(runnable, "runnable == null");
        requireNonNull(timeUnit, "timeUnit == null");
        return this.schedule(() -> {
            runnable.run();
            return null;
        }, l, timeUnit);
    }

    @Override
    public <V> FakeTask<V> schedule(Callable<V> callable, long delay, TimeUnit timeUnit) {
        requireNonNull(callable, "callable == null");
        requireNonNull(timeUnit, "timeUnit == null");
        if (isShutdown()) {
            throw new IllegalStateException("Executor is shut down");
        }
        if (delay < 0) {
            throw new IllegalArgumentException("Unable to schedule tasks in the past");
        }

        var now = clock.instant();
        var triggersAt = now.plusMillis(timeUnit.toMillis(delay));
        FakeTask<V> task = new FakeTask<>(triggersAt, callable);
        scheduledTasks.add(task);
        Collections.sort(scheduledTasks);
        return task;
    }

    @Override
    public FakeRecurringTask scheduleAtFixedRate(Runnable runnable, long initialDelay, long period, TimeUnit timeUnit) {
        requireNonNull(runnable, "runnable == null");
        requireNonNull(timeUnit, "timeUnit == null");
        if (initialDelay < 0 || period < 1) {
            throw new IllegalArgumentException("Invalid initial delay or period: " + initialDelay + " / " + period);
        }
        AtomicReference<FakeTask<?>> first = new AtomicReference<>();
        FakeRecurringTask recurring = new FakeRecurringTask(period, initialDelay, timeUnit, runnable, first);
        first.set(schedule(recurring::runWithRate, initialDelay, timeUnit));
        return recurring;
    }

    @Override
    public FakeRecurringTask scheduleWithFixedDelay(Runnable runnable,
                                                    long initialDelay,
                                                    long delay,
                                                    TimeUnit timeUnit) {
        requireNonNull(runnable, "runnable == null");
        requireNonNull(timeUnit, "timeUnit == null");
        if (initialDelay < 0 || delay < 1) {
            throw new IllegalArgumentException("Invalid initial delay or intermediate delay: " + initialDelay + " / " + delay);
        }
        AtomicReference<FakeTask<?>> first = new AtomicReference<>();
        FakeRecurringTask recurring = new FakeRecurringTask(delay, initialDelay, timeUnit, runnable, first);
        first.set(schedule(recurring::runWithDelay, initialDelay, timeUnit));
        return recurring;
    }

    // ---- ExecutorService

    @Override
    public void shutdown() {
        this.shutdownCalled = true;
    }

    @Override
    public List<Runnable> shutdownNow() {
        this.shutdownCalled = true;
        List<Runnable> result = scheduledTasks.stream()
                                              .filter(t -> !t.isCancelled())
                                              .filter(t -> !t.isDone())
                                              .collect(Collectors.toList());
        scheduledTasks.clear();
        return result;
    }

    @Override
    public boolean isShutdown() {
        return shutdownCalled;
    }

    @Override
    public boolean isTerminated() {
        return shutdownCalled && scheduledTasks.isEmpty();
    }

    @Override
    public boolean awaitTermination(long l, TimeUnit timeUnit) {
        requireNonNull(timeUnit, "timeUnit == null");
        if (!shutdownCalled) {
            throw new IllegalStateException("Shutdown not triggered");
        }
        var now = clock.instant();
        var until = now.plusMillis(timeUnit.toMillis(l));

        FakeTask<?> next;
        while ((next = getNextTask()) != null) {
            now = clock.instant();
            var delay = next.getDelay(MILLISECONDS);
            if (delay > 0 && (now.plusMillis(delay)).isAfter(until)) {
                break;
            }

            if (delay > 0) {
                clock.tick(delay);
            } else {
                listener.newCurrentTime(now);
            }
        }

        return scheduledTasks.isEmpty();
    }

    @Override
    public <T> FakeTask<T> submit(Callable<T> callable) {
        return schedule(callable, 0, MILLISECONDS);
    }

    @Override
    public <T> FakeTask<T> submit(Runnable runnable, T t) {
        requireNonNull(runnable, "runnable == null");
        return schedule(() -> {
            runnable.run();
            return t;
        }, 0, MILLISECONDS);
    }

    @Override
    public FakeTask<?> submit(Runnable runnable) {
        return schedule(runnable, 0, MILLISECONDS);
    }

    @Override
    public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> collection) {
        requireNonNull(collection, "collection == null");
        if (isShutdown()) {
            throw new IllegalStateException("Executor is shut down");
        }
        if (collection.isEmpty()) {
            throw new IllegalArgumentException("Empty invoke collection");
        }

        List<Future<T>> results = new ArrayList<>();
        for (Callable<T> c : collection) {
            results.add(submit(c));
        }
        return results;
    }

    @Override
    public <T> List<Future<T>> invokeAll(
            Collection<? extends Callable<T>> collection,
            long l,
            TimeUnit timeUnit) {
        requireNonNull(timeUnit, "timeUnit == null");
        if (l < 0) {
            throw new IllegalArgumentException("Negative timeout: " + l);
        }
        return invokeAll(collection);
    }

    @Override
    public <T> T invokeAny(Collection<? extends Callable<T>> collection) throws ExecutionException {
        requireNonNull(collection, "collection == null");
        if (isShutdown()) {
            throw new IllegalStateException("Executor is shut down");
        }
        if (collection.isEmpty()) {
            throw new IllegalArgumentException("Empty invoke collection");
        }

        ExecutionException ex = null;
        for (Callable<T> c : collection) {
            try {
                return c.call();
            } catch (Exception e) {
                if (ex == null) {
                    ex = new ExecutionException("All " + collection.size() + " tasks failed, first exception", e);
                } else {
                    ex.addSuppressed(e);
                }
            }
        }
        throw ex;
    }

    @Override
    public <T> T invokeAny(Collection<? extends Callable<T>> collection,
                           long l,
                           TimeUnit timeUnit) throws ExecutionException {
        requireNonNull(timeUnit, "timeUnit == null");
        if (l < 0) {
            throw new IllegalArgumentException("Negative timeout: " + l);
        }
        return invokeAny(collection);
    }

    @Override
    public void execute(Runnable runnable) {
        schedule(runnable, 0, MILLISECONDS);
    }

    private FakeTask<?> getNextTask() {
        if (scheduledTasks.isEmpty()) {
            return null;
        }
        return scheduledTasks.get(0);
    }

    private final class FakeTimeListener implements FakeClock.TimeListener {
        @Override
        public Duration getDelay(Instant now) {
            var next = scheduledTasks
                    .stream()
                    .filter(it -> !it.isDone())
                    .map(it -> it.triggersAt)
                    .sorted()
                    .findFirst()
                    .orElse(Instant.MAX);
            if (next.equals(Instant.MAX)) {
                return Duration.ZERO;
            }
            if (!now.isBefore(next)) {
                return FakeClock.MIN_DELAY;
            }
            return Duration.between(now, next);
        }

        @Override
        public void newCurrentTime(Instant now) {
            FakeTask<?> next;
            while ((next = getNextTask()) != null) {
                if (next.isDone()) {
                    scheduledTasks.remove(next);
                } else if (!now.isBefore(next.triggersAt)) {
                    scheduledTasks.remove(next);
                    next.run();
                } else {
                    break;
                }
            }
        }
    }
}