ProtoListBuilder.java

package net.morimekta.proto;

import com.google.protobuf.Descriptors;
import com.google.protobuf.Message;
import net.morimekta.collect.UnmodifiableList;
import net.morimekta.proto.utils.ValueUtil;

import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.NoSuchElementException;
import java.util.RandomAccess;
import java.util.stream.Collectors;

import static java.util.Objects.requireNonNull;

/**
 * A list wrapping a proro message repeated field for a builder.
 *
 * @param <T> The item type.
 */
public class ProtoListBuilder<T>
        extends AbstractList<T>
        implements RandomAccess {
    private transient final Message.Builder             builder;
    private transient final Descriptors.FieldDescriptor field;

    /**
     * @param builder The message builder.
     * @param field   The repeated field.
     */
    public ProtoListBuilder(Message.Builder builder, Descriptors.FieldDescriptor field) {
        requireNonNull(builder, "builder == null");
        requireNonNull(field, "field == null");
        if (!field.isRepeated() || field.isMapField()) {
            throw new IllegalArgumentException("Not a list type: " + field);
        }
        this.builder = builder;
        this.field = field;
    }

    @Override
    public int size() {
        return builder.getRepeatedFieldCount(field);
    }

    @Override
    public void clear() {
        builder.clearField(field);
    }

    @Override
    @SuppressWarnings("unchecked")
    public T get(int i) {
        return (T) ValueUtil.toJavaValue(field, builder.getRepeatedField(field, i));
    }

    @Override
    public T set(int index, T element) {
        if (index < 0) {
            throw new IllegalArgumentException("index < 0");
        }
        if (index > size()) {
            throw new IllegalArgumentException("index > size(" + size() + ")");
        }
        Object protoValue = ValueUtil.toProtoValue(field, element);
        if (index == size()) {
            builder.addRepeatedField(field, protoValue);
            return null;
        }
        T old = get(index);
        builder.setRepeatedField(field, index, protoValue);
        return old;
    }

    @Override
    public void add(int index, T element) {
        if (index < 0) {
            throw new IllegalArgumentException("index < 0");
        }
        if (index > size()) {
            throw new IllegalArgumentException("index > size(" + size() + ")");
        }
        Object protoValue = ValueUtil.toProtoValue(field, element);
        if (index == size()) {
            builder.addRepeatedField(field, protoValue);
        } else {
            @SuppressWarnings("unchecked")
            List<Object> tmp = new ArrayList<>((List<Object>) builder.getField(field));
            tmp.add(index, protoValue);
            builder.setField(field, tmp);
        }
    }

    @Override
    @SuppressWarnings("unchecked")
    public boolean addAll(int index, Collection<? extends T> c) {
        if (index < 0) {
            throw new IllegalArgumentException("index < 0");
        }
        if (index > size()) {
            throw new IllegalArgumentException("index > size(" + size() + ")");
        }
        if (c.isEmpty()) {
            return false;
        }

        if (index == size()) {
            for (Object t : c) {
                builder.addRepeatedField(field, ValueUtil.toProtoValue(field, t));
            }
        } else {
            List<T> tmp = new ArrayList<>(this);
            List<T> mapped = c.stream().map(i -> (T) ValueUtil.toProtoValue(field, i)).collect(Collectors.toList());
            tmp.addAll(index, mapped);
            builder.setField(field, tmp);
        }
        return true;
    }

    @Override
    @SuppressWarnings("unchecked")
    public T remove(int index) {
        if (index < 0) {
            throw new IllegalArgumentException("index < 0");
        }
        if (index >= size()) {
            throw new IllegalArgumentException("index >= size(" + size() + ")");
        }
        if (size() == 1) {
            var tmp = get(0);
            clear();
            return tmp;
        }
        @SuppressWarnings("unchecked")
        List<Object> tmp = new ArrayList<>((List<Object>) builder.getField(field));
        Object old = tmp.remove(index);
        builder.setField(field, tmp);
        return (T) ValueUtil.toJavaValue(field, old);
    }

    @Override
    public void sort(Comparator<? super T> c) {
        List<T> newList = UnmodifiableList.asList(this).sortedBy(c);
        builder.setField(field, ValueUtil.toProtoValue(field, newList));
    }

    @Override
    public Iterator<T> iterator() {
        return new FieldIterator(0);
    }

    @Override
    public ListIterator<T> listIterator() {
        return new FieldIterator(0);
    }

    @Override
    public ListIterator<T> listIterator(int index) {
        if (index < 0) {
            throw new IllegalArgumentException("index < 0");
        }
        if (index > size()) {
            throw new IllegalArgumentException("index > size(" + size() + ")");
        }
        return new FieldIterator(index);
    }

    @Override
    public String toString() {
        return ValueUtil.asString(this);
    }

    private class FieldIterator
            implements ListIterator<T> {
        // "this" item is removed, so cannot be accessed. Next will use same index.
        private int currentIndex;
        private int nextIndex;

        FieldIterator(int nextIndex) {
            this.currentIndex = -1;
            this.nextIndex = nextIndex;
        }

        @Override
        public boolean hasNext() {
            return nextIndex < size();
        }

        @Override
        public T next() {
            if (nextIndex >= size()) {
                throw new NoSuchElementException(nextIndex + " >= " + size());
            }
            currentIndex = nextIndex;
            ++nextIndex;
            return get(currentIndex);
        }

        @Override
        public boolean hasPrevious() {
            return nextIndex > 0;
        }

        @Override
        public T previous() {
            if (nextIndex <= 0) {
                throw new NoSuchElementException(nextIndex + " <= " + 0);
            }
            --nextIndex;
            currentIndex = nextIndex;
            return get(currentIndex);
        }

        @Override
        public int nextIndex() {
            return nextIndex;
        }

        @Override
        public int previousIndex() {
            return nextIndex - 1;
        }

        @Override
        public void remove() {
            if (currentIndex < 0) {
                throw new NoSuchElementException("No current item");
            }
            ProtoListBuilder.this.remove(currentIndex);
            currentIndex = -1;
            --nextIndex;
        }

        @Override
        public void set(T t) {
            if (currentIndex < 0) {
                throw new NoSuchElementException("No current item");
            }
            ProtoListBuilder.this.set(currentIndex, t);
        }

        @Override
        public void add(T t) {
            if (currentIndex < 0) {
                throw new NoSuchElementException("No current item");
            }
            ProtoListBuilder.this.add(currentIndex, t);
            ++currentIndex;
            ++nextIndex;
        }
    }
}