ProtoMapBuilder.java

package net.morimekta.proto;

import com.google.protobuf.Descriptors;
import com.google.protobuf.MapEntry;
import com.google.protobuf.Message;
import net.morimekta.proto.utils.FieldUtil;
import net.morimekta.proto.utils.ValueUtil;

import java.util.AbstractCollection;
import java.util.AbstractMap;
import java.util.AbstractSet;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Set;
import java.util.Spliterator;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import static java.util.Objects.requireNonNull;
import static net.morimekta.proto.utils.FieldUtil.getDefaultTypeValue;
import static net.morimekta.proto.utils.ValueUtil.asString;
import static net.morimekta.proto.utils.ValueUtil.toJavaValue;
import static net.morimekta.proto.utils.ValueUtil.toProtoValue;

/**
 * A map wrapping a map field on a message builder.
 *
 * @param <K> The map key type.
 * @param <V> The map value type.
 */
@SuppressWarnings("unchecked")
public class ProtoMapBuilder<K, V> implements Map<K, V> {
    private transient final Message.Builder             builder;
    private transient final Descriptors.FieldDescriptor field;
    private transient final MapEntry<Object, Object>    defaultEntry;
    private transient final Descriptors.FieldDescriptor keyType;
    private transient final Descriptors.FieldDescriptor valueType;

    /**
     * @param builder The message builder containing the map field.
     * @param field   The map field descriptor.
     */
    public ProtoMapBuilder(Message.Builder builder, Descriptors.FieldDescriptor field) {
        if (!field.isRepeated() || !field.isMapField()) {
            throw new IllegalArgumentException("Not a map field: " + field);
        }
        this.builder = builder;
        this.field = field;
        this.keyType = FieldUtil.getMapKeyDescriptor(field);
        this.valueType = FieldUtil.getMapValueDescriptor(field);
        this.defaultEntry = MapEntry.newDefaultInstance(field.getMessageType(),
                                                        keyType.getLiteType(),
                                                        getDefaultTypeValue(keyType),
                                                        valueType.getLiteType(),
                                                        getDefaultTypeValue(valueType));
    }

    @Override
    public int size() {
        return builder.getRepeatedFieldCount(field);
    }

    @Override
    public boolean isEmpty() {
        return size() == 0;
    }

    @Override
    public boolean containsKey(Object key) {
        requireNonNull(key, "key == null");
        return indexOfInternal(toProtoValue(keyType, key)) >= 0;
    }

    @Override
    public boolean containsValue(Object value) {
        requireNonNull(value, "value == null");
        var protoValue = toProtoValue(valueType, value);
        return mapEntryStream()
                .anyMatch(e -> Objects.equals(protoValue, e.getValue()));
    }

    @Override
    public V get(Object key) {
        requireNonNull(key, "key == null");
        return (V) toJavaValue(valueType, getInternal(toProtoValue(keyType, key)));
    }

    @Override
    public Set<K> keySet() {
        return new KeySet();
    }

    @Override
    public Collection<V> values() {
        return new ValueCollection();
    }

    @Override
    public Set<Entry<K, V>> entrySet() {
        return new EntrySet();
    }

    // ---- Value Write

    @Override
    public V put(K key, V value) {
        requireNonNull(key, "key == null");
        requireNonNull(value, "value == null");
        var protoKey = toProtoValue(keyType, key);
        var protoValue = toProtoValue(valueType, value);
        var idx = indexOfInternal(protoKey);
        if (idx >= 0) {
            var old = getMapEntry(idx);
            builder.setRepeatedField(field, idx, makeProtoEntry(protoKey, protoValue));
            return (V) toJavaValue(valueType, old.getValue());
        } else {
            builder.addRepeatedField(field, makeProtoEntry(protoKey, protoValue));
            return null;
        }
    }

    @Override
    public V remove(Object key) {
        requireNonNull(key, "key == null");
        var protoKey = toProtoValue(keyType, key);
        var old = getInternal(protoKey);
        removeInternal(protoKey);
        return (V) toJavaValue(valueType, old);
    }

    @Override
    public void putAll(Map<? extends K, ? extends V> map) {
        map.forEach(this::put);
    }

    @Override
    public void clear() {
        builder.clearField(field);
    }

    // ---- Object

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (!(o instanceof Map)) {
            return false;
        }
        var map = (Map<?, ?>) o;
        if (map.size() != size()) {
            return false;
        }
        for (var entry : map.entrySet()) {
            if (!containsKey(entry.getKey())) {
                return false;
            }
            if (!Objects.equals(entry.getValue(), get(entry.getKey()))) {
                return false;
            }
        }
        return true;
    }

    @Override
    public int hashCode() {
        return Objects.hash(getClass(), field, size());
    }

    @Override
    public String toString() {
        return ValueUtil.asString(this);
    }

    // ------

    private Object getInternal(Object protoKey) {
        var idx = indexOfInternal(protoKey);
        if (idx >= 0) {
            return ((MapEntry<Object, Object>) builder.getRepeatedField(field, idx)).getValue();
        }
        return null;
    }

    private void removeInternal(Object protoKey) {
        var filteredList = mapEntryStream()
                .filter(entry -> !Objects.equals(protoKey, entry.getKey()))
                .collect(Collectors.toList());
        builder.setField(field, filteredList);
    }

    private int indexOfInternal(Object protoKey) {
        for (int i = 0; i < size(); ++i) {
            var entry = getMapEntry(i);
            if (Objects.equals(protoKey, entry.getKey())) {
                return i;
            }
        }
        return -1;
    }

    private Stream<MapEntry<Object, Object>> mapEntryStream() {
        return StreamSupport.stream(new MapEntrySpliterator(), false);
    }

    private MapEntry<Object, Object> getMapEntry(int idx) {
        return (MapEntry<Object, Object>) builder.getRepeatedField(field, idx);
    }

    private MapEntry<Object, Object> makeProtoEntry(Object key, Object value) {
        return defaultEntry.toBuilder()
                           .setKey(key)
                           .setValue(value)
                           .build();
    }

    private class KeySet
            extends AbstractSet<K> {
        @Override
        public Iterator<K> iterator() {
            return new KeyIterator();
        }

        @Override
        public int size() {
            return ProtoMapBuilder.this.size();
        }

        // ---- Object

        @Override
        public int hashCode() {
            return ProtoMapBuilder.this.hashCode();
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (!(o instanceof Collection)) {
                return false;
            }
            var coll = (Collection<?>) o;
            if (coll.size() != size()) {
                return false;
            }
            for (var it : this) {
                if (!coll.contains(it)) {
                    return false;
                }
            }
            return true;
        }

        @Override
        public String toString() {
            var builder = new StringBuilder("[");
            for (var it : this) {
                if (builder.length() > 1) {
                    builder.append(", ");
                }
                builder.append(it);
            }
            return builder.append("]").toString();
        }
    }

    private class KeyIterator
            implements Iterator<K> {
        private int    nextIndex  = 0;
        private Object currentKey = null;

        @Override
        public boolean hasNext() {
            return nextIndex < size();
        }

        @Override
        public K next() {
            if (nextIndex >= size()) {
                throw new NoSuchElementException("" + nextIndex + " >= " + size());
            }
            var key = ((MapEntry<K, V>) builder.getRepeatedField(field, nextIndex)).getKey();
            currentKey = key;
            ++nextIndex;
            return (K) toJavaValue(keyType, key);
        }

        @Override
        public void remove() {
            if (currentKey == null) {
                throw new IllegalStateException("No current element");
            }
            removeInternal(currentKey);
            currentKey = null;
            --nextIndex;
        }
    }

    private class ValueCollection
            extends AbstractCollection<V> {
        @Override
        public Iterator<V> iterator() {
            return new ValueIterator();
        }

        @Override
        public int size() {
            return ProtoMapBuilder.this.size();
        }

        // ---- Object

        @Override
        public int hashCode() {
            return ProtoMapBuilder.this.hashCode();
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (!(o instanceof Collection)) {
                return false;
            }
            var coll = (Collection<?>) o;
            if (coll.size() != size()) {
                return false;
            }
            for (var it : this) {
                if (!coll.contains(it)) {
                    return false;
                }
            }
            return true;
        }

        @Override
        public String toString() {
            var builder = new StringBuilder("[");
            for (var it : this) {
                if (builder.length() > 1) {
                    builder.append(", ");
                }
                builder.append(it);
            }
            return builder.append("]").toString();
        }
    }

    private class ValueIterator
            implements Iterator<V> {
        private int    nextIndex  = 0;
        private Object currentKey = null;

        @Override
        public boolean hasNext() {
            return nextIndex < size();
        }

        @Override
        public V next() {
            if (nextIndex >= size()) {
                throw new NoSuchElementException("" + nextIndex + " >= " + size());
            }
            var entry = getMapEntry(nextIndex);
            currentKey = entry.getKey();
            ++nextIndex;
            return (V) toJavaValue(valueType, entry.getValue());
        }

        @Override
        public void remove() {
            if (currentKey == null) {
                throw new NoSuchElementException("No current element");
            }
            removeInternal(currentKey);
            currentKey = null;
            --nextIndex;
        }
    }

    private class EntrySet
            extends AbstractSet<Entry<K, V>> {
        @Override
        public Iterator<Entry<K, V>> iterator() {
            return new EntryIterator(ProtoMapBuilder.this);
        }

        @Override
        public int size() {
            return ProtoMapBuilder.this.size();
        }

        // ---- Object

        @Override
        public int hashCode() {
            return ProtoMapBuilder.this.hashCode();
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (!(o instanceof Collection)) {
                return false;
            }
            var coll = (Collection<?>) o;
            if (coll.size() != size()) {
                return false;
            }
            for (var it : this) {
                if (!coll.contains(it)) {
                    return false;
                }
            }
            return true;
        }

        @Override
        public String toString() {
            var builder = new StringBuilder("[");
            for (var it : this) {
                if (builder.length() > 1) {
                    builder.append(", ");
                }
                builder.append(it);
            }
            return builder.append("]").toString();
        }
    }

    private static class EntryIterator<K, V>
            implements Iterator<Entry<K, V>> {
        private ProtoMapBuilder<K, V> map;
        private int                   nextIndex  = 0;
        private Object                currentKey = null;

        private EntryIterator(ProtoMapBuilder<K, V> map) {
            this.map = map;
        }

        @Override
        public boolean hasNext() {
            return nextIndex < map.size();
        }

        @Override
        public Entry<K, V> next() {
            if (nextIndex >= map.size()) {
                throw new NoSuchElementException("" + nextIndex + " >= " + map.size());
            }
            var entry = map.getMapEntry(nextIndex);
            currentKey = entry.getKey();
            ++nextIndex;
            return new EntryEntry(
                    (K) toJavaValue(map.keyType, entry.getKey()),
                    (V) toJavaValue(map.valueType, entry.getValue()));
        }

        @Override
        public void remove() {
            if (currentKey == null) {
                throw new NoSuchElementException("No current element");
            }
            map.removeInternal(currentKey);
            currentKey = null;
            --nextIndex;
        }

        private class EntryEntry
                extends AbstractMap.SimpleEntry<K, V> {
            public EntryEntry(K key, V value) {
                super(key, value);
            }

            @Override
            public V setValue(V value) {
                map.put(getKey(), requireNonNull(value, "value == null"));
                return super.setValue(value);
            }
        }
    }

    private class MapEntrySpliterator
            implements Spliterator<MapEntry<Object, Object>> {
        private int nextIndex;

        private MapEntrySpliterator() {
            this.nextIndex = 0;
        }

        @Override
        public boolean tryAdvance(Consumer<? super MapEntry<Object, Object>> consumer) {
            if (nextIndex < size()) {
                consumer.accept(getMapEntry(nextIndex));
                ++nextIndex;
                return true;
            }
            return false;
        }

        @Override
        public Spliterator<MapEntry<Object, Object>> trySplit() {
            return null;
        }

        @Override
        public long estimateSize() {
            return size() - nextIndex;
        }

        @Override
        public int characteristics() {
            return Spliterator.DISTINCT |
                   Spliterator.NONNULL |
                   Spliterator.ORDERED |
                   Spliterator.SIZED;
        }
    }
}