MessageGenerator.java

package net.morimekta.providence.testing.generator;

import net.morimekta.providence.PMessage;
import net.morimekta.providence.PMessageBuilder;
import net.morimekta.providence.PMessageOrBuilder;
import net.morimekta.providence.PMessageVariant;
import net.morimekta.providence.descriptor.PField;
import net.morimekta.providence.descriptor.PMessageDescriptor;
import net.morimekta.providence.descriptor.PRequirement;
import net.morimekta.util.collect.UnmodifiableList;

import javax.annotation.Nonnull;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;

/**
 * Message generator for generating a specific message.
 *
 * @param <Context> The generator context.
 * @param <Message> The message type to be generated.
 * @param <MessageOrBuilder> The message or builder interface.
 */
public class MessageGenerator<
        Context extends GeneratorContext<Context>,
        Message extends PMessage<Message>,
        MessageOrBuilder extends PMessageOrBuilder<Message>>
        implements Generator<Context, Message> {
    private final PMessageDescriptor<Message>                                                        descriptor;
    private final Map<PField<Message>, MessageFieldGenerator<Message, MessageOrBuilder, Context, ?>> fieldValueGenerators;
    private final Map<PField<Message>, Boolean>                                                      fieldPresenceOverrides;
    private final List<PField<Message>>                                                              fieldOrder;

    public MessageGenerator(PMessageDescriptor<Message> descriptor) {
        this.descriptor = descriptor;
        this.fieldValueGenerators = new HashMap<>();
        this.fieldPresenceOverrides = new HashMap<>();
        this.fieldOrder = new ArrayList<>(UnmodifiableList.copyOf(descriptor.getFields()));
    }

    private MessageGenerator(MessageGenerator<Context, Message, MessageOrBuilder> parent) {
        this.descriptor = parent.descriptor;
        this.fieldValueGenerators = new HashMap<>(parent.fieldValueGenerators);
        this.fieldPresenceOverrides = new HashMap<>(parent.fieldPresenceOverrides);
        this.fieldOrder = new ArrayList<>(parent.fieldOrder);
    }

    public MessageGenerator<Context, Message, MessageOrBuilder> deepCopy() {
        return new MessageGenerator<>(this);
    }

    /**
     * Set which fields must always be generated. Note that unions only
     * allow for one single field to be always present.
     *
     * @param fields The fields that must be generated for each instance.
     * @return The message generator.
     */
    @SafeVarargs
    public final MessageGenerator<Context, Message, MessageOrBuilder> setAlwaysPresent(PField<Message>... fields) {
        return setAlwaysPresent(Arrays.asList(fields));
    }

    /**
     * Set which fields must always be generated. Note that unions only
     * allow for one single field to be always present.
     *
     * @param fields The fields that must be generated for each instance.
     * @return The message generator.
     */
    public MessageGenerator<Context, Message, MessageOrBuilder> setAlwaysPresent(@Nonnull Collection<PField<Message>> fields) {
        for (PField<Message> field : fields) {
            fieldPresenceOverrides.put(field, Boolean.TRUE);
        }
        return this;
    }

    /**
     * Set which fields must never be generated. If the message is a union
     * then these fields will not be selected when getting a random field
     * to get value for.
     *
     * @param fields The fields that should always be absent.
     * @return The message generator.
     */
    @SafeVarargs
    public final MessageGenerator<Context, Message, MessageOrBuilder> setAlwaysAbsent(PField<Message>... fields) {
        return setAlwaysAbsent(Arrays.asList(fields));
    }

    /**
     * Set which fields must never be generated. If the message is a union
     * then these fields will not be selected when getting a random field
     * to get value for.
     *
     * @param fields The fields that should always be absent.
     * @return The message generator.
     */
    public MessageGenerator<Context, Message, MessageOrBuilder> setAlwaysAbsent(@Nonnull Collection<PField<Message>> fields) {
        for (PField<Message> field : fields) {
            fieldPresenceOverrides.put(field, Boolean.FALSE);
        }
        return this;
    }

    /**
     * Set default presence probability based on the default fill rate
     * in the generator options instance.
     *
     * @param fields The fields the should have default presence probability.
     * @return The message generator.
     */
    @SafeVarargs
    public final MessageGenerator<Context, Message, MessageOrBuilder> setDefaultPresence(PField<Message>... fields) {
        return setDefaultPresence(Arrays.asList(fields));
    }

    /**
     * Set default presence probability based on the default fill rate
     * in the generator options instance.
     *
     * @param fields The fields the should have default presence probability.
     * @return The message generator.
     */
    public MessageGenerator<Context, Message, MessageOrBuilder> setDefaultPresence(@Nonnull Collection<PField<Message>> fields) {
        for (PField<Message> field : fields) {
            fieldPresenceOverrides.remove(field);
        }
        return this;
    }

    /**
     * Reset all field presence probabilities to default based on the
     * fill rate of the message generator options.
     *
     * @return The message generator.
     */
    public MessageGenerator<Context, Message, MessageOrBuilder> resetDefaultPresence() {
        fieldPresenceOverrides.clear();
        return this;
    }

    /**
     * Set fields (in order) that should have generated value <b>before</b> all other values.
     *
     * @param fields The fields the should be prioritized
     * @return The message generator.
     */
    @SafeVarargs
    public final MessageGenerator<Context, Message, MessageOrBuilder> setFirstFields(@Nonnull PField<Message>... fields) {
        return setFirstFields(UnmodifiableList.copyOf(fields));
    }

    /**
     * Set fields (in order) that should have generated value <b>before</b> all other values.
     *
     * @param fields The fields the should be prioritized
     * @return The message generator.
     */
    public MessageGenerator<Context, Message, MessageOrBuilder> setFirstFields(@Nonnull Collection<PField<Message>> fields) {
        fieldOrder.removeAll(fields);
        fieldOrder.addAll(0, fields);
        return this;
    }

    /**
     * Set fields (in order) that should have generated value <b>after</b> all other values.
     *
     * @param fields The fields the should have default presence probability.
     * @return The message generator.
     */
    @SafeVarargs
    public final MessageGenerator<Context, Message, MessageOrBuilder> setLastFields(@Nonnull PField<Message>... fields) {
        return setLastFields(UnmodifiableList.copyOf(fields));
    }

    /**
     * Set fields (in order) that should have generated value <b>after</b> all other values.
     *
     * @param fields The fields the should have default presence probability.
     * @return The message generator.
     */
    public MessageGenerator<Context, Message, MessageOrBuilder> setLastFields(@Nonnull Collection<PField<Message>> fields) {
        fieldOrder.removeAll(fields);
        fieldOrder.addAll(fields);
        return this;
    }

    public MessageGenerator<Context, Message, MessageOrBuilder> setValueGenerator(@Nonnull PField<Message> field,
                                                                                  @Nonnull Generator<Context, ?> generator) {
        return setFieldGenerator(field, new MessageFieldGenerator.Wrapper<>(generator));
    }

    public MessageGenerator<Context, Message, MessageOrBuilder> setFieldGenerator(@Nonnull PField<Message> field,
                                                                                  @Nonnull MessageFieldGenerator<Message, MessageOrBuilder, Context, ?> valueGenerator) {
        fieldValueGenerators.put(field, valueGenerator);
        return this;
    }

    public MessageGenerator<Context, Message, MessageOrBuilder> setValueGenerator(@Nonnull PField<Message> field,
                                                                                  @Nonnull Predicate<MessageOrBuilder> predicate,
                                                                                  @Nonnull Generator<Context, ?> generator) {
        return setFieldGenerator(field, predicate, new MessageFieldGenerator.Wrapper<>(generator));
    }


    public
    MessageGenerator<Context, Message, MessageOrBuilder> setFieldGenerator(@Nonnull PField<Message> field,
                                                                           @Nonnull Predicate<MessageOrBuilder> predicate,
                                                                           @Nonnull MessageFieldGenerator<Message, MessageOrBuilder, Context, ?> valueGenerator) {
        return setFieldGenerator(field, new MessageFieldGenerator.Conditional<>(predicate, valueGenerator));
    }

    @Override
    @SuppressWarnings("unchecked")
    public Message generate(Context ctx) {
        PMessageBuilder<Message> builder = descriptor.builder();
        if (descriptor.getVariant() == PMessageVariant.UNION) {
            PField<Message> selectedField = null;
            Set<PField<Message>> blockedFields = new HashSet<>();
            for (Map.Entry<PField<Message>, Boolean> entry : fieldPresenceOverrides.entrySet()) {
                if (entry.getValue()) {
                    if (selectedField != null) {
                        throw new IllegalStateException("More than one required union field");
                    }
                    selectedField = entry.getKey();
                } else {
                    blockedFields.add(entry.getKey());
                }
            }

            // select a random field, and set that, unless the field presence
            // overrides has a single required field, then use that. More than
            // one required field is not allowed with unions.
            if (selectedField == null) {
                ArrayList<PField<Message>> allowed = new ArrayList<>(Arrays.asList(descriptor.getFields()));
                allowed.removeAll(blockedFields);

                if (allowed.size() < 1) {
                    throw new IllegalStateException("No remaining fields allowed after " + blockedFields.size() + " was blocked");
                }

                int idx = ctx.getRandom().nextInt(allowed.size());
                selectedField = allowed.get(idx);
            }
            builder.set(selectedField, makeFieldValue((MessageOrBuilder) builder, selectedField, ctx));
        } else {
            for (PField<Message> field : fieldOrder) {
                if (fieldPresenceOverrides.containsKey(field)) {
                    if (fieldPresenceOverrides.get(field)) {
                        builder.set(field, makeFieldValue((MessageOrBuilder) builder, field, ctx));
                    }
                    continue;
                }

                // Default presence calculation.
                if (field.getRequirement() == PRequirement.REQUIRED || ctx.nextFieldIsPresent()) {
                    builder.set(field, makeFieldValue((MessageOrBuilder) builder, field, ctx));
                }
            }
        }
        Message message = builder.build();
        ctx.addGeneratedMessage(message);
        return message;
    }

    /**
     * When the field is decided to be present, this will generate the
     * actual value based on a simple algorithm.
     *
     * @param field The field to generate for.
     * @param context The context to build the field around.
     * @return The value to be set for the field.
     */
    private Object makeFieldValue(MessageOrBuilder instance, PField<Message> field,
                                  Context context) {
        // This will try to make a field value for the given field regardless of access.
        MessageFieldGenerator<Message, MessageOrBuilder, Context, ?> valueGenerator = fieldValueGenerators.get(field);
        if (valueGenerator == null) {
            return context.generatorForDescriptor(field.getDescriptor()).generate(context);
        }
        return valueGenerator.generate(instance, context);
    }
}