ProtoMap.java

package net.morimekta.proto;

import com.google.protobuf.Descriptors;
import com.google.protobuf.MapEntry;
import com.google.protobuf.MessageOrBuilder;
import net.morimekta.proto.utils.FieldUtil;
import net.morimekta.proto.utils.ValueUtil;

import java.util.AbstractCollection;
import java.util.AbstractMap;
import java.util.AbstractSet;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Set;

import static java.util.Objects.requireNonNull;
import static net.morimekta.proto.utils.ValueUtil.toJavaValue;
import static net.morimekta.proto.utils.ValueUtil.toProtoValue;

/**
 * A map wrapping a proto message map field.
 *
 * @param <K> The map key type.
 * @param <V> The map value type.
 */
@SuppressWarnings("unchecked")
public class ProtoMap<K, V> implements Map<K, V> {
    private final transient MessageOrBuilder            message;
    private final transient Descriptors.FieldDescriptor field;
    private final transient Descriptors.FieldDescriptor keyType;
    private final transient Descriptors.FieldDescriptor valueType;

    /**
     * @param message The message containing the map.
     * @param field   The map field descriptor.
     */
    public ProtoMap(MessageOrBuilder message, Descriptors.FieldDescriptor field) {
        requireNonNull(message, "message == null");
        requireNonNull(field, "field == null");
        if (!field.isMapField()) {
            throw new IllegalArgumentException("Not a map field: " + field);
        }
        this.message = message;
        this.field = field;
        this.keyType = FieldUtil.getMapKeyDescriptor(field);
        this.valueType = FieldUtil.getMapValueDescriptor(field);
    }

    // ---- Map

    @Override
    public int size() {
        return message.getRepeatedFieldCount(field);
    }

    @Override
    public boolean isEmpty() {
        return size() == 0;
    }

    @Override
    public boolean containsKey(Object o) {
        return indexOfProto(toProtoValue(keyType, o)) >= 0;
    }

    @Override
    public boolean containsValue(Object o) {
        var protoValue = toProtoValue(valueType, o);
        for (int i = 0; i < size(); ++i) {
            if (Objects.equals(getMapEntry(i).getValue(), protoValue)) {
                return true;
            }
        }
        return false;
    }

    @Override
    public V get(Object key) {
        var idx = indexOfProto(toProtoValue(keyType, key));
        if (idx >= 0) {
            return (V) toJavaValue(keyType, getMapEntry(idx).getValue());
        }
        return null;
    }

    @Override
    public Set<K> keySet() {
        return new KeySet();
    }

    @Override
    public Collection<V> values() {
        return new ValueCollection();
    }

    @Override
    public Set<Entry<K, V>> entrySet() {
        return new EntrySet();
    }

    // ---- Unsupported

    @Override
    public V put(K k, V v) {
        throw new UnsupportedOperationException("Unmodifiable map");
    }

    @Override
    public V remove(Object o) {
        throw new UnsupportedOperationException("Unmodifiable map");
    }

    @Override
    public void putAll(Map<? extends K, ? extends V> map) {
        throw new UnsupportedOperationException("Unmodifiable map");
    }

    @Override
    public void clear() {
        throw new UnsupportedOperationException("Unmodifiable map");
    }

    // ---- Object

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (!(o instanceof Map)) {
            return false;
        }
        var map = (Map<?, ?>) o;
        if (map.size() != size()) {
            return false;
        }
        for (var entry : entrySet()) {
            var key = entry.getKey();
            if (!map.containsKey(key)) {
                return false;
            }
            if (!Objects.equals(entry.getValue(), map.get(key))) {
                return false;
            }
        }
        return true;
    }

    @Override
    public int hashCode() {
        return Objects.hash(getClass(), message, field);
    }

    @Override
    public String toString() {
        return ValueUtil.asString(this);
    }


    // ------

    private int indexOfProto(Object protoKey) {
        for (int i = 0; i < size(); ++i) {
            var entry = getMapEntry(i);
            if (Objects.equals(protoKey, entry.getKey())) {
                return i;
            }
        }
        return -1;
    }

    private MapEntry<Object, Object> getMapEntry(int idx) {
        return (MapEntry<Object, Object>) message.getRepeatedField(field, idx);
    }

    private class KeySet
            extends AbstractSet<K> {
        @Override
        public Iterator<K> iterator() {
            return new KeyIterator();
        }

        @Override
        public int size() {
            return ProtoMap.this.size();
        }

        // ---- Object

        @Override
        public int hashCode() {
            return ProtoMap.this.hashCode();
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (!(o instanceof Collection)) {
                return false;
            }
            var coll = (Collection<?>) o;
            if (coll.size() != size()) {
                return false;
            }
            for (var it : this) {
                if (!coll.contains(it)) {
                    return false;
                }
            }
            return true;
        }

        @Override
        public String toString() {
            var builder = new StringBuilder("[");
            for (var it : this) {
                if (builder.length() > 1) {
                    builder.append(", ");
                }
                builder.append(it);
            }
            return builder.append("]").toString();
        }
    }

    private class KeyIterator
            implements Iterator<K> {
        private int nextIndex = 0;

        @Override
        public boolean hasNext() {
            return nextIndex < size();
        }

        @Override
        @SuppressWarnings("unchecked")
        public K next() {
            if (nextIndex >= size()) {
                throw new NoSuchElementException("" + nextIndex + " >= " + size());
            }
            var entry = getMapEntry(nextIndex);
            ++nextIndex;
            return (K) toJavaValue(keyType, entry.getKey());
        }
    }

    private class ValueCollection
            extends AbstractCollection<V> {
        @Override
        public Iterator<V> iterator() {
            return new ValueIterator();
        }

        @Override
        public int size() {
            return ProtoMap.this.size();
        }

        // ---- Object

        @Override
        public int hashCode() {
            return ProtoMap.this.hashCode();
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (!(o instanceof Collection)) {
                return false;
            }
            var coll = (Collection<?>) o;
            if (coll.size() != size()) {
                return false;
            }
            for (var it : this) {
                if (!coll.contains(it)) {
                    return false;
                }
            }
            return true;
        }

        @Override
        public String toString() {
            var builder = new StringBuilder("[");
            for (var it : this) {
                if (builder.length() > 1) {
                    builder.append(", ");
                }
                builder.append(it);
            }
            return builder.append("]").toString();
        }
    }

    private class ValueIterator
            implements Iterator<V> {
        private int nextIndex = 0;

        @Override
        public boolean hasNext() {
            return nextIndex < size();
        }

        @Override
        @SuppressWarnings("unchecked")
        public V next() {
            if (nextIndex >= size()) {
                throw new NoSuchElementException("" + nextIndex + " >= " + size());
            }
            var entry = getMapEntry(nextIndex);
            ++nextIndex;
            return (V) toJavaValue(valueType, entry.getValue());
        }
    }

    private class EntrySet
            extends AbstractSet<Entry<K, V>> {
        @Override
        public Iterator<Entry<K, V>> iterator() {
            return new EntryIterator();
        }

        @Override
        public int size() {
            return ProtoMap.this.size();
        }

        // ---- Object

        @Override
        public int hashCode() {
            return ProtoMap.this.hashCode();
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (!(o instanceof Collection)) {
                return false;
            }
            var coll = (Collection<?>) o;
            if (coll.size() != size()) {
                return false;
            }
            for (var it : this) {
                if (!coll.contains(it)) {
                    return false;
                }
            }
            return true;
        }

        @Override
        public String toString() {
            var builder = new StringBuilder("[");
            for (var it : this) {
                if (builder.length() > 1) {
                    builder.append(", ");
                }
                builder.append(it);
            }
            return builder.append("]").toString();
        }
    }

    private class EntryIterator
            implements Iterator<Entry<K, V>> {
        private int nextIndex = 0;

        @Override
        public boolean hasNext() {
            return nextIndex < size();
        }

        @Override
        @SuppressWarnings("unchecked")
        public Entry<K, V> next() {
            if (nextIndex >= size()) {
                throw new NoSuchElementException("" + nextIndex + " >= " + size());
            }
            var entry = (MapEntry<Object, Object>) message.getRepeatedField(field, nextIndex);
            ++nextIndex;
            return new AbstractMap.SimpleImmutableEntry<>(
                    (K) toJavaValue(keyType, entry.getKey()),
                    (V) toJavaValue(valueType, entry.getValue()));
        }
    }
}