ProvidenceModelConverter.java

package net.morimekta.providence.jax.rs;

import com.fasterxml.jackson.databind.type.SimpleType;
import io.swagger.v3.core.converter.AnnotatedType;
import io.swagger.v3.core.converter.ModelConverter;
import io.swagger.v3.core.converter.ModelConverterContext;
import io.swagger.v3.oas.models.media.ArraySchema;
import io.swagger.v3.oas.models.media.BooleanSchema;
import io.swagger.v3.oas.models.media.IntegerSchema;
import io.swagger.v3.oas.models.media.MapSchema;
import io.swagger.v3.oas.models.media.NumberSchema;
import io.swagger.v3.oas.models.media.ObjectSchema;
import io.swagger.v3.oas.models.media.Schema;
import io.swagger.v3.oas.models.media.StringSchema;
import net.morimekta.providence.PEnumValue;
import net.morimekta.providence.PMessage;
import net.morimekta.providence.PMessageVariant;
import net.morimekta.providence.PType;
import net.morimekta.providence.descriptor.PDeclaredDescriptor;
import net.morimekta.providence.descriptor.PDescriptor;
import net.morimekta.providence.descriptor.PEnumDescriptor;
import net.morimekta.providence.descriptor.PField;
import net.morimekta.providence.descriptor.PInterfaceDescriptor;
import net.morimekta.providence.descriptor.PList;
import net.morimekta.providence.descriptor.PMap;
import net.morimekta.providence.descriptor.PMessageDescriptor;
import net.morimekta.providence.descriptor.PRequirement;
import net.morimekta.providence.descriptor.PSet;
import net.morimekta.providence.jax.rs.schema.AllOfSchema;
import net.morimekta.providence.jax.rs.schema.CompactObjectSchema;
import net.morimekta.providence.jax.rs.schema.OneOfSchema;
import net.morimekta.providence.serializer.json.JsonCompactibleDescriptor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import java.lang.reflect.Field;
import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Set;
import java.util.stream.Collectors;

import static net.morimekta.util.collect.UnmodifiableList.listOf;
import static net.morimekta.util.collect.UnmodifiableSet.setOf;

/**
 * Model converter for Providence generated messages and enums. This class
 * requires {@link io.swagger.v3.core} to work. Set it up by adding
 * dependencies to your <code>pom.xml</code>:
 *
 * <pre>{@code
 *   <dependency>
 *       <groupId>org.glassfish.jersey.containers</groupId>
 *       <artifactId>jersey-container-jetty-servlet</artifactId>
 *       <version>${jersey.version}</version>
 *   </dependency>
 *   <dependency>
 *       <groupId>org.glassfish.jersey.containers</groupId>
 *       <artifactId>jersey-container-jetty-http</artifactId>
 *       <version>${jersey.version}</version>
 *   </dependency>
 *   <dependency>
 *       <groupId>org.glassfish.jersey.inject</groupId>
 *       <artifactId>jersey-hk2</artifactId>
 *       <version>${jersey.version}</version>
 *   </dependency>
 *   <dependency>
 *       <groupId>io.swagger.core.v3</groupId>
 *       <artifactId>swagger-core</artifactId>
 *       <version>${swagger.version}</version>
 *   </dependency>
 *   <dependency>
 *       <groupId>io.swagger.core.v3</groupId>
 *       <artifactId>swagger-models</artifactId>
 *       <version>${swagger.version}</version>
 *   </dependency>
 *   <dependency>
 *       <groupId>io.swagger.core.v3</groupId>
 *       <artifactId>swagger-integration</artifactId>
 *       <version>${swagger.version}</version>
 *   </dependency>
 *   <dependency>
 *       <groupId>io.swagger.core.v3</groupId>
 *       <artifactId>swagger-jaxrs2</artifactId>
 *       <version>${swagger.version}</version>
 *   </dependency>
 *   <dependency>
 *       <groupId>net.morimekta.providence</groupId>
 *       <artifactId>providence-jax-rs</artifactId>
 *       <version>${providence.version}</version>
 *   </dependency>
 * }</pre>
 *
 * And a method on the <code>ResourceConfig</code> implementation.
 *
 * <pre>{@code
 * public class TestConfig extends ResourceConfig {
 *     public TestConfig() {
 *         setApplicationName("providence-test");
 *         register(ProvidenceFeature.class);
 *     }
 *
 *     private void addOpenApiSpec() {
 *         OpenAPI oas = new OpenAPI();
 *         Info info = new Info()
 *                 .title("OpenAPI Test Service")
 *                 .description("Testing / Example service for providence + openapi.")
 *                 .contact(new Contact().email("[email protected]"))
 *                 .version("1.0.1");
 *
 *         oas.info(info);
 *         oas.servers(listOf(
 *                 new Server().url("https://morimekta.net/test").description("Testing Service"),
 *                 new Server().url("https://morimekta.net/prod").description("Production Service")));
 *
 *         OpenApiResource openApiResource = new OpenApiResource();
 *         SwaggerConfiguration oasConfig = new SwaggerConfiguration()
 *                 .openAPI(oas)
 *                 .prettyPrint(true)
 *                 .modelConverterClasses(setOf(ProvidenceModelConverter.class.getName()))
 *                 .resourcePackages(Set.of("net.morimekta.api.resources"));
 *         openApiResource.setOpenApiConfiguration(oasConfig);
 *         register(openApiResource);
 *     }
 * }
 * }</pre>
 */
public class ProvidenceModelConverter implements ModelConverter {
    private static final Logger LOGGER = LoggerFactory.getLogger(ProvidenceModelConverter.class);

    private final String  $refPrefix;
    private final boolean complexUnions;
    private final boolean allowCompact;
    private final boolean allowEnumId;

    /** Simple model converter for OpenAPI. */
    @SuppressWarnings("unused")
    public ProvidenceModelConverter() {
        this("#/components/schemas/", false, false, false);
    }

    public ProvidenceModelConverter(@Nonnull String $refPrefix, boolean complexUnions, boolean allowCompact, boolean allowEnumId) {
        this.$refPrefix = $refPrefix;
        this.complexUnions = complexUnions;
        this.allowCompact = allowCompact;
        this.allowEnumId = allowEnumId;
    }

    @Override
    @SuppressWarnings("rawtypes")
    public Schema resolve(AnnotatedType type, ModelConverterContext context, Iterator<ModelConverter> chain) {
        try {
            if (type.getType() instanceof Class) {
                Class<?> rawClass = (Class<?>) type.getType();
                if (PMessage.class.isAssignableFrom(rawClass)) {
                    Field              kDescriptor = rawClass.getDeclaredField("kDescriptor");
                    PMessageDescriptor descriptor  = (PMessageDescriptor) kDescriptor.get(null);
                    return resolveMessage(context, descriptor);
                } else if (rawClass.isInstance(PEnumValue.class)) {
                    Field           kDescriptor = rawClass.getDeclaredField("kDescriptor");
                    PEnumDescriptor descriptor  = (PEnumDescriptor) kDescriptor.get(null);
                    return resolveEnum(context, descriptor);
                }
            } else if (type.getType() instanceof SimpleType) {
                Class<?> rawClass = ((SimpleType) type.getType()).getRawClass();
                if (PMessage.class.isAssignableFrom(rawClass)) {
                    Field              kDescriptor = rawClass.getDeclaredField("kDescriptor");
                    PMessageDescriptor descriptor  = (PMessageDescriptor) kDescriptor.get(null);
                    return resolveMessage(context, descriptor);
                } else if (rawClass.isInstance(PEnumValue.class)) {
                    Field           kDescriptor = rawClass.getDeclaredField("kDescriptor");
                    PEnumDescriptor descriptor  = (PEnumDescriptor) kDescriptor.get(null);
                    return resolveEnum(context, descriptor);
                }
            } else if (type.getType() instanceof PDeclaredDescriptor) {
                PDeclaredDescriptor descriptor = (PDeclaredDescriptor) type.getType();
                if (descriptor.getType() == PType.MESSAGE) {
                    return resolveMessage(context, (PMessageDescriptor<?>) descriptor);
                } else if (descriptor.getType() == PType.ENUM) {
                    return resolveEnum(context, (PEnumDescriptor) descriptor);
                }
            }
        } catch (Exception e) {
            LOGGER.error("Failed to resolve type: {}", e.getMessage(), e);
        }

        if (chain.hasNext()) {
            return chain.next().resolve(type, context, chain);
        } else {
            return null;
        }
    }

    @SuppressWarnings("rawtypes")
    private Schema resolveEnum(ModelConverterContext context, PEnumDescriptor<?> descriptor) {
        if (context.getDefinedModels().containsKey(descriptor.getQualifiedName())) {
            return new Schema().$ref($refPrefix + descriptor.getQualifiedName());
        }

        StringSchema schema = new StringSchema();
        Schema outSchema = schema;
        schema.setName(descriptor.getQualifiedName());
        for (PEnumValue value : descriptor.getValues()) {
            schema.addEnumItem(value.asString());
        }

        if (allowEnumId) {
            NumberSchema idSchema = new NumberSchema();
            idSchema.setName(descriptor.getQualifiedName());
            for (PEnumValue value : descriptor.getValues()) {
                idSchema.addEnumItem(new BigDecimal(value.asInteger()));
            }
            outSchema = new OneOfSchema().oneOf(listOf(schema, idSchema)).name(descriptor.getQualifiedName());
        }

        context.defineModel(descriptor.getQualifiedName(), outSchema);
        return new Schema().$ref($refPrefix + descriptor.getQualifiedName());
    }

    @SuppressWarnings("rawtypes")
    private Schema resolveMessage(ModelConverterContext context, PMessageDescriptor<?> descriptor) {
        if (context.getDefinedModels().containsKey(descriptor.getQualifiedName())) {
            return new Schema().$ref($refPrefix + descriptor.getQualifiedName());
        }
        context.defineModel(descriptor.getQualifiedName(),
                            new ObjectSchema().name(descriptor.getQualifiedName()));

        ObjectSchema objectSchema = new ObjectSchema();
        Schema       outSchema    = objectSchema;
        outSchema.setName(descriptor.getQualifiedName());

        if (complexUnions && descriptor.getVariant() == PMessageVariant.UNION) {
            OneOfSchema unionSchema = new OneOfSchema();
            unionSchema.setName(descriptor.getQualifiedName());

            for (PField field : descriptor.getFields()) {
                if (field.getName().startsWith("__")) continue;
                if (field.getDescriptor().getName().startsWith("__")) continue;

                objectSchema = new ObjectSchema();
                objectSchema.setName(descriptor.getQualifiedName() + "." + field.getName());
                objectSchema.addProperties(field.getName(), resolveProperty(context, field.getDescriptor()));
                objectSchema.addRequiredItem(field.getName());
                objectSchema.additionalProperties(false);

                unionSchema.getOneOf().add(objectSchema);
            }

            outSchema = unionSchema;
        } else {
            PInterfaceDescriptor iFace        = descriptor.getImplementing();
            Set<String>          iFaceMethods = setOf();
            if (iFace != null && descriptor.getVariant() == PMessageVariant.STRUCT) {
                outSchema = new AllOfSchema()
                        .allOf(listOf(new Schema().$ref($refPrefix + iFace.getQualifiedName()),
                                      objectSchema))
                        .name(descriptor.getQualifiedName());
                iFaceMethods = Arrays.stream(iFace.getFields()).map(PField::getName).collect(Collectors.toSet());
                resolveMessage(context, iFace);
            }

            for (PField field : descriptor.getFields()) {
                if (field.getName().startsWith("__")) continue;
                if (field.getDescriptor().getName().startsWith("__")) continue;
                if (iFaceMethods.contains(field.getName())) continue;

                objectSchema.addProperties(field.getName(), resolveProperty(context, field.getDescriptor()));
                if (field.getRequirement() == PRequirement.REQUIRED) {
                    objectSchema.addRequiredItem(field.getName());
                }
            }
        }

        if (allowCompact &&
            descriptor instanceof JsonCompactibleDescriptor &&
            ((JsonCompactibleDescriptor) descriptor).isJsonCompactible()) {
            CompactObjectSchema compact  = new CompactObjectSchema();
            int                 minItems = 0;
            for (PField field : descriptor.getFields()) {
                if (field.getRequirement() == PRequirement.REQUIRED) {
                    ++minItems;
                }
                compact.getItems().add(resolveProperty(context, field.getDescriptor()));
            }
            if (minItems > 0) {
                compact.setMinItems(minItems);
            }

            outSchema = new OneOfSchema().oneOf(listOf(outSchema, compact)).name(descriptor.getQualifiedName());
        }

        context.defineModel(descriptor.getQualifiedName(), outSchema);
        return new Schema().$ref($refPrefix + descriptor.getQualifiedName());
    }

    @SuppressWarnings("rawtypes")
    private Schema resolveProperty(ModelConverterContext context, PDescriptor descriptor) {
        switch (descriptor.getType()) {
            case ENUM: {
                return resolveEnum(context, (PEnumDescriptor) descriptor);
            }
            case MESSAGE: {
                return resolveMessage(context, (PMessageDescriptor<?>) descriptor);
            }
            case LIST: {
                PList list = (PList) descriptor;
                return new ArraySchema().items(resolveProperty(context, list.itemDescriptor()));
            }
            case SET: {
                PSet set = (PSet) descriptor;
                return new ArraySchema().items(resolveProperty(context, set.itemDescriptor())).uniqueItems(true);
            }
            case MAP: {
                PMap map = (PMap) descriptor;
                resolveProperty(context, map.keyDescriptor());  // just in case the key type needs to be defined.
                return new MapSchema()
                        .additionalProperties(resolveProperty(context, map.itemDescriptor()));
            }
            case BOOL:
                return new BooleanSchema();
            case BYTE:
                return new IntegerSchema().format("byte");
            case I16:  // no Json-Schema format defined for 16-bit integer.
            case I32:
                return new IntegerSchema();
            case I64:
                return new IntegerSchema().format("int64");
            case DOUBLE:
                return new NumberSchema().format("double");
            case BINARY:
                return new StringSchema().format("base64");
            case STRING:
                return new StringSchema();
            default:
                throw new IllegalArgumentException("Unhandled OpenAPI type: " + descriptor.getQualifiedName());
        }
    }
}