ProtoMessage.java

package net.morimekta.proto;

import com.google.protobuf.Descriptors;
import com.google.protobuf.Message;
import com.google.protobuf.MessageOrBuilder;

import java.util.IdentityHashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;

import static java.util.Objects.requireNonNull;
import static net.morimekta.proto.utils.ReflectionUtil.getInstanceClass;
import static net.morimekta.proto.utils.ReflectionUtil.getSingleton;

/**
 * A wrapper around a proto message for emulating java conveniences.
 */
public class ProtoMessage
        extends ProtoMessageOrBuilder {
    private final Map<Descriptors.FieldDescriptor, Object>           valueCache;
    private final Map<Descriptors.FieldDescriptor, Optional<Object>> optionalCache;

    /**
     * @param message The proto message or builder.
     */
    public ProtoMessage(MessageOrBuilder message) {
        super(message);
        var fields = getDescriptor().getFields().size();
        this.valueCache = new IdentityHashMap<>(fields);
        this.optionalCache = new IdentityHashMap<>(fields);
    }

    @Override
    public Message getMessage() {
        return (Message) super.getMessage();
    }

    @Override
    @SuppressWarnings("unchecked")
    public <T> T get(Descriptors.FieldDescriptor field) {
        requireNonNull(field, "field == null");
        return (T) valueCache.computeIfAbsent(field, this::getInternal);
    }

    @Override
    @SuppressWarnings("unchecked")
    public <T> Optional<T> optional(Descriptors.FieldDescriptor field) {
        requireNonNull(field, "field == null");
        return (Optional<T>) optionalCache.computeIfAbsent(field, this::optionalInternal);
    }

    // --------------------------------------
    // --------        STATIC        --------
    // --------------------------------------

    /**
     * @param descriptor Message descriptor.
     * @return A new builder for the message.
     */
    public static Message.Builder newBuilder(Descriptors.Descriptor descriptor) {
        return getDefaultInstance(descriptor).toBuilder();
    }

    /**
     * @param type A message type.
     * @return A new builder for the message.
     */
    public static Message.Builder newBuilder(Class<?> type) {
        return getDefaultInstance(type).toBuilder();
    }

    /**
     * @param descriptor Message descriptor.
     * @return The default instance for the type.
     */
    public static Message getDefaultInstance(Descriptors.Descriptor descriptor) {
        return getDefaultInstance(getMessageClass(descriptor));
    }

    /**
     * @param type A message type.
     * @return The default instance for the type.
     */
    public static Message getDefaultInstance(Class<?> type) {
        requireNonNull(type, "type == null");
        if (Message.class.isAssignableFrom(type)) {
            return defaultInstanceMap.computeIfAbsent(
                    type, t -> getSingleton(type, Message.class, "getDefaultInstance"));
        }
        throw new IllegalArgumentException("Not a typed message: " + type.getSimpleName());
    }

    /**
     * @param descriptor A message descriptor.
     * @param <T>        The message type.
     * @return The message type class.
     */
    @SuppressWarnings("unchecked")
    public static <T extends Message> Class<T> getMessageClass(Descriptors.Descriptor descriptor) {
        requireNonNull(descriptor, "descriptor == null");
        return (Class<T>) descriptorClassMap.computeIfAbsent(
                descriptor, d -> {
                    var type = getInstanceClass(descriptor);
                    classDescriptorMap.put(type, descriptor);
                    return type;
                });
    }

    /**
     * @param type The message type class.
     * @return The message descriptor for the type.
     */
    public static Descriptors.Descriptor getMessageDescriptor(Class<?> type) {
        requireNonNull(type, "type == null");
        if (Message.class.isAssignableFrom(type)) {
            return classDescriptorMap.computeIfAbsent(
                    type, t -> {
                        var descriptor = getSingleton(type,
                                                      Descriptors.Descriptor.class,
                                                      "getDescriptor");
                        descriptorClassMap.put(descriptor, type);
                        return descriptor;
                    });
        }
        throw new IllegalArgumentException("Not a typed message: " + type.getSimpleName());
    }

    // ---------------------------------------
    // --------        PRIVATE        --------
    // ---------------------------------------

    private static final Map<Class<?>, Descriptors.Descriptor> classDescriptorMap = new ConcurrentHashMap<>();
    private static final Map<Descriptors.Descriptor, Class<?>> descriptorClassMap = new ConcurrentHashMap<>();
    private static final Map<Class<?>, Message>                defaultInstanceMap = new ConcurrentHashMap<>();
}