FakeScheduledExecutor.java

package net.morimekta.testing.concurrent;

import net.morimekta.testing.concurrent.internal.CompletableScheduledFuture;

import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.AbstractExecutorService;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import static java.util.Objects.requireNonNull;
import static java.util.concurrent.Executors.newFixedThreadPool;

/**
 * A scheduled executor that uses a fake clock as back-bone to the executor. To trigger scheduled tasks in the
 * executions, call {@link FakeClock#tick(long)} on the fake clock. Execution will be handles in a separate
 * thread using an internal executor service.
 */
public class FakeScheduledExecutor
        extends AbstractExecutorService
        implements ScheduledExecutorService, FakeClock.TimeListener {
    public FakeScheduledExecutor(FakeClock clock) {
        this(clock, 1);
    }

    /**
     * Create a fake scheduled executor.
     *
     * @param clock      Clock to trigger executions.
     * @param maxThreads Max number of threads.
     */
    public FakeScheduledExecutor(FakeClock clock, int maxThreads) {
        this.scheduledTasks = new ArrayList<>();
        this.executor = newFixedThreadPool(maxThreads);
        this.clock = requireNonNull(clock, "clock == null");
        this.listener = new FakeTimeListener();
        this.clock.addListener(listener);
        this.nextId = new AtomicInteger();
    }

    // ---- FakeClock.TimeListener

    /**
     * @param newNow The new current time value.
     * @deprecated Use the {@link FakeClock#tick(long)} method to propagate time.
     */
    @Override
    @Deprecated(since = "v5.2.0", forRemoval = true)
    public void newCurrentTime(Instant newNow) {
        throw new UnsupportedOperationException("Not allowed to propagate time by calling this method");
    }

    /**
     * @param now The current time when asking.
     * @return The delay until the next execution should happen.
     * @deprecated Should not be used.
     */
    @Override
    @Deprecated(since = "v5.2.0", forRemoval = true)
    public Duration getDelay(Instant now) {
        return listener.getDelay(now);
    }

    // ---- ScheduledExecutorService

    @Override
    public ScheduledFuture<?> schedule(Runnable runnable, long l, TimeUnit timeUnit) {
        requireNonNull(runnable, "runnable == null");
        return this.schedule(() -> {
            runnable.run();
            return null;
        }, l, timeUnit);
    }

    @Override
    public <V> ScheduledFuture<V> schedule(Callable<V> callable, long delay, TimeUnit timeUnit) {
        requireNonNull(callable, "callable == null");
        requireNonNull(timeUnit, "timeUnit == null");
        if (delay < 0) {
            throw new IllegalArgumentException("Invalid delay " + delay);
        }
        if (isShutdown()) {
            throw new IllegalStateException("Executor is shut down");
        }

        var now = clock.instant();
        var nextExecution = new AtomicReference<>(now.plusMillis(timeUnit.toMillis(delay)));
        var future = new CompletableScheduledFuture<V>(clock, nextExecution);
        var task = new ScheduledTask(0, 0, () -> {
            try {
                future.complete(callable.call());
            } catch (Exception e) {
                future.completeExceptionally(e);
            }
        }, future, nextExecution, nextExecution);
        if (delay == 0) {
            execute(task::runOnce);
            try {
                Thread.sleep(3);
            } catch (InterruptedException e) {
                throw new AssertionError("Interrupted", e);
            }
        } else {
            removeOnCancel(future, task);
            synchronized (scheduledTasks) {
                scheduledTasks.add(task);
            }
        }
        return future;
    }

    @Override
    public ScheduledFuture<?> scheduleAtFixedRate(Runnable runnable, long initialDelay, long rate, TimeUnit timeUnit) {
        validateArguments(runnable, initialDelay, 1, rate, timeUnit);
        var now = clock.instant();
        var nextExecution = new AtomicReference<>(Instant.MAX);
        var nextRun = new AtomicReference<>(now.plusMillis(timeUnit.toMillis(initialDelay)));
        var future = new CompletableScheduledFuture<Void>(clock, nextExecution);
        var task = new ScheduledTask(0, timeUnit.toMillis(rate), runnable, future, nextExecution, nextRun);
        removeOnCancel(future, task);
        synchronized (scheduledTasks) {
            scheduledTasks.add(task);
        }
        if (initialDelay == 0) {
            execute(() -> task.runWithRate(now));
            try {
                Thread.sleep(3);
            } catch (InterruptedException e) {
                throw new AssertionError("Interrupted", e);
            }
        } else {
            nextExecution.set(nextRun.get());
        }
        return future;
    }

    @Override
    public ScheduledFuture<?> scheduleWithFixedDelay(Runnable runnable,
                                                     long initialDelay,
                                                     long delay,
                                                     TimeUnit timeUnit) {
        validateArguments(runnable, initialDelay, delay, 1, timeUnit);
        var now = clock.instant();
        var nextExecution = new AtomicReference<>(Instant.MAX);
        var nextRun = new AtomicReference<>(now.plusMillis(timeUnit.toMillis(initialDelay)));
        var future = new CompletableScheduledFuture<Void>(clock, nextExecution);
        var task = new ScheduledTask(timeUnit.toMillis(delay), 0, runnable, future, nextExecution, nextRun);
        removeOnCancel(future, task);
        synchronized (scheduledTasks) {
            scheduledTasks.add(task);
        }
        if (initialDelay == 0) {
            execute(task::runWithDelay);
            try {
                Thread.sleep(3);
            } catch (InterruptedException e) {
                throw new AssertionError("Interrupted", e);
            }
        } else {
            nextExecution.set(nextRun.get());
        }
        return future;
    }

    // ---- ExecutorService

    @Override
    public void shutdown() {
        executor.shutdown();
        synchronized (scheduledTasks) {
            scheduledTasks.forEach(task -> task.future.complete(null));
            scheduledTasks.clear();
        }
    }

    @Override
    public List<Runnable> shutdownNow() {
        var ret = executor.shutdownNow();
        synchronized (scheduledTasks) {
            scheduledTasks.forEach(task -> task.future.complete(null));
            scheduledTasks.clear();
        }
        return ret;
    }

    @Override
    public boolean isShutdown() {
        return executor.isShutdown();
    }

    @Override
    public boolean isTerminated() {
        return executor.isTerminated();
    }

    @Override
    public boolean awaitTermination(long l, TimeUnit timeUnit) throws InterruptedException {
        return executor.awaitTermination(l, timeUnit);
    }

    // ---- Executor

    @Override
    public void execute(Runnable runnable) {
        executor.execute(runnable);
    }

    // ---- Private

    private void validateArguments(Runnable runnable,
                                   long initialDelay,
                                   long delay,
                                   long rate,
                                   TimeUnit timeUnit) {
        requireNonNull(runnable, "runnable == null");
        requireNonNull(timeUnit, "timeUnit == null");
        if (initialDelay < 0) {
            throw new IllegalArgumentException("Invalid initial delay " + initialDelay);
        }
        if (delay < 1) {
            throw new IllegalArgumentException("Invalid delay " + delay);
        }
        if (rate < 1) {
            throw new IllegalArgumentException("Invalid rate " + rate);
        }
        if (isShutdown()) {
            throw new IllegalStateException("Executor is shut down");
        }
    }

    private void removeOnCancel(CompletableFuture<?> future, ScheduledTask task) {
        future.whenCompleteAsync((value, throwable) -> {
            if (future.isCancelled()) {
                synchronized (scheduledTasks) {
                    scheduledTasks.remove(task);
                }
            }
        }, executor);
    }

    private class ScheduledTask
            implements Comparable<ScheduledTask> {
        private final long                     delayMs;
        private final long                     rateMs;
        private final Runnable                 runnable;
        private final CompletableFuture<?>     future;
        private final AtomicReference<Instant> nextExecution;
        private final AtomicReference<Instant> nextRun;
        private final int                      id;

        private ScheduledTask(long delayMs,
                              long rateMs,
                              Runnable runnable,
                              CompletableFuture<?> future,
                              AtomicReference<Instant> nextExecution,
                              AtomicReference<Instant> nextRun) {
            this.delayMs = delayMs;
            this.rateMs = rateMs;
            this.id = nextId.incrementAndGet();

            this.runnable = runnable;
            this.nextExecution = nextExecution;
            this.nextRun = nextRun;
            this.future = future;
        }

        public boolean isCancelled() {
            return future.isCancelled();
        }

        public boolean shouldExecute(Instant now) {
            return !isCancelled() && !now.isBefore(nextExecution.get());
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            ScheduledTask that = (ScheduledTask) o;
            return delayMs == that.delayMs &&
                   rateMs == that.rateMs &&
                   id == that.id;
        }

        @Override
        public int hashCode() {
            return Objects.hash(delayMs, rateMs, id);
        }

        @Override
        public int compareTo(ScheduledTask task) {
            requireNonNull(task, "delayed == null");
            int c = nextRun.get().compareTo(task.nextRun.get());
            if (c != 0) {
                return c;
            }
            return Integer.compare(id, task.id);
        }

        private void runOnce() {
            if (future.isDone()) {
                return;
            }
            runnable.run();
            synchronized (scheduledTasks) {
                scheduledTasks.remove(this);
            }
        }

        private void runWithDelay() {
            if (future.isDone()) {
                return;
            }
            try {
                runnable.run();
            } catch (Exception e) {
                future.completeExceptionally(e);
                synchronized (scheduledTasks) {
                    scheduledTasks.remove(this);
                }
                return;
            }
            var nextRun = clock.instant().plusMillis(delayMs);
            nextExecution.set(nextRun);
            this.nextRun.set(nextRun);
        }

        private void runWithRate(Instant now) {
            if (future.isDone()) {
                return;
            }
            // Run the task for each time it should have been run in the
            // time the last 'tick' should have triggered.
            Instant nextRun = this.nextRun.get();
            while (!now.isBefore(nextRun)) {
                try {
                    runnable.run();
                } catch (Exception e) {
                    future.completeExceptionally(e);
                    synchronized (scheduledTasks) {
                        scheduledTasks.remove(this);
                    }
                    return;
                }
                nextRun = this.nextRun.updateAndGet(it -> it.plusMillis(rateMs));
            }
            nextExecution.set(nextRun);
        }
    }

    private class FakeTimeListener implements FakeClock.TimeListener {
        @Override
        public void newCurrentTime(Instant now) {
            if (isShutdown()) {
                return;
            }

            AtomicBoolean scheduled = new AtomicBoolean();
            synchronized (scheduledTasks) {
                scheduledTasks.removeIf(ScheduledTask::isCancelled);
                scheduledTasks.stream()
                              .sorted()
                              .filter(task -> task.shouldExecute(now))
                              .forEach(task -> {
                                  task.nextExecution.set(Instant.MAX);
                                  scheduled.set(true);
                                  if (task.delayMs > 0) {
                                      execute(task::runWithDelay);
                                  } else if (task.rateMs > 0) {
                                      execute(() -> task.runWithRate(now));
                                  } else {
                                      execute(task::runOnce);
                                  }
                              });
            }
            if (scheduled.get()) {
                try {
                    Thread.sleep(10L);
                } catch (InterruptedException e) {
                    throw new AssertionError("Interrupted", e);
                }
            }
        }

        @Override
        public Duration getDelay(Instant now) {
            synchronized (scheduledTasks) {
                scheduledTasks.removeIf(ScheduledTask::isCancelled);
                var next = scheduledTasks
                        .stream()
                        .map(it -> it.nextExecution.get())
                        .filter(it -> it != Instant.MAX)
                        .sorted()
                        .findFirst()
                        .orElse(Instant.MAX);
                if (next.equals(Instant.MAX)) {
                    return Duration.ZERO;
                }
                if (!now.isBefore(next)) {
                    return FakeClock.MIN_DELAY;
                }
                return Duration.between(now, next);
            }
        }
    }

    private final FakeClock           clock;
    private final FakeTimeListener    listener;
    private final List<ScheduledTask> scheduledTasks;
    private final ExecutorService     executor;
    private final AtomicInteger       nextId;
}