MessageValidation.java

/*
 * Copyright 2019 Providence Authors
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.providence.util;

import net.morimekta.providence.PMessage;
import net.morimekta.providence.PMessageOrBuilder;
import net.morimekta.providence.PType;
import net.morimekta.providence.descriptor.PContainer;
import net.morimekta.providence.descriptor.PField;
import net.morimekta.providence.descriptor.PMap;
import net.morimekta.providence.descriptor.PMessageDescriptor;
import net.morimekta.util.Strings;
import net.morimekta.util.collect.UnmodifiableList;
import net.morimekta.util.collect.UnmodifiableSortedSet;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import static net.morimekta.providence.util.MessageUtil.keyPathAppend;
import static net.morimekta.util.collect.UnmodifiableList.copyOf;
import static net.morimekta.util.collect.UnmodifiableList.listOf;
import static net.morimekta.util.collect.Unmodifiables.asList;

/**
 * Class that handles validation of the structure or content of a message
 * type. This this can do much more fine grained validation than just assigning
 * required fields.
 *
 * @param <M> The message type to be validated.
 * @param <E> The exception to be thrown on validation failure.
 */
public class MessageValidation<
        M extends PMessage<M>,
        E extends Exception> {
    /**
     * Interface for testing some expectation.
     *
     * @param <V> The value type that is tested.
     */
    @FunctionalInterface
    public interface SimpleExpectation<V> extends Expectation<V> {
        /**
         * Test the expectation.
         *
         * @param value The value to be tested.
         * @throws Exception On any unmet expectation or other errors.
         */
        void test(V value) throws Exception;

        /**
         * Test the expectation at a specific path in the structure.
         *
         * @param path  The path to the current location.
         * @param value The value to be tested.
         * @throws Exception On validation failure.
         */
        default void test(String path, V value) throws Exception {
            test(value);
        }
    }

    /**
     * Interface for testing some expectation.
     *
     * @param <V> The value type that is tested.
     */
    @FunctionalInterface
    public interface Expectation<V> {
        /**
         * Test the expectation at a specific path in the structure.
         *
         * @param path  The path to the current location.
         * @param value The value to be tested.
         * @throws Exception On validation failure.
         */
        void test(String path, V value) throws Exception;
    }

    /**
     * Test a field that itself is using another message validation. The
     * field value is <b>only</b> tested if present, so null value allowance check
     * is required in addition to this validation.
     *
     * @param <BaseMessage>  The base or container message type.
     * @param <E>            The exception type thrown.
     */
    public interface ValidationExpectation<
            BaseMessage extends PMessage<BaseMessage>,
            E extends Exception> extends Expectation<BaseMessage> {
        /**
         * Check for validity, and collect as many validation errors as possible.
         *
         * @param path    The path to the base message.
         * @param message The message to test.
         * @return True if valid, otherwise false.
         */
        List<E> validationErrors(String path, BaseMessage message);
    }

    /**
     * Test a field that itself is a message using another message validation. The
     * field value is <b>only</b> tested if present, so null value allowance check
     * is required in addition to this validation.
     *
     * @param <BaseMessage>  The base or container message type.
     * @param <FieldMessage> The field or contained message type.
     * @param <E>            The exception type thrown.
     */
    private static final class MessageValidationExpectation<
            BaseMessage extends PMessage<BaseMessage>,
            FieldMessage extends PMessage<FieldMessage>,
            E extends Exception> implements ValidationExpectation<BaseMessage, E> {
        private final MessageValidation<FieldMessage, E> validation;
        private final PField<BaseMessage>                field;

        public MessageValidationExpectation(PField<BaseMessage> field,
                                            MessageValidation<FieldMessage, E> validation) {
            if (!field.getDescriptor().equals(validation.descriptor)) {
                throw new IllegalArgumentException("Field value type not same as validation");
            }
            this.validation = validation;
            this.field = field;
        }

        @Override
        public void test(String path, BaseMessage message) throws Exception {
            if (message.has(field)) {
                validation.validateInternal(keyPathAppend(path, field), message.get(field));
            }
        }

        @Override
        public List<E> validationErrors(String path, BaseMessage message) {
            if (message.has(field)) {
                return validation.validationErrors(keyPathAppend(path, field), message.get(field));
            }
            return listOf();
        }
    }

    /**
     * Test a field that itself is a map of message values using another message validation.
     * The field value is <b>only</b> tested if present, so null value allowance check
     * is required in addition to this validation.
     *
     * @param <BaseMessage>  The base or container message type.
     * @param <FieldMessage> The field or contained message type.
     * @param <E>            The exception type thrown.
     */
    public static final class MessageMapValidationExpectation<
            BaseMessage extends PMessage<BaseMessage>,
            FieldMessage extends PMessage<FieldMessage>,
            E extends Exception> implements ValidationExpectation<BaseMessage, E> {
        private final MessageValidation<FieldMessage, E> validation;
        private final PField<BaseMessage>                field;

        public MessageMapValidationExpectation(
                PField<BaseMessage> field,
                MessageValidation<FieldMessage, E> validation) {
            if (field.getDescriptor().getType() != PType.MAP) {
                throw new IllegalArgumentException("Field type not a map.");
            }
            PMap<?,?> mapType = (PMap<?, ?>) field.getDescriptor();
            if (!mapType.itemDescriptor().equals(validation.descriptor)) {
                throw new IllegalArgumentException("Field map value type not same as validation.");
            }
            this.validation = validation;
            this.field = field;
        }

        @Override
        public void test(String path, BaseMessage message) throws Exception {
            if (message.has(field)) {
                Map<?, FieldMessage> map = message.get(field);
                String prefix = keyPathAppend(path, field) + "[";
                for (Map.Entry<?, FieldMessage> entry : map.entrySet()) {
                    validation.validateInternal(prefix + entry.getKey().toString() + "]", entry.getValue());
                }
            }
        }

        @Override
        public List<E> validationErrors(String path, BaseMessage message) {
            if (message.has(field)) {
                List<E> exceptions = new ArrayList<>();
                Map<?, FieldMessage> map = message.get(field);
                String prefix = keyPathAppend(path, field) + "[";
                for (Map.Entry<?, FieldMessage> entry : map.entrySet()) {
                    exceptions.addAll(validation.validationErrors(prefix + entry.getKey().toString() + "]", entry.getValue()));
                }
                copyOf(exceptions);
            }
            return listOf();
        }
    }

    /**
     * Test a field that itself is a map of message values using another message validation.
     * The field value is <b>only</b> tested if present, so null value allowance check
     * is required in addition to this validation.
     *
     * @param <BaseMessage>  The base or container message type.
     * @param <FieldMessage> The field or contained message type.
     * @param <E>            The exception type thrown.
     */
    public static final class MessageCollectionValidationExpectation<
            BaseMessage extends PMessage<BaseMessage>,
            FieldMessage extends PMessage<FieldMessage>,
            E extends Exception> implements ValidationExpectation<BaseMessage, E> {
        private final MessageValidation<FieldMessage, E> validation;
        private final PField<BaseMessage>                field;

        public MessageCollectionValidationExpectation(
                PField<BaseMessage> field,
                MessageValidation<FieldMessage, E> validation) {
            if (field.getDescriptor().getType() != PType.LIST &&
                field.getDescriptor().getType() != PType.SET) {
                throw new IllegalArgumentException("Field type not a collection.");
            }
            PContainer<?> containerType = (PMap<?, ?>) field.getDescriptor();
            if (!containerType.itemDescriptor().equals(validation.descriptor)) {
                throw new IllegalArgumentException("Field item value type not same as validation.");
            }
            this.validation = validation;
            this.field = field;
        }

        @Override
        public void test(String path, BaseMessage message) throws Exception {
            if (message.has(field)) {
                Collection<FieldMessage> collection = message.get(field);
                String                   prefix     = keyPathAppend(path, field) + "[";
                int i = 0;
                for (FieldMessage item : collection) {
                    validation.validateInternal(prefix + (i++) + "]", item);
                }
            }
        }

        @Override
        public List<E> validationErrors(String path, BaseMessage message) {
            if (message.has(field)) {
                List<E> exceptions = new ArrayList<>();
                Collection<FieldMessage> collection = message.get(field);
                String                   prefix     = keyPathAppend(path, field) + "[";
                int i = 0;
                for (FieldMessage item : collection) {
                    exceptions.addAll(validation.validationErrors(prefix + (i++) + "]", item));
                }
                copyOf(exceptions);
            }
            return listOf();
        }
    }

    /**
     * Check some value based on a simple predicate.
     *
     * @param <Value> The value type being validated.
     */
    public static final class PredicateExpectation<Value> implements SimpleExpectation<Value> {
        private final Predicate<Value> predicate;
        private final String           failureMessage;

        /**
         * Create a predicate based expectation. The expectation is met if the predicate
         * evaluates to 'true'. And will throw {@link IllegalStateException} if not.
         *
         * @param predicate      The value predicate.
         * @param failureMessage The message of the thrown exception.
         */
        public PredicateExpectation(Predicate<Value> predicate, String failureMessage) {
            this.predicate = predicate;
            this.failureMessage = failureMessage;
        }

        /**
         * Test the value.
         *
         * @param value The value to test.
         * @throws IllegalStateException if the value does not meet the predicate expectation.
         */
        @Override
        public void test(Value value) {
            if (!predicate.test(value)) {
                throw new IllegalStateException(failureMessage);
            }
        }
    }

    /**
     * Validate a message using the built expectations.
     *
     * @param message The message to be validated.
     * @param <ME>    Message or builder type to be validated.
     * @return Message after validation.
     * @throws E On not valid message.
     */
    public <ME extends PMessageOrBuilder<M>> ME validate(ME message) throws E {
        return validate("", message);
    }

    /**
     * Just see if the message is valid or not. Does not throw any exception.
     *
     * @param message The message to be validated.
     * @return True if the message is valid, false otherwise.
     */
    public boolean isValid(PMessageOrBuilder<M> message) {
        if (message == null) {
            return allowNull;
        }
        if (!descriptor.equals(message.descriptor())) {
            return false;
        }
        M toTest = message.toMessage();
        if (expectedMissingFields.stream().anyMatch(field -> toTest.has(field.getId()))) {
            return false;
        }

        if (expectedPresentFields.stream().anyMatch(field -> !toTest.has(field.getId()))) {
            return false;
        }
        for (Expectation<M> predicate : expectations) {
            try {
                predicate.test("", toTest);
            } catch (Exception e) {
                return false;
            }
        }
        return true;
    }

    /**
     * See if the message is valid or not. Collects exceptions to consumers and
     * will try to continue on each failure.
     *
     * @param message The message to be validated.
     * @return True if the message is valid, false otherwise.
     */
    public List<E> validationErrors(PMessageOrBuilder<M> message) {
        return validationErrors("", message);
    }

    /**
     * Create a message validator that throws specific exception on failure.
     *
     * @param descriptor The message type descriptor to be validated.
     * @param onMismatch Function producer for thrown exceptions.
     * @param <M>        Message type.
     * @param <E>        Exception type.
     * @return The message validator builder.
     */
    public static <M extends PMessage<M>, E extends Exception>
    MessageValidation.Builder<M, E> builder(
            @Nonnull PMessageDescriptor<M> descriptor,
            @Nonnull Function<Exception, E> onMismatch) {
        return new Builder<>(descriptor, onMismatch);
    }

    /**
     * Validate a message using the built expectations.
     *
     * @param path    Logical path in structure that is being validated. Used in
     *                failure exceptions.
     * @param message The message or builder to be validated.
     * @param <ME>    Message or builder type to be validated.
     * @return Message after validation.
     * @throws E On not valid message.
     */
    public <ME extends PMessageOrBuilder<M>> ME validate(@Nonnull String path, @Nullable ME message) throws E {
        try {
            return validateInternal(path, message);
        } catch (Exception e) {
            throw onMismatch.apply(e);
        }
    }

    /**
     * See if the message is valid or not. Collects exceptions to consumers and
     * will try to continue on each failure.
     *
     * @param path    Logical path in structure that is being validated. Used in
     *                failure exceptions.
     * @param message The message to be validated.
     * @return True if the message is valid, false otherwise.
     */
    @SuppressWarnings("unchecked")
    public List<E> validationErrors(String path, PMessageOrBuilder<M> message) {
        if (message == null) {
            if (!allowNull) {
                return listOf(onMismatch.apply(new IllegalStateException(
                        pathPrefix(path) + "Null " + descriptor.getQualifiedName() + " value.")));
            }
            return listOf();
        }
        message = message.toMessage();
        if (!descriptor.equals(message.descriptor())) {
            return listOf(onMismatch.apply(new IllegalStateException(
                    pathPrefix(path) + "Validating message of type " + message.descriptor().getQualifiedName() +
                    ", required " + descriptor.getQualifiedName() + ".")));
        }
        List<E> errors = new ArrayList<>();
        M       toTest = message.toMessage();
        if (!expectedMissingFields.isEmpty()) {
            List<String> present = expectedMissingFields.stream()
                                                        .filter(field -> toTest.has(field.getId()))
                                                        .map(PField::getName)
                                                        .collect(Collectors.toList());
            if (!present.isEmpty()) {
                errors.add(onMismatch.apply(new IllegalStateException(
                        pathPrefix(path) + Strings.join(", ", present) + " present on " +
                        descriptor.getQualifiedName())));
            }
        }

        if (!expectedPresentFields.isEmpty()) {
            List<String> missing = expectedPresentFields.stream()
                                                        .filter(field -> !toTest.has(field.getId()))
                                                        .map(PField::getName)
                                                        .collect(Collectors.toList());
            if (!missing.isEmpty()) {
                errors.add(onMismatch.apply(new IllegalStateException(
                        pathPrefix(path) + Strings.join(", ", missing) + " not present on " +
                        descriptor.getQualifiedName())));
            }
        }

        for (Expectation<M> predicate : expectations) {
            if (predicate instanceof ValidationExpectation) {
                ValidationExpectation<M,E> valEx = (ValidationExpectation<M,E>) predicate;
                errors.addAll(valEx.validationErrors(path, (M) message));
            } else {
                try {
                    predicate.test(path, (M) message);
                } catch (Exception e) {
                    errors.add(onMismatch.apply(e));
                }
            }
        }
        return errors;
    }

    /**
     * Make a builder out of the current validation. The builder can build uppon the
     * validation, but cannot remove expectations from it. It will not modify the
     * original validation, only make a new one.
     *
     * @return A validation builder to extend the current.
     */
    public Builder<M, E> toBuilder() {
        return new Builder<>(this);
    }

    /**
     * Builder vlass for message validators.
     *
     * @param <M> Message type.
     * @param <E> Exception type.
     */
    public static class Builder<
            M extends PMessage<M>,
            E extends Exception> {
        /**
         * Build the validator.
         *
         * @return The validator instance.
         */
        @Nonnull
        public MessageValidation<M, E> build() {
            return new MessageValidation<>(this);
        }

        /**
         * Expect the message to be non-null value.
         *
         * @return The builder instance.
         */
        @Nonnull
        public Builder<M, E> expectNotNull() {
            this.allowNull = false;
            return this;
        }

        /**
         * Expect field to be present on message.
         *
         * @param fields The fields to be present.
         * @return The builder instance.
         */
        @Nonnull
        @SafeVarargs
        public final Builder<M, E> expectPresent(@Nonnull PField<M>... fields) {
            expectedPresentFields.addAll(asList(fields));
            expectedMissingFields.removeAll(asList(fields));
            return this;
        }

        /**
         * Expect field to be present on message.
         *
         * @param fields The fields to be present.
         * @return The builder instance.
         */
        @Nonnull
        @SafeVarargs
        public final Builder<M, E> expectMissing(@Nonnull PField<M>... fields) {
            expectedMissingFields.addAll(asList(fields));
            expectedPresentFields.removeAll(asList(fields));
            return this;
        }

        /**
         * Make a specific expectation for the message.
         *
         * @param expectation Expectation predicate.
         * @return The builder instance.
         */
        @Nonnull
        public Builder<M, E> expect(@Nonnull SimpleExpectation<M> expectation) {
            this.expectations.add(expectation);
            return this;
        }

        /**
         * Make a specific expectation for the message.
         *
         * @param expectation Expectation predicate.
         * @return The builder instance.
         */
        @Nonnull
        public Builder<M, E> expect(@Nonnull Expectation<M> expectation) {
            this.expectations.add(expectation);
            return this;
        }

        /**
         * Given the field and type descriptor (which must match the field type),
         * build an inner validator to check the value of the field.
         *
         * @param field            The field to check.
         * @param valueExpectation Expectation of field value.
         * @param <V>              The inner message type.
         * @return The builder instance.
         */
        @Nonnull
        public <V>
        Builder<M, E> expectIfPresent(@Nonnull PField<M> field,
                                      @Nonnull SimpleExpectation<V> valueExpectation) {
            return expectIfPresent(field, (Expectation<V>) valueExpectation);
        }

        /**
         * Given the field and type descriptor (which must match the field type),
         * build an inner validator to check the value of the field.
         *
         * @param field            The field to check.
         * @param valueExpectation Expectation of field value.
         * @param <V>              The inner message type.
         * @return The builder instance.
         */
        @Nonnull
        public <V>
        Builder<M, E> expectIfPresent(@Nonnull PField<M> field,
                                      @Nonnull Expectation<V> valueExpectation) {
            Expectation<M> expectation = (path, message) -> {
                if (message.has(field)) {
                    valueExpectation.test(keyPathAppend(path, field), message.get(field));
                }
            };
            this.expectations.add(expectation);
            return this;
        }

        /**
         * Given the field and type descriptor (which must match the field type),
         * build an inner validator to check the value of the field.
         *
         * @param field           The field to check.
         * @param descriptor      The message descriptor matching the field.
         * @param builderConsumer Consumer to configure the inner validator.
         * @param <M2>            The inner message type.
         * @return The builder instance.
         */
        @Nonnull
        public <M2 extends PMessage<M2>>
        Builder<M, E> expectIfPresent(@Nonnull PField<M> field,
                                      @Nonnull PMessageDescriptor<M2> descriptor,
                                      @Nonnull Consumer<Builder<M2, E>> builderConsumer) {
            if (!field.onMessageType().equals(this.descriptor)) {
                throw new IllegalArgumentException(
                        "Field not part of, " + this.descriptor.getQualifiedName());
            }
            Builder<M2, E> builder = builder(descriptor, onMismatch);
            builderConsumer.accept(builder);
            MessageValidation<M2, E> validator = builder.build();

            if (field.getType() == PType.MESSAGE) {
                this.expectations.add(new MessageValidationExpectation<>(field, validator));
            } else if (field.getType() == PType.MAP) {
                this.expectations.add(new MessageMapValidationExpectation<>(field, validator));
            } else if (field.getType() == PType.LIST ||
                       field.getType() == PType.SET) {
                this.expectations.add(new MessageCollectionValidationExpectation<>(field, validator));
            } else {
                throw new IllegalArgumentException(
                        "Field type mismatch, '" + field + "' is not usable for " + descriptor.getQualifiedName() + " validation.");
            }

            if (!validator.allowNull) {
                expectPresent(field);
            }
            return this;
        }

        private Builder(PMessageDescriptor<M> descriptor, @Nonnull Function<Exception, E> onMismatch) {
            this.descriptor = descriptor;
            this.onMismatch = onMismatch;
            this.expectations = new ArrayList<>();
            this.allowNull = true;
            this.expectedPresentFields = new TreeSet<>(Comparator.comparing(PField::getName));
            this.expectedMissingFields = new TreeSet<>(Comparator.comparing(PField::getName));
        }

        private Builder(MessageValidation<M, E> validation) {
            this.descriptor = validation.descriptor;
            this.onMismatch = validation.onMismatch;
            this.expectations = new ArrayList<>(validation.expectations);
            this.allowNull = validation.allowNull;
            this.expectedPresentFields = new TreeSet<>(validation.expectedPresentFields);
            this.expectedMissingFields = new TreeSet<>(validation.expectedMissingFields);
        }

        private       boolean                allowNull;
        private final SortedSet<PField<M>>   expectedPresentFields;
        private final SortedSet<PField<M>>   expectedMissingFields;
        private final PMessageDescriptor<M>  descriptor;
        private final Function<Exception, E> onMismatch;
        private final List<Expectation<M>>   expectations;
    }

    private MessageValidation(Builder<M, E> builder) {
        this.expectedMissingFields = UnmodifiableSortedSet.copyOf(builder.expectedMissingFields);
        this.expectedPresentFields = UnmodifiableSortedSet.copyOf(builder.expectedPresentFields);
        this.onMismatch = builder.onMismatch;
        this.allowNull = builder.allowNull;
        this.descriptor = builder.descriptor;
        this.expectations = UnmodifiableList.copyOf(builder.expectations);
    }

    /**
     * Validate a message using the built expectations.
     *
     * @param path    Logical path in structure that is being validated. Used in
     *                failure exceptions.
     * @param message The message or builder to be validated.
     * @param <ME>    Message or builder type to be validated.
     * @return Message after validation.
     * @throws E On not valid message.
     */
    <ME extends PMessageOrBuilder<M>> ME validateInternal(@Nonnull String path,
                                                          @Nullable ME message) throws Exception {
        if (message == null) {
            if (allowNull) return null;
            throw new IllegalStateException(
                    pathPrefix(path) + "Null " + descriptor.getQualifiedName() + " value" + pathPrefix(path) + ".");
        }
        if (!descriptor.equals(message.descriptor())) {
            throw new IllegalStateException(
                    pathPrefix(path) + "Validating message of type " + message.descriptor().getQualifiedName() +
                    ", required " + descriptor.getQualifiedName() + ".");
        }

        // TODO: Test the un-built message whenever possible.
        M toTest = message.toMessage();

        if (!expectedMissingFields.isEmpty()) {
            List<String> present = expectedMissingFields.stream()
                                                        .filter(field -> toTest.has(field.getId()))
                                                        .map(PField::getName)
                                                        .collect(Collectors.toList());
            if (!present.isEmpty()) {
                throw new IllegalStateException(pathPrefix(path) +
                                                Strings.join(", ", present) + " present on " + descriptor.getQualifiedName());
            }
        }
        if (!expectedPresentFields.isEmpty()) {
            List<String> missing = expectedPresentFields.stream()
                                                        .filter(field -> !toTest.has(field.getId()))
                                                        .map(PField::getName)
                                                        .collect(Collectors.toList());
            if (!missing.isEmpty()) {
                throw new IllegalStateException(pathPrefix(path) +
                                                Strings.join(", ", missing) + " not present on " + descriptor.getQualifiedName());
            }
        }

        for (Expectation<M> expectation : expectations) {
            expectation.test(path, toTest);
        }
        return message;
    }

    public static String pathPrefix(String path) {
        if (path.isEmpty()) return path;
        return path + ": ";
    }

    public static String atPathSuffix(String path) {
        if (path.isEmpty()) return path;
        return " at " + path;
    }

    private final SortedSet<PField<M>>   expectedPresentFields;
    private final SortedSet<PField<M>>   expectedMissingFields;
    private final PMessageDescriptor<M>  descriptor;
    private final boolean                allowNull;
    private final Function<Exception, E> onMismatch;
    private final List<Expectation<M>>   expectations;
}