ProtoMessage.java

/*
 * Copyright 2022 Proto Utils Authors
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.proto;

import com.google.protobuf.Descriptors;
import com.google.protobuf.Message;
import com.google.protobuf.MessageOrBuilder;

import java.util.IdentityHashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;

import static java.util.Objects.requireNonNull;
import static net.morimekta.proto.utils.ReflectionUtil.getInstanceClass;
import static net.morimekta.proto.utils.ReflectionUtil.getSingleton;

/**
 * A wrapper around a proto message for emulating java conveniences.
 */
public class ProtoMessage
        extends ProtoMessageOrBuilder {
    private final Map<Descriptors.FieldDescriptor, Object>           valueCache;
    private final Map<Descriptors.FieldDescriptor, Optional<Object>> optionalCache;

    /**
     * @param message The proto message or builder.
     */
    public ProtoMessage(MessageOrBuilder message) {
        super(message);
        var fields = getDescriptor().getFields().size();
        this.valueCache = new IdentityHashMap<>(fields);
        this.optionalCache = new IdentityHashMap<>(fields);
    }

    @Override
    public Message getMessage() {
        return (Message) super.getMessage();
    }

    @Override
    @SuppressWarnings("unchecked")
    public <T> T get(Descriptors.FieldDescriptor field) {
        requireNonNull(field, "field == null");
        return (T) valueCache.computeIfAbsent(field, this::getInternal);
    }

    @Override
    @SuppressWarnings("unchecked")
    public <T> Optional<T> optional(Descriptors.FieldDescriptor field) {
        requireNonNull(field, "field == null");
        return (Optional<T>) optionalCache.computeIfAbsent(field, this::optionalInternal);
    }

    // --------------------------------------
    // --------        STATIC        --------
    // --------------------------------------

    /**
     * @param descriptor Message descriptor.
     * @return A new builder for the message.
     */
    public static Message.Builder newBuilder(Descriptors.Descriptor descriptor) {
        return getDefaultInstance(descriptor).toBuilder();
    }

    /**
     * @param type A message type.
     * @return A new builder for the message.
     */
    public static Message.Builder newBuilder(Class<?> type) {
        return getDefaultInstance(type).toBuilder();
    }

    /**
     * @param descriptor Message descriptor.
     * @return The default instance for the type.
     */
    public static Message getDefaultInstance(Descriptors.Descriptor descriptor) {
        return getDefaultInstance(getMessageClass(descriptor));
    }

    /**
     * @param type A message type.
     * @return The default instance for the type.
     */
    public static Message getDefaultInstance(Class<?> type) {
        requireNonNull(type, "type == null");
        if (Message.class.isAssignableFrom(type)) {
            return defaultInstanceMap.computeIfAbsent(
                    type, t -> getSingleton(type, Message.class, "getDefaultInstance"));
        }
        throw new IllegalArgumentException("Not a typed message: " + type.getSimpleName());
    }

    /**
     * @param descriptor A message descriptor.
     * @param <T>        The message type.
     * @return The message type class.
     */
    @SuppressWarnings("unchecked")
    public static <T extends Message> Class<T> getMessageClass(Descriptors.Descriptor descriptor) {
        requireNonNull(descriptor, "descriptor == null");
        return (Class<T>) descriptorClassMap.computeIfAbsent(
                descriptor, d -> {
                    var type = getInstanceClass(descriptor);
                    classDescriptorMap.put(type, descriptor);
                    return type;
                });
    }

    /**
     * @param type The message type class.
     * @return The message descriptor for the type.
     */
    public static Descriptors.Descriptor getMessageDescriptor(Class<?> type) {
        requireNonNull(type, "type == null");
        if (Message.class.isAssignableFrom(type)) {
            return classDescriptorMap.computeIfAbsent(
                    type, t -> {
                        var descriptor = getSingleton(type,
                                                      Descriptors.Descriptor.class,
                                                      "getDescriptor");
                        descriptorClassMap.put(descriptor, type);
                        return descriptor;
                    });
        }
        throw new IllegalArgumentException("Not a typed message: " + type.getSimpleName());
    }

    // ---------------------------------------
    // --------        PRIVATE        --------
    // ---------------------------------------

    private static final Map<Class<?>, Descriptors.Descriptor> classDescriptorMap = new ConcurrentHashMap<>();
    private static final Map<Descriptors.Descriptor, Class<?>> descriptorClassMap = new ConcurrentHashMap<>();
    private static final Map<Class<?>, Message>                defaultInstanceMap = new ConcurrentHashMap<>();
}