ProtoMapBuilder.java
/*
* Copyright 2022 Proto Utils Authors
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package net.morimekta.proto;
import com.google.protobuf.Descriptors;
import com.google.protobuf.MapEntry;
import com.google.protobuf.Message;
import net.morimekta.proto.utils.FieldUtil;
import net.morimekta.proto.utils.ValueUtil;
import java.util.AbstractCollection;
import java.util.AbstractMap;
import java.util.AbstractSet;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Set;
import java.util.Spliterator;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import static java.util.Objects.requireNonNull;
import static net.morimekta.proto.utils.FieldUtil.getDefaultTypeValue;
import static net.morimekta.proto.utils.ValueUtil.asString;
import static net.morimekta.proto.utils.ValueUtil.toJavaValue;
import static net.morimekta.proto.utils.ValueUtil.toProtoValue;
/**
* A map wrapping a map field on a message builder.
*
* @param <K> The map key type.
* @param <V> The map value type.
*/
@SuppressWarnings("unchecked")
public class ProtoMapBuilder<K, V> implements Map<K, V> {
private transient final Message.Builder builder;
private transient final Descriptors.FieldDescriptor field;
private transient final MapEntry<Object, Object> defaultEntry;
private transient final Descriptors.FieldDescriptor keyType;
private transient final Descriptors.FieldDescriptor valueType;
/**
* @param builder The message builder containing the map field.
* @param field The map field descriptor.
*/
public ProtoMapBuilder(Message.Builder builder, Descriptors.FieldDescriptor field) {
if (!field.isRepeated() || !field.isMapField()) {
throw new IllegalArgumentException("Not a map field: " + field);
}
this.builder = builder;
this.field = field;
this.keyType = FieldUtil.getMapKeyDescriptor(field);
this.valueType = FieldUtil.getMapValueDescriptor(field);
this.defaultEntry = MapEntry.newDefaultInstance(field.getMessageType(),
keyType.getLiteType(),
getDefaultTypeValue(keyType),
valueType.getLiteType(),
getDefaultTypeValue(valueType));
}
@Override
public int size() {
return builder.getRepeatedFieldCount(field);
}
@Override
public boolean isEmpty() {
return size() == 0;
}
@Override
public boolean containsKey(Object key) {
requireNonNull(key, "key == null");
return indexOfInternal(toProtoValue(keyType, key)) >= 0;
}
@Override
public boolean containsValue(Object value) {
requireNonNull(value, "value == null");
var protoValue = toProtoValue(valueType, value);
return mapEntryStream()
.anyMatch(e -> Objects.equals(protoValue, e.getValue()));
}
@Override
public V get(Object key) {
requireNonNull(key, "key == null");
return (V) toJavaValue(valueType, getInternal(toProtoValue(keyType, key)));
}
@Override
public Set<K> keySet() {
return new KeySet();
}
@Override
public Collection<V> values() {
return new ValueCollection();
}
@Override
public Set<Entry<K, V>> entrySet() {
return new EntrySet();
}
// ---- Value Write
@Override
public V put(K key, V value) {
requireNonNull(key, "key == null");
requireNonNull(value, "value == null");
var protoKey = toProtoValue(keyType, key);
var protoValue = toProtoValue(valueType, value);
var idx = indexOfInternal(protoKey);
if (idx >= 0) {
var old = getMapEntry(idx);
builder.setRepeatedField(field, idx, makeProtoEntry(protoKey, protoValue));
return (V) toJavaValue(valueType, old.getValue());
} else {
builder.addRepeatedField(field, makeProtoEntry(protoKey, protoValue));
return null;
}
}
@Override
public V remove(Object key) {
requireNonNull(key, "key == null");
var protoKey = toProtoValue(keyType, key);
var old = getInternal(protoKey);
removeInternal(protoKey);
return (V) toJavaValue(valueType, old);
}
@Override
public void putAll(Map<? extends K, ? extends V> map) {
map.forEach(this::put);
}
@Override
public void clear() {
builder.clearField(field);
}
// ---- Object
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof Map)) {
return false;
}
var map = (Map<?, ?>) o;
if (map.size() != size()) {
return false;
}
for (var entry : map.entrySet()) {
if (!containsKey(entry.getKey())) {
return false;
}
if (!Objects.equals(entry.getValue(), get(entry.getKey()))) {
return false;
}
}
return true;
}
@Override
public int hashCode() {
return Objects.hash(getClass(), field, size());
}
@Override
public String toString() {
return ValueUtil.asString(this);
}
// ------
private Object getInternal(Object protoKey) {
var idx = indexOfInternal(protoKey);
if (idx >= 0) {
return ((MapEntry<Object, Object>) builder.getRepeatedField(field, idx)).getValue();
}
return null;
}
private void removeInternal(Object protoKey) {
var filteredList = mapEntryStream()
.filter(entry -> !Objects.equals(protoKey, entry.getKey()))
.collect(Collectors.toList());
builder.setField(field, filteredList);
}
private int indexOfInternal(Object protoKey) {
for (int i = 0; i < size(); ++i) {
var entry = getMapEntry(i);
if (Objects.equals(protoKey, entry.getKey())) {
return i;
}
}
return -1;
}
private Stream<MapEntry<Object, Object>> mapEntryStream() {
return StreamSupport.stream(new MapEntrySpliterator(), false);
}
private MapEntry<Object, Object> getMapEntry(int idx) {
return (MapEntry<Object, Object>) builder.getRepeatedField(field, idx);
}
private MapEntry<Object, Object> makeProtoEntry(Object key, Object value) {
return defaultEntry.toBuilder()
.setKey(key)
.setValue(value)
.build();
}
private class KeySet
extends AbstractSet<K> {
@Override
public Iterator<K> iterator() {
return new KeyIterator();
}
@Override
public int size() {
return ProtoMapBuilder.this.size();
}
// ---- Object
@Override
public int hashCode() {
return ProtoMapBuilder.this.hashCode();
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof Collection)) {
return false;
}
var coll = (Collection<?>) o;
if (coll.size() != size()) {
return false;
}
for (var it : this) {
if (!coll.contains(it)) {
return false;
}
}
return true;
}
@Override
public String toString() {
var builder = new StringBuilder("[");
for (var it : this) {
if (builder.length() > 1) {
builder.append(", ");
}
builder.append(it);
}
return builder.append("]").toString();
}
}
private class KeyIterator
implements Iterator<K> {
private int nextIndex = 0;
private Object currentKey = null;
@Override
public boolean hasNext() {
return nextIndex < size();
}
@Override
public K next() {
if (nextIndex >= size()) {
throw new NoSuchElementException("" + nextIndex + " >= " + size());
}
var key = ((MapEntry<K, V>) builder.getRepeatedField(field, nextIndex)).getKey();
currentKey = key;
++nextIndex;
return (K) toJavaValue(keyType, key);
}
@Override
public void remove() {
if (currentKey == null) {
throw new IllegalStateException("No current element");
}
removeInternal(currentKey);
currentKey = null;
--nextIndex;
}
}
private class ValueCollection
extends AbstractCollection<V> {
@Override
public Iterator<V> iterator() {
return new ValueIterator();
}
@Override
public int size() {
return ProtoMapBuilder.this.size();
}
// ---- Object
@Override
public int hashCode() {
return ProtoMapBuilder.this.hashCode();
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof Collection)) {
return false;
}
var coll = (Collection<?>) o;
if (coll.size() != size()) {
return false;
}
for (var it : this) {
if (!coll.contains(it)) {
return false;
}
}
return true;
}
@Override
public String toString() {
var builder = new StringBuilder("[");
for (var it : this) {
if (builder.length() > 1) {
builder.append(", ");
}
builder.append(it);
}
return builder.append("]").toString();
}
}
private class ValueIterator
implements Iterator<V> {
private int nextIndex = 0;
private Object currentKey = null;
@Override
public boolean hasNext() {
return nextIndex < size();
}
@Override
public V next() {
if (nextIndex >= size()) {
throw new NoSuchElementException("" + nextIndex + " >= " + size());
}
var entry = getMapEntry(nextIndex);
currentKey = entry.getKey();
++nextIndex;
return (V) toJavaValue(valueType, entry.getValue());
}
@Override
public void remove() {
if (currentKey == null) {
throw new NoSuchElementException("No current element");
}
removeInternal(currentKey);
currentKey = null;
--nextIndex;
}
}
private class EntrySet
extends AbstractSet<Entry<K, V>> {
@Override
public Iterator<Entry<K, V>> iterator() {
return new EntryIterator(ProtoMapBuilder.this);
}
@Override
public int size() {
return ProtoMapBuilder.this.size();
}
// ---- Object
@Override
public int hashCode() {
return ProtoMapBuilder.this.hashCode();
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof Collection)) {
return false;
}
var coll = (Collection<?>) o;
if (coll.size() != size()) {
return false;
}
for (var it : this) {
if (!coll.contains(it)) {
return false;
}
}
return true;
}
@Override
public String toString() {
var builder = new StringBuilder("[");
for (var it : this) {
if (builder.length() > 1) {
builder.append(", ");
}
builder.append(it);
}
return builder.append("]").toString();
}
}
private static class EntryIterator<K, V>
implements Iterator<Entry<K, V>> {
private ProtoMapBuilder<K, V> map;
private int nextIndex = 0;
private Object currentKey = null;
private EntryIterator(ProtoMapBuilder<K, V> map) {
this.map = map;
}
@Override
public boolean hasNext() {
return nextIndex < map.size();
}
@Override
public Entry<K, V> next() {
if (nextIndex >= map.size()) {
throw new NoSuchElementException("" + nextIndex + " >= " + map.size());
}
var entry = map.getMapEntry(nextIndex);
currentKey = entry.getKey();
++nextIndex;
return new EntryEntry(
(K) toJavaValue(map.keyType, entry.getKey()),
(V) toJavaValue(map.valueType, entry.getValue()));
}
@Override
public void remove() {
if (currentKey == null) {
throw new NoSuchElementException("No current element");
}
map.removeInternal(currentKey);
currentKey = null;
--nextIndex;
}
private class EntryEntry
extends AbstractMap.SimpleEntry<K, V> {
public EntryEntry(K key, V value) {
super(key, value);
}
@Override
public V setValue(V value) {
map.put(getKey(), requireNonNull(value, "value == null"));
return super.setValue(value);
}
}
}
private class MapEntrySpliterator
implements Spliterator<MapEntry<Object, Object>> {
private int nextIndex;
private MapEntrySpliterator() {
this.nextIndex = 0;
}
@Override
public boolean tryAdvance(Consumer<? super MapEntry<Object, Object>> consumer) {
if (nextIndex < size()) {
consumer.accept(getMapEntry(nextIndex));
++nextIndex;
return true;
}
return false;
}
@Override
public Spliterator<MapEntry<Object, Object>> trySplit() {
return null;
}
@Override
public long estimateSize() {
return size() - nextIndex;
}
@Override
public int characteristics() {
return Spliterator.DISTINCT |
Spliterator.NONNULL |
Spliterator.ORDERED |
Spliterator.SIZED;
}
}
}