FakeClock.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.testing.concurrent;

import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;

import static java.lang.Math.max;
import static java.lang.Math.min;
import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE_TIME;
import static java.util.Objects.requireNonNull;
import static net.morimekta.strings.Displayable.displayableDuration;

/**
 * Fake clock implementation for testing.
 */
public class FakeClock
        extends Clock {
    public static final Duration MIN_DELAY = Duration.ofMillis(1);

    /**
     * Interface for listening to time changes.
     */
    @FunctionalInterface
    public interface TimeListener {
        /**
         * Get delay until the next something to be handled. Will return the default max value if nothing else.
         * If something is scheduled to happen immediately, {@link #MIN_DELAY} should be returned.
         *
         * @param now The current time when asking.
         * @return The time to delay. NULL values, {@link Duration#ZERO} and negative durations will be ignored.
         */
        default Duration getDelay(Instant now) {
            return Duration.ZERO;
        }

        /**
         * @param now The new current time value.
         */
        void newCurrentTime(Instant now);
    }

    /**
     * Create a fake clock instance using the actual current time.
     */
    public FakeClock() {
        this(systemUTC.instant().truncatedTo(ChronoUnit.MILLIS));
    }

    /**
     * @param millis Current time millis for new fake clock.
     * @return The created fake clock.
     */
    public static FakeClock forCurrentTimeMillis(long millis) {
        return forInstant(Instant.ofEpochMilli(millis));
    }

    /**
     * @param instant Current time for new fake clock. Will be truncated to MILLIS.
     * @return The created fake clock.
     */
    public static FakeClock forInstant(Instant instant) {
        return new FakeClock(requireNonNull(instant, "instant == null").truncatedTo(ChronoUnit.MILLIS));
    }

    /**
     * Tick the fake clock the given number of milliseconds.
     *
     * @param tickMs Milliseconds to move the clock.
     */
    public void tick(final long tickMs) {
        if (tickMs < minSkip) {
            throw new IllegalArgumentException("Invalid tick ms: " + tickMs);
        }
        tickInternal(tickMs);
    }

    /**
     * Tick the clock a certain duration into the future.
     *
     * @param duration The duration of time to skip.
     */
    public void tick(Duration duration) {
        requireNonNull(duration, "duration == null");
        if (duration.isNegative() || duration.isZero()) {
            throw new IllegalArgumentException("Invalid tick: " + displayableDuration(duration));
        }
        tickInternal(max(duration.toMillis(), 1));
    }

    /**
     * Tick the clock a guven number of any time units.
     *
     * @param time The amount of time.
     * @param unit The unit of time.
     */
    public void tick(long time, TimeUnit unit) {
        requireNonNull(unit, "unit == null");
        if (time < minSkip) {
            throw new IllegalArgumentException("Invalid tick: " + time + " " + unit);
        }
        tickInternal(max(unit.toMillis(time), 1));
    }

    /**
     * Add listener to the clock. Note that if a listener is added while a
     * {@code tick} is ongoing, it may only start listening after the current
     * tick increment is completed.
     *
     * @param listener Listener to be added to the clocks time progress.
     */
    public void addListener(TimeListener listener) {
        requireNonNull(listener, "listener == null");
        synchronized (listeners) {
            if (!listeners.contains(listener)) {
                listeners.add(listener);
            }
        }
    }

    /**
     * @param listener Listener to be removed from getting time progress.
     */
    public void removeListener(TimeListener listener) {
        requireNonNull(listener, "listener == null");
        synchronized (listeners) {
            listeners.remove(listener);
        }
    }

    /**
     * @return Get list of current listeners.
     */
    public List<TimeListener> getListeners() {
        synchronized (listeners) {
            return List.copyOf(listeners);
        }
    }

    // -- Clock

    @Override
    public ZoneId getZone() {
        return zoneId;
    }

    @Override
    public FakeClock withZone(ZoneId zoneId) {
        requireNonNull(zoneId, "zoneId == null");
        if (this.zoneId.equals(zoneId)) {
            return this;
        }
        return new FakeClock(currentInstant, tickUntilInstant, zoneId, listeners, inTick);
    }

    @Override
    public Instant instant() {
        return currentInstant.get();
    }

    // -- Object

    @Override
    public String toString() {
        return "FakeClock{" +
               "@" + ISO_LOCAL_DATE_TIME.format(LocalDateTime.ofInstant(currentInstant.get(), zoneId)) +
               (zoneId.equals(systemUTC.getZone()) ? "" : ", zoneId=" + zoneId.getId()) +
               '}';
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || getClass() != o.getClass()) {
            return false;
        }
        FakeClock fakeClock = (FakeClock) o;
        return currentInstant.get().equals(fakeClock.currentInstant.get()) &&
               zoneId.equals(fakeClock.zoneId);
    }

    @Override
    public int hashCode() {
        return Objects.hash(getClass(), currentInstant.get(), zoneId);
    }

    // -----------------------

    private void tickInternal(long tickMs) {
        tickUntilInstant.updateAndGet(d -> d.plus(tickMs, ChronoUnit.MILLIS));
        if (inTick.getAndSet(true)) {
            // avoid recursion. Just let the other call (currently in the
            // block below) take care of the extra time.
            return;
        }
        try {
            if (listeners.size() > 0) {
                // Tick the clock along in 47 millis blocks. This is to be able to
                // spread out the 'now' timestamps seen while ticking along.
                while (tickUntilInstant.get().isAfter(currentInstant.get())) {
                    final Instant oldCurrent = currentInstant.get();
                    final long skipMax = tickUntilInstant.get().toEpochMilli() - oldCurrent.toEpochMilli();
                    var tickingListeners = getListeners();
                    long skipDelay = tickingListeners
                            .stream()
                            .map(it -> it.getDelay(oldCurrent))
                            .filter(Objects::nonNull)
                            .map(it -> it.truncatedTo(ChronoUnit.MILLIS))
                            .filter(it -> !(it.isZero() || it.isNegative()))
                            .mapToLong(Duration::toMillis)
                            .min()
                            .orElse(skipMax);
                    final long skip = max(minSkip, min(skipMax, skipDelay));
                    final Instant newCurrent = currentInstant.updateAndGet(i -> i.plusMillis(skip));
                    tickingListeners.forEach(l -> l.newCurrentTime(newCurrent));
                }
            } else {
                currentInstant.set(tickUntilInstant.get());
            }
        } finally {
            inTick.set(false);
        }
    }

    private static final long  minSkip   = MIN_DELAY.toMillis();
    private static final Clock systemUTC = Clock.systemUTC();

    private final AtomicReference<Instant> currentInstant;
    private final AtomicReference<Instant> tickUntilInstant;
    private final ZoneId                   zoneId;
    private final List<TimeListener>       listeners;
    private final AtomicBoolean            inTick;

    private FakeClock(Instant now) {
        this(new AtomicReference<>(now),
             new AtomicReference<>(now),
             systemUTC.getZone(),
             Collections.synchronizedList(new ArrayList<>()),
             new AtomicBoolean());
    }

    private FakeClock(AtomicReference<Instant> currentInstant,
                      AtomicReference<Instant> tickUntilInstant,
                      ZoneId zoneId,
                      List<TimeListener> listeners,
                      AtomicBoolean inTick) {
        this.currentInstant = currentInstant;
        this.tickUntilInstant = tickUntilInstant;
        this.zoneId = zoneId;
        this.listeners = listeners;
        this.inTick = inTick;
    }
}