ProgramRegistry.java

package net.morimekta.providence.reflect;

import net.morimekta.providence.descriptor.PDeclaredDescriptor;
import net.morimekta.providence.descriptor.PMessageDescriptor;
import net.morimekta.providence.descriptor.PService;
import net.morimekta.providence.descriptor.PServiceMethod;
import net.morimekta.providence.descriptor.PUnionDescriptor;
import net.morimekta.providence.descriptor.PValueProvider;
import net.morimekta.providence.reflect.contained.CProgram;
import net.morimekta.providence.reflect.model.ProgramDeclaration;
import net.morimekta.providence.types.TypeReference;
import net.morimekta.providence.types.WritableTypeRegistry;
import net.morimekta.util.collect.UnmodifiableList;

import javax.annotation.Nonnull;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/**
 * This is a registry for a single program. Meaning it also handles how
 * includes work and contains meta info about the program it is representing.
 *
 * The registry can reference each other recursively. Each registry will
 * have a specific package context, and needs to be built up in a recursive
 * manner. This way each type registry will only have access to the described
 * types actually referenced and included in the given thrift program file.
 */
public class ProgramRegistry extends WritableTypeRegistry {
    private final String             programContext;
    private       CProgram           program;
    private       ProgramDeclaration programType;

    private final Map<String, TypeReference>          typedefs;
    private final Map<String, PDeclaredDescriptor<?>> declaredTypes;
    private final Map<String, PService>               services;
    private final Map<String, PValueProvider<?>>         constants;
    private final Map<String, ProgramRegistry>        includes;

    ProgramRegistry(@Nonnull String programContext) {
        this.programContext = programContext;
        this.declaredTypes = new LinkedHashMap<>();
        this.services = new LinkedHashMap<>();
        this.includes = new LinkedHashMap<>();
        this.typedefs = new LinkedHashMap<>();
        this.constants = new LinkedHashMap<>();
    }

    public String getProgramContext() {
        return programContext;
    }

    public CProgram getProgram() {
        return program;
    }

    public ProgramDeclaration getProgramType() {
        return programType;
    }

    @Nonnull
    public Optional<ProgramRegistry> getRegistry(@Nonnull String program) {
        if (programContext.equals(program)) {
            return Optional.of(this);
        }
        return Optional.ofNullable(includes.get(program));
    }

    // TypeRegistry

    @Nonnull
    @Override
    public Optional<PDeclaredDescriptor<?>> getDeclaredType(@Nonnull TypeReference reference) {
        reference = finalTypeReference(reference);
        if (programContext.equals(reference.programName)) {
            return Optional.ofNullable(declaredTypes.get(reference.typeName));
        }
        if (includes.containsKey(reference.programName)) {
            return includes.get(reference.programName).getDeclaredType(reference);
        }
        return Optional.empty();
    }

    @Nonnull
    @Override
    public Optional<PService> getService(@Nonnull TypeReference reference) {
        if (includes.containsKey(reference.programName)) {
            return includes.get(reference.programName).getService(reference);
        }
        return Optional.ofNullable(services.get(reference.typeName));
    }

    @Nonnull
    @Override
    @SuppressWarnings("unchecked")
    public <T> Optional<T> getConstantValue(@Nonnull TypeReference reference) {
        if (includes.containsKey(reference.programName)) {
            return includes.get(reference.programName).getConstantValue(reference);
        }
        return Optional.ofNullable((PValueProvider<T>) constants.get(reference.typeName))
                       .map(PValueProvider::get);
    }

    @Nonnull
    @Override
    public Optional<TypeReference> getTypedef(@Nonnull TypeReference reference) {
        if (programContext.equals(reference.programName)) {
            return Optional.ofNullable(typedefs.get(reference.typeName));
        }
        if (includes.containsKey(reference.programName)) {
            return includes.get(reference.programName).getTypedef(reference);
        }
        return Optional.empty();
    }

    @Override
    public List<PDeclaredDescriptor<?>> getDeclaredTypes() {
        return UnmodifiableList.copyOf(declaredTypes.values());
    }

    @Override
    public boolean isKnownProgram(@Nonnull String program) {
        return programContext.equals(program) || includes.containsKey(program);
    }

    // WritableTypeRegistry

    @Override
    public void registerTypedef(@Nonnull TypeReference reference, @Nonnull TypeReference target) {
        if (programContext.equals(reference.programName)) {
            typedefs.put(reference.typeName, target);
        } else {
            throw new IllegalArgumentException("Unable to register typedef " + reference + " = " + target);
        }
    }

    @Override
    public <T> void registerType(PDeclaredDescriptor<T> declaredType) {
        if (programContext.equals(declaredType.getProgramName())) {
            declaredTypes.put(declaredType.getName(), declaredType);
        } else {
            throw new IllegalArgumentException("Unable to register " + declaredType.getType() + " " + declaredType.getQualifiedName());
        }
    }

    @Override
    public void registerService(@Nonnull PService service) {
        if (programContext.equals(service.getProgramName())) {
            services.put(service.getName(), service);
            for (PServiceMethod method : service.getMethods()) {
                PUnionDescriptor<?> returnType = method.getResponseType();
                if (returnType != null) {
                    registerType(returnType);
                }
                PMessageDescriptor<?> requestType = method.getRequestType();
                registerType(requestType);
            }
        } else {
            throw new IllegalArgumentException("Unable to register service " + service.getQualifiedName());
        }
    }

    @Override
    public void registerConstant(@Nonnull TypeReference reference,
                                 @Nonnull PValueProvider<?> value) {
        if (programContext.equals(reference.programName)) {
            constants.put(reference.typeName, value);
        } else {
            throw new IllegalArgumentException("Unable to register constant " + reference);
        }
    }

    // --- Internal ---

    void setProgram(CProgram program) {
        this.program = program;
    }

    void setProgramType(ProgramDeclaration type) {
        this.programType = type;
    }

    void addInclude(String programName, ProgramRegistry included) {
        includes.put(programName, included);
    }
}