ImmediateScheduledExecutor.java
/*
* Copyright (c) 2020, Stein Eldar Johnsen
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package net.morimekta.testing.concurrent;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.concurrent.Delayed;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
/**
* A scheduled executor that uses a fake clock as back-bone to the executor.
* To trigger the executions, call {@link FakeClock#tick(long)} on the fake
* clock. The executions will be handled in thread of the tick itself.
*/
@SuppressWarnings("PreferJavaTimeOverload")
public class ImmediateScheduledExecutor
implements ScheduledExecutorService, FakeClock.TimeListener {
/**
* A fake task / future instance.
*
* @param <V> The future return type.
*/
public final class FakeTask<V> implements ScheduledFuture<V>, Runnable {
private final Instant triggersAt;
private final Callable<V> callable;
private final int id;
private V result;
private Throwable except;
private boolean cancelled;
private boolean done;
private FakeTask(Instant triggersAt,
Callable<V> callable) {
this.triggersAt = triggersAt;
this.callable = requireNonNull(callable, "callable == null");
this.id = nextId.incrementAndGet();
this.cancelled = false;
this.done = false;
this.result = null;
this.except = null;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
FakeTask<?> fakeTask = (FakeTask<?>) o;
return triggersAt.equals(fakeTask.triggersAt) &&
id == fakeTask.id &&
cancelled == fakeTask.cancelled &&
done == fakeTask.done &&
callable.equals(fakeTask.callable);
}
@Override
public int hashCode() {
return Objects.hash(triggersAt, callable, id, cancelled, done);
}
@Override
public int compareTo(Delayed delayed) {
requireNonNull(delayed, "delayed == null");
if (delayed == this) {
return 0;
}
if (delayed instanceof FakeTask) {
int c = triggersAt.compareTo(((FakeTask<?>) delayed).triggersAt);
if (0 != c) {
return c;
}
return Integer.compare(id, ((FakeTask<?>) delayed).id);
}
long diff = getDelay(MILLISECONDS) - delayed.getDelay(MILLISECONDS);
return (diff < 0) ? -1 : (diff > 0) ? 1 : 0;
}
@Override
public long getDelay(TimeUnit timeUnit) {
requireNonNull(timeUnit, "timeUnit == null");
Instant now = clock.instant();
if (now.isBefore(triggersAt)) {
return timeUnit.convert(Duration.between(now, triggersAt).toMillis(), MILLISECONDS);
}
return 0L;
}
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
if (!done) {
cancelled = true;
done = true;
scheduledTasks.remove(this);
}
return cancelled;
}
@Override
public boolean isCancelled() {
return cancelled;
}
@Override
public boolean isDone() {
return done;
}
@Override
public V get() throws InterruptedException, ExecutionException {
if (!done) {
clock.tick(getDelay(MILLISECONDS));
}
if (cancelled) {
throw new InterruptedException("Task cancelled");
}
if (except != null) {
throw new ExecutionException(except.getMessage(), except);
}
return result;
}
@Override
public V get(long l, TimeUnit timeUnit) throws InterruptedException, ExecutionException, TimeoutException {
requireNonNull(timeUnit, "timeUnit == null");
if (!done) {
long maxWait = timeUnit.toMillis(l);
clock.tick(Math.min(maxWait, getDelay(MILLISECONDS)));
if (!done) {
throw new TimeoutException("Timed out after " + timeUnit.toMillis(l) + " millis");
}
}
if (cancelled) {
throw new InterruptedException("Task cancelled");
}
if (except != null) {
throw new ExecutionException(except.getMessage(), except);
}
return result;
}
@Override
public void run() {
V result = null;
try {
result = callable.call();
} catch (Exception e) {
this.except = e;
}
this.result = result;
this.done = true;
}
}
/**
* A fake recurring task / future.
*/
public final class FakeRecurringTask implements ScheduledFuture<Void> {
private final long delay;
private final Runnable callable;
private final AtomicReference<FakeTask<?>> next;
AtomicLong nextExecution;
private boolean cancelled;
private FakeRecurringTask(long delay,
long initialDelay,
TimeUnit timeUnit,
Runnable callable,
AtomicReference<FakeTask<?>> first) {
this.delay = timeUnit.toMillis(delay);
this.nextExecution = new AtomicLong(clock.millis() + timeUnit.toMillis(initialDelay));
this.callable = callable;
this.next = first;
this.cancelled = false;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
FakeRecurringTask that = (FakeRecurringTask) o;
return delay == that.delay &&
cancelled == that.cancelled &&
callable.equals(that.callable) &&
nextExecution.get() == that.nextExecution.get();
}
@Override
public int hashCode() {
return Objects.hash(delay, callable, nextExecution, cancelled);
}
@Override
public int compareTo(Delayed delayed) {
requireNonNull(delayed, "delayed == null");
if (delayed == this) {
return 0;
}
if (delayed instanceof FakeRecurringTask) {
FakeRecurringTask task = (FakeRecurringTask) delayed;
int c = Long.compare(nextExecution.get(), task.nextExecution.get());
if (c != 0) {
return c;
}
return Integer.compare(next.get().id, task.next.get().id);
}
long diff = getDelay(MILLISECONDS) - delayed.getDelay(MILLISECONDS);
return (diff < 0) ? -1 : (diff > 0) ? 1 : 0;
}
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
if (!cancelled) {
cancelled = true;
scheduledTasks.remove(next.get());
}
return cancelled;
}
@Override
public boolean isCancelled() {
return cancelled;
}
@Override
public boolean isDone() {
return cancelled;
}
@Override
public Void get() throws InterruptedException {
if (cancelled) {
throw new InterruptedException("Task cancelled");
}
throw new IllegalStateException("Cannot wait for fake recurring tasks");
}
@Override
public Void get(long l, TimeUnit timeUnit) throws InterruptedException {
requireNonNull(timeUnit, "timeUnit == null");
return get();
}
@Override
public long getDelay(TimeUnit timeUnit) {
requireNonNull(timeUnit, "timeUnit == null");
long now = clock.millis();
long realDelay = nextExecution.get() - now;
return timeUnit.convert(realDelay, MILLISECONDS);
}
void runWithDelay() {
try {
callable.run();
} catch (Exception e) {
// e.printStackTrace();
// keep going.
}
next.set(schedule(this::runWithDelay, delay, MILLISECONDS));
}
void runWithRate() {
long now = clock.millis();
// Run the task for each time it should have been run in the
// time the last 'tick' should have triggered.
long nextRun = nextExecution.get();
while (nextRun <= now) {
try {
callable.run();
} catch (Exception e) {
// e.printStackTrace();
// keep going.
}
nextRun = nextExecution.addAndGet(delay);
}
long delay = nextRun - now;
next.set(schedule(this::runWithRate, delay, MILLISECONDS));
}
}
private final FakeClock clock;
private final List<FakeTask<?>> scheduledTasks;
private final AtomicInteger nextId = new AtomicInteger();
private final FakeTimeListener listener;
private boolean shutdownCalled = false;
/**
* Create a fake scheduled executor.
*
* @param clock Clock to trigger executions.
*/
public ImmediateScheduledExecutor(FakeClock clock) {
this.scheduledTasks = new ArrayList<>();
this.clock = requireNonNull(clock, "clock == null");
this.listener = new FakeTimeListener();
this.clock.addListener(listener);
}
// ---- FakeClock.TimeListener
/**
* @param newNow The new current time value.
* @deprecated Use the {@link FakeClock#tick(long)} method to propagate time.
*/
@Override
@Deprecated(since = "v5.2.0", forRemoval = true)
public void newCurrentTime(Instant newNow) {
throw new UnsupportedOperationException("Not allowed to propagate time by calling this method");
}
/**
* @param now The current time when asking.
* @return The delay until the next execution should happen.
* @deprecated Should not be used.
*/
@Override
@Deprecated(since = "v5.2.0", forRemoval = true)
public Duration getDelay(Instant now) {
return listener.getDelay(now);
}
// ---- ScheduledExecutorService
@Override
public FakeTask<?> schedule(Runnable runnable, long l, TimeUnit timeUnit) {
requireNonNull(runnable, "runnable == null");
requireNonNull(timeUnit, "timeUnit == null");
return this.schedule(() -> {
runnable.run();
return null;
}, l, timeUnit);
}
@Override
public <V> FakeTask<V> schedule(Callable<V> callable, long delay, TimeUnit timeUnit) {
requireNonNull(callable, "callable == null");
requireNonNull(timeUnit, "timeUnit == null");
if (isShutdown()) {
throw new IllegalStateException("Executor is shut down");
}
if (delay < 0) {
throw new IllegalArgumentException("Unable to schedule tasks in the past");
}
var now = clock.instant();
var triggersAt = now.plusMillis(timeUnit.toMillis(delay));
FakeTask<V> task = new FakeTask<>(triggersAt, callable);
scheduledTasks.add(task);
Collections.sort(scheduledTasks);
return task;
}
@Override
public FakeRecurringTask scheduleAtFixedRate(Runnable runnable, long initialDelay, long period, TimeUnit timeUnit) {
requireNonNull(runnable, "runnable == null");
requireNonNull(timeUnit, "timeUnit == null");
if (initialDelay < 0 || period < 1) {
throw new IllegalArgumentException("Invalid initial delay or period: " + initialDelay + " / " + period);
}
AtomicReference<FakeTask<?>> first = new AtomicReference<>();
FakeRecurringTask recurring = new FakeRecurringTask(period, initialDelay, timeUnit, runnable, first);
first.set(schedule(recurring::runWithRate, initialDelay, timeUnit));
return recurring;
}
@Override
public FakeRecurringTask scheduleWithFixedDelay(Runnable runnable,
long initialDelay,
long delay,
TimeUnit timeUnit) {
requireNonNull(runnable, "runnable == null");
requireNonNull(timeUnit, "timeUnit == null");
if (initialDelay < 0 || delay < 1) {
throw new IllegalArgumentException("Invalid initial delay or intermediate delay: "
+ initialDelay
+ " / "
+ delay);
}
AtomicReference<FakeTask<?>> first = new AtomicReference<>();
FakeRecurringTask recurring = new FakeRecurringTask(delay, initialDelay, timeUnit, runnable, first);
first.set(schedule(recurring::runWithDelay, initialDelay, timeUnit));
return recurring;
}
// ---- ExecutorService
@Override
public void shutdown() {
this.shutdownCalled = true;
}
@Override
public List<Runnable> shutdownNow() {
this.shutdownCalled = true;
List<Runnable> result = scheduledTasks.stream()
.filter(t -> !t.isCancelled())
.filter(t -> !t.isDone())
.collect(Collectors.toList());
scheduledTasks.clear();
return result;
}
@Override
public boolean isShutdown() {
return shutdownCalled;
}
@Override
public boolean isTerminated() {
return shutdownCalled && scheduledTasks.isEmpty();
}
@Override
public boolean awaitTermination(long l, TimeUnit timeUnit) {
requireNonNull(timeUnit, "timeUnit == null");
if (!shutdownCalled) {
throw new IllegalStateException("Shutdown not triggered");
}
var now = clock.instant();
var until = now.plusMillis(timeUnit.toMillis(l));
FakeTask<?> next;
while ((next = getNextTask()) != null) {
now = clock.instant();
var delay = next.getDelay(MILLISECONDS);
if (delay > 0 && now.plusMillis(delay).isAfter(until)) {
break;
}
if (delay > 0) {
clock.tick(delay);
} else {
listener.newCurrentTime(now);
}
}
return scheduledTasks.isEmpty();
}
@Override
public <T> FakeTask<T> submit(Callable<T> callable) {
return schedule(callable, 0, MILLISECONDS);
}
@Override
public <T> FakeTask<T> submit(Runnable runnable, T t) {
requireNonNull(runnable, "runnable == null");
return schedule(() -> {
runnable.run();
return t;
}, 0, MILLISECONDS);
}
@Override
public FakeTask<?> submit(Runnable runnable) {
return schedule(runnable, 0, MILLISECONDS);
}
@Override
public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> collection) {
requireNonNull(collection, "collection == null");
if (isShutdown()) {
throw new IllegalStateException("Executor is shut down");
}
if (collection.isEmpty()) {
throw new IllegalArgumentException("Empty invoke collection");
}
List<Future<T>> results = new ArrayList<>();
for (Callable<T> c : collection) {
results.add(submit(c));
}
return results;
}
@Override
public <T> List<Future<T>> invokeAll(
Collection<? extends Callable<T>> collection,
long l,
TimeUnit timeUnit) {
requireNonNull(timeUnit, "timeUnit == null");
if (l < 0) {
throw new IllegalArgumentException("Negative timeout: " + l);
}
return invokeAll(collection);
}
@Override
public <T> T invokeAny(Collection<? extends Callable<T>> collection) throws ExecutionException {
requireNonNull(collection, "collection == null");
if (isShutdown()) {
throw new IllegalStateException("Executor is shut down");
}
if (collection.isEmpty()) {
throw new IllegalArgumentException("Empty invoke collection");
}
ExecutionException ex = null;
for (Callable<T> c : collection) {
try {
return c.call();
} catch (Exception e) {
if (ex == null) {
ex = new ExecutionException("All " + collection.size() + " tasks failed, first exception", e);
} else {
ex.addSuppressed(e);
}
}
}
throw ex;
}
@Override
public <T> T invokeAny(Collection<? extends Callable<T>> collection,
long l,
TimeUnit timeUnit) throws ExecutionException {
requireNonNull(timeUnit, "timeUnit == null");
if (l < 0) {
throw new IllegalArgumentException("Negative timeout: " + l);
}
return invokeAny(collection);
}
@Override
@SuppressWarnings("FutureReturnValueIgnored")
public void execute(Runnable runnable) {
schedule(runnable, 0, MILLISECONDS);
}
private FakeTask<?> getNextTask() {
if (scheduledTasks.isEmpty()) {
return null;
}
return scheduledTasks.get(0);
}
private final class FakeTimeListener implements FakeClock.TimeListener {
@Override
public Duration getDelay(Instant now) {
var next = scheduledTasks
.stream()
.filter(it -> !it.isDone())
.map(it -> it.triggersAt)
.sorted()
.findFirst()
.orElse(Instant.MAX);
if (next.equals(Instant.MAX)) {
return Duration.ZERO;
}
if (!now.isBefore(next)) {
return FakeClock.MIN_DELAY;
}
return Duration.between(now, next);
}
@Override
public void newCurrentTime(Instant now) {
FakeTask<?> next;
while ((next = getNextTask()) != null) {
if (next.isDone()) {
scheduledTasks.remove(next);
} else if (!now.isBefore(next.triggersAt)) {
scheduledTasks.remove(next);
next.run();
} else {
break;
}
}
}
}
}