ProtoEnum.java

/*
 * Copyright 2016 Providence Authors
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.proto;

import com.google.protobuf.Descriptors;
import com.google.protobuf.ProtocolMessageEnum;

import java.lang.reflect.Type;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import static java.util.Objects.requireNonNull;
import static net.morimekta.collect.UnmodifiableList.toList;
import static net.morimekta.collect.UnmodifiableMap.toMap;
import static net.morimekta.collect.util.LazyCachedSupplier.lazyCache;
import static net.morimekta.proto.utils.ReflectionUtil.getInstanceClass;
import static net.morimekta.proto.utils.ReflectionUtil.getSingleton;
import static net.morimekta.strings.EscapeUtil.javaEscape;

/**
 * The definition of a serializable enum.
 */
public class ProtoEnum<E extends Enum<E> & ProtocolMessageEnum> implements Type {
    /**
     * The name of the UNRECOGNIZED enum entry.
     */
    public static final String UNRECOGNIZED = "UNRECOGNIZED";

    private final Class<E>                                          enumClass;
    private final Descriptors.EnumDescriptor                        protoDescriptor;
    private final Supplier<List<E>>                                 values;
    private final Supplier<Map<String, E>>                          nameMap;
    private final Supplier<Map<Integer, E>>                         idMap;
    private final Supplier<Map<Descriptors.EnumValueDescriptor, E>> valueMap;
    private final Supplier<Optional<E>>                             defaultValue;
    private final Supplier<Optional<E>>                             unrecognized;

    /**
     * @param enumClass      The enum type.
     * @param enumDescriptor The enum type descriptor.
     */
    protected ProtoEnum(Class<E> enumClass, Descriptors.EnumDescriptor enumDescriptor) {
        this.enumClass = requireNonNull(enumClass, "class == null");
        this.protoDescriptor = requireNonNull(enumDescriptor, "descriptor == null");

        // only v3 has unrecognized.
        this.unrecognized = lazyCache(
                () -> EnumSet.allOf(enumClass)
                             .stream()
                             .filter(val -> val.name().equals(UNRECOGNIZED))
                             .findFirst());
        this.values = lazyCache(
                () -> EnumSet.allOf(enumClass)
                             .stream()
                             .filter(val -> !val.name().equals(UNRECOGNIZED))
                             .collect(toList()));
        this.idMap = lazyCache(
                () -> allValues()
                        .stream()
                        .collect(toMap(ProtocolMessageEnum::getNumber)));
        this.nameMap = lazyCache(
                () -> enumDescriptor.getValues()
                                    .stream()
                                    .collect(toMap(Descriptors.EnumValueDescriptor::getName,
                                                   v -> valueForNumber(v.getNumber()))));
        this.valueMap = lazyCache(
                () -> enumDescriptor.getValues()
                                    .stream()
                                    .collect(toMap(Function.identity(),
                                                   v -> valueForNumber(v.getNumber()))));
        this.defaultValue = lazyCache(() -> {
            if (protoDescriptor.getFile().getSyntax() == Descriptors.FileDescriptor.Syntax.PROTO3) {
                return Optional.ofNullable(idMap.get().getOrDefault(0, getUnrecognized()));
            }
            return Optional.empty();
        });
    }

    /**
     * @return The full type name.
     */
    public String getTypeName() {
        return protoDescriptor.getFullName();
    }

    /**
     * @return The enum type class.
     */
    public Class<E> getEnumClass() {
        return enumClass;
    }

    /**
     * @return The enum type descriptor.
     */
    public Descriptors.EnumDescriptor getProtoDescriptor() {
        return protoDescriptor;
    }

    /**
     * @return The default value for the enum, or null if no default value.
     */
    public Optional<E> getDefaultValue() {
        return defaultValue.get();
    }

    /**
     * @return The unrecognized value, or null if no such values exists.
     */
    public E getUnrecognized() {
        return unrecognized.get().orElse(null);
    }

    /**
     * @return List of all values in declared order.
     */
    public List<E> allValues() {
        return values.get();
    }

    /**
     * @param id Value to look up enum from.
     * @return Enum if found, null otherwise.
     */
    public E findByNumber(Integer id) {
        if (id == null) {
            return null;
        }
        return idMap.get().get(id);
    }

    /**
     * @param name Name to look up enum from.
     * @return Enum if found, null otherwise.
     */
    public E findByName(String name) {
        if (name == null) {
            return null;
        }
        return nameMap.get().get(name);
    }

    /**
     * @param enumValue The enum value descriptor.
     * @return The enum value matching the descriptor.
     */
    public E findByValue(Descriptors.EnumValueDescriptor enumValue) {
        if (enumValue == null) {
            return null;
        }
        return valueMap.get().get(enumValue);
    }

    /**
     * @param id Value to look up enum from.
     * @return The enum value.
     * @throws IllegalArgumentException If value not found.
     */
    public E valueForNumber(int id) {
        return Optional.ofNullable(idMap.get().get(id))
                       .orElseThrow(() -> new IllegalArgumentException(
                               "No " + getTypeName() + " value for number " + id));
    }

    /**
     * @param name Name to look up enum from.
     * @return The enum value.
     * @throws IllegalArgumentException If value not found.
     */
    public E valueForName(String name) {
        requireNonNull(name, "name == null");
        return Optional.ofNullable(nameMap.get().get(name))
                       .orElseThrow(() -> new IllegalArgumentException(
                               "No " + getTypeName() + " value for name '" + javaEscape(name) + "'"));
    }

    /**
     * @param value Value to look up enum from.
     * @return The enum value.
     * @throws IllegalArgumentException If value not found.
     */
    public E valueFor(Descriptors.EnumValueDescriptor value) {
        requireNonNull(value, "value == null");
        return Optional.ofNullable(valueMap.get().get(value))
                       .orElseThrow(() -> new IllegalArgumentException(
                               "No " + getTypeName() + " value for " + value.getFullName()));
    }

    // --- Object ---

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ProtoEnum)) {
            return false;
        }
        ProtoEnum<?> that = (ProtoEnum<?>) o;
        return allValues().equals(that.allValues());
    }

    @Override
    public int hashCode() {
        return Objects.hash(super.hashCode(), allValues(), getDefaultValue());
    }

    @Override
    public String toString() {
        return getTypeName() + "{" +
               allValues().stream().map(e -> e.getValueDescriptor().getName() + "=" + e.getNumber())
                          .collect(Collectors.joining(",")) + "}";
    }

    // --- Static Utils ---

    /**
     * @param instance An enum value.
     * @param <E>      The enum value type.
     * @return The instance case to the enum value type.
     */
    @SuppressWarnings("unchecked")
    public static <E extends Enum<E> & ProtocolMessageEnum>
    E requireProtoEnum(Object instance) {
        requireNonNull(instance, "instance == null");
        if (instance instanceof Enum && instance instanceof ProtocolMessageEnum) {
            return (E) instance;
        }
        throw new IllegalArgumentException("Not a proto enum " + instance);
    }

    /**
     * @param type A java class.
     * @param <E>  The enum type.
     * @return The class cast to the proto enum class.
     */
    @SuppressWarnings("unchecked")
    public static <E extends Enum<E> & ProtocolMessageEnum>
    Class<E> requireProtoEnumClass(Class<?> type) {
        requireNonNull(type, "type == null");
        if (isProtoEnumClass(type)) {
            return (Class<E>) type;
        }
        throw new IllegalArgumentException("Not a proto enum type: " +
                                           type.getName().replaceAll("\\$", "."));
    }

    /**
     * @param type A java class.
     * @return True if the class is a java proto enum class.
     */
    public static boolean isProtoEnumClass(Class<?> type) {
        if (type == null) {
            return false;
        }
        return Enum.class.isAssignableFrom(type) && ProtocolMessageEnum.class.isAssignableFrom(type);
    }

    /**
     * @param value An enum value.
     * @param <E>   The enum value type.
     * @return The enum descriptor helper.
     */
    @SuppressWarnings("unchecked")
    public static <E extends Enum<E> & ProtocolMessageEnum>
    ProtoEnum<E> getEnumDescriptor(E value) {
        requireNonNull(value, "value == null");
        var enumType = value.getDeclaringClass();
        return (ProtoEnum<E>) enumTypeFromClass.computeIfAbsent(
                enumType, t -> {
                    ProtoEnum<?> descriptor = new ProtoEnum<>(enumType, value.getDescriptorForType());
                    classFromDescriptor.put(value.getDescriptorForType(), enumType);
                    return descriptor;
                });
    }

    /**
     * @param protoDescriptor An enum descriptor.
     * @param <E>             The enum value type.
     * @return The enum descriptor helper.
     */
    @SuppressWarnings("unchecked")
    public static <E extends Enum<E> & ProtocolMessageEnum>
    ProtoEnum<E> getEnumDescriptor(Descriptors.EnumDescriptor protoDescriptor) {
        return (ProtoEnum<E>) enumTypeFromClass.computeIfAbsent(
                classFromDescriptor.computeIfAbsent(protoDescriptor, a -> getInstanceClass(protoDescriptor)),
                type -> new ProtoEnum<>((Class<E>) type, protoDescriptor));
    }


    /**
     * @param enumType An enum java class.
     * @param <E>      The enum value type.
     * @return The enum descriptor helper.
     */
    @SuppressWarnings("unchecked")
    public static <E extends Enum<E> & ProtocolMessageEnum>
    ProtoEnum<E> getEnumDescriptorUnchecked(Class<?> enumType) {
        return getEnumDescriptor((Class<E>) requireProtoEnumClass(enumType));
    }

    /**
     * @param enumType An enum java class.
     * @param <E>      The enum value type.
     * @return The enum descriptor helper.
     */
    @SuppressWarnings("unchecked")
    public static <E extends Enum<E> & ProtocolMessageEnum>
    ProtoEnum<E> getEnumDescriptor(Class<E> enumType) {
        if (isProtoEnumClass(enumType)) {
            return (ProtoEnum<E>) enumTypeFromClass.computeIfAbsent(
                    enumType,
                    t -> {
                        Descriptors.EnumDescriptor protoDescriptor =
                                getSingleton(enumType, Descriptors.EnumDescriptor.class, "getDescriptor");
                        var descriptor = new ProtoEnum<>(enumType, protoDescriptor);
                        classFromDescriptor.put(protoDescriptor, enumType);
                        return descriptor;
                    });
        }
        throw new IllegalArgumentException("Not a proto enum type: " +
                                           enumType.getName().replaceAll("\\$", "."));
    }

    private static final Map<Class<?>, ProtoEnum<?>>               enumTypeFromClass
            = new ConcurrentHashMap<>();
    private static final Map<Descriptors.EnumDescriptor, Class<?>> classFromDescriptor
            = new ConcurrentHashMap<>();
}