ReflectionUtil.java

/*
 * Copyright 2022 Proto Utils Authors
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.proto.utils;

import com.google.protobuf.DescriptorProtos;
import com.google.protobuf.Descriptors;
import net.morimekta.strings.NamingUtil;

import java.io.File;
import java.lang.reflect.InaccessibleObjectException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static net.morimekta.strings.NamingUtil.Format.PASCAL;

/**
 * Utilities regarding reflection.
 */
public final class ReflectionUtil {
    /**
     * Get a singleton value using a static getter method on the type.
     *
     * @param type            The type to call getter on.
     * @param out             The expected output type.
     * @param singletonGetter The singleton getter method name.
     * @param <T>             The output type.
     * @return The singleton.
     */
    public static <T> T getSingleton(Class<?> type, Class<T> out, String singletonGetter) {
        try {
            Method getter = getMethod(type, singletonGetter);
            if ((getter.getModifiers() & Modifier.STATIC) == 0) {
                throw new IllegalArgumentException(
                        "Non-static singleton getter: " +
                        type.getSimpleName() + "." + singletonGetter + "()");
            }
            Object instance;
            if (!getter.canAccess(null)) {
                // TODO: This is needed to access message types in unnamed modules.
                try {
                    // It will be blocked if it is not public in some way, either the class or the method.
                    getter.setAccessible(true);
                } catch (InaccessibleObjectException e) {
                    throw new IllegalArgumentException(
                            "Inaccessible singleton " + type.getSimpleName() + "." + singletonGetter + "()",
                            e);
                }
                instance = getter.invoke(null);
                getter.setAccessible(false);
            } else {
                instance = getter.invoke(null);
            }
            if (instance == null) {
                throw new IllegalArgumentException(
                        "Null singleton value: " +
                        type.getSimpleName() + "." + singletonGetter + "()");
            }
            if (!out.isAssignableFrom(instance.getClass())) {
                throw new IllegalArgumentException(
                        "Invalid singleton value: " +
                        instance.getClass().getSimpleName() + " " +
                        type.getSimpleName() + "." + singletonGetter + "() not assignable to " + out.getName());
            }
            return out.cast(instance);
        } catch (NoSuchMethodException e) {
            throw new IllegalArgumentException("No such singleton getter: " + type.getSimpleName() + "." + singletonGetter + "()");
        } catch (SecurityException | IllegalAccessException | InvocationTargetException e) {
            throw new IllegalArgumentException("Invalid singleton getter: " + type.getSimpleName(), e);
        }
    }

    /**
     * Get the class of the instance for the generic proto descriptor.
     *
     * @param descriptor The protobuf descriptor.
     * @return The instance class.
     */
    public static Class<?> getInstanceClass(Descriptors.EnumDescriptor descriptor) {
        return getInstanceClassInternal(descriptor);
    }

    /**
     * Get the class of the instance for the generic proto descriptor.
     *
     * @param descriptor The protobuf descriptor.
     * @return The instance class.
     */
    public static Class<?> getInstanceClass(Descriptors.Descriptor descriptor) {
        return getInstanceClassInternal(descriptor);
    }

    // -------------------------------------------------------------------------
    // -------------------------------------------------------------------------
    // ------                                                             ------
    // ------                    PRIVATE METHODS                          ------
    // ------                                                             ------
    // -------------------------------------------------------------------------
    // -------------------------------------------------------------------------

    private static Class<?> getInstanceClassInternal(Descriptors.GenericDescriptor descriptor) {
        String className = getInstanceClassName(descriptor);
        try {
            return ReflectionUtil.class.getClassLoader().loadClass(className);
        } catch (ClassNotFoundException e) {
            throw new IllegalArgumentException("No generated class for " + descriptor.getFullName(), e);
        }
    }

    private static String getInstanceClassName(Descriptors.GenericDescriptor descriptor) {
        Objects.requireNonNull(descriptor, "descriptor == null");
        Descriptors.FileDescriptor fileDescriptor = descriptor.getFile();
        DescriptorProtos.FileOptions fileOptions = fileDescriptor.getOptions();
        StringBuilder nameBuilder = new StringBuilder();
        if (fileOptions.hasJavaPackage()) {
            nameBuilder.append(fileOptions.getJavaPackage()).append(".");
        }
        if (!fileOptions.getJavaMultipleFiles()) {
            if (fileOptions.hasJavaOuterClassname()) {
                nameBuilder.append(fileOptions.getJavaOuterClassname());
            } else {
                Matcher nameMatcher = PROTO_NAME.matcher(fileDescriptor.getName());
                if (!nameMatcher.find()) {
                    throw new IllegalArgumentException(
                            "Unrecognizable proto file for class: " + fileDescriptor.getName());
                }
                nameBuilder.append(NamingUtil.format(nameMatcher.group("name"), PASCAL));
            }
            nameBuilder.append("$");
        }
        if (descriptor instanceof Descriptors.Descriptor) {
            Descriptors.Descriptor md = (Descriptors.Descriptor) descriptor;
            handleContainingType(md.getContainingType(), nameBuilder);
        } else if (descriptor instanceof Descriptors.EnumDescriptor) {
            Descriptors.EnumDescriptor ed = (Descriptors.EnumDescriptor) descriptor;
            handleContainingType(ed.getContainingType(), nameBuilder);
        }
        nameBuilder.append(NamingUtil.format(descriptor.getName(), PASCAL));
        return nameBuilder.toString();
    }

    private static void handleContainingType(Descriptors.Descriptor containing, StringBuilder nameBuilder) {
        if (containing == null) {
            return;
        }
        handleContainingType(containing.getContainingType(), nameBuilder);
        nameBuilder.append(NamingUtil.format(containing.getName(), PASCAL)).append("$");
    }

    private static Method getMethod(Class<?> type, String singletonGetter) throws NoSuchMethodException {
        try {
            return type.getDeclaredMethod(singletonGetter);
        } catch (NoSuchMethodException e) {
            return type.getMethod(singletonGetter);
        }
    }

    private ReflectionUtil() {
    }

    private static final Pattern PROTO_NAME = Pattern.compile(
            "^(.*" + Pattern.quote(File.pathSeparator) + ")?" +
            "(?<name>[-._a-z0-9]*)[.](protobuf|proto|pb)$", Pattern.CASE_INSENSITIVE);
}