EqualMessageIgnoring.java

package net.morimekta.proto.testing.matchers;

import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors;
import com.google.protobuf.Message;
import com.google.protobuf.MessageOrBuilder;
import net.morimekta.collect.util.SetOperations;
import net.morimekta.proto.ProtoMessage;
import net.morimekta.strings.Stringable;
import org.hamcrest.BaseMatcher;
import org.hamcrest.Description;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

import static java.util.Objects.requireNonNull;
import static net.morimekta.collect.UnmodifiableSet.asSet;
import static net.morimekta.proto.utils.ByteStringUtil.toBase64;
import static net.morimekta.proto.utils.FieldUtil.fieldPathToFields;
import static net.morimekta.proto.utils.FieldUtil.filterFields;
import static net.morimekta.proto.utils.ValueUtil.toDebugString;
import static net.morimekta.proto.utils.ValueUtil.toJavaValue;

/**
 * Matcher for checking equality between messages that can ignore fields and extensions.
 *
 * @param <M> Message type.
 */
public class EqualMessageIgnoring<M extends Message> extends BaseMatcher<M> {
    protected final M           expected;
    private final   Set<String> ignoringFields;
    private final   boolean     ignoreExtensions;

    /**
     * @param expected         Expected message.
     * @param ignoreExtensions If all extensions should be ignored.
     * @param ignoringFields   Ignored fields. This can be any field descriptor, including extensions.
     */
    public EqualMessageIgnoring(M expected, boolean ignoreExtensions, Set<String> ignoringFields) {
        ignoringFields.forEach(path -> fieldPathToFields(expected.getDescriptorForType(), path));
        this.expected = requireNonNull(expected, "expected == null");
        this.ignoringFields = asSet(ignoringFields);
        this.ignoreExtensions = ignoreExtensions;
    }

    @Override
    public boolean matches(Object o) {
        if (!(o instanceof Message)) {
            return false;
        }
        if (!expected.getClass().equals(o.getClass())) {
            return false;
        }
        if (ignoringFields.isEmpty() && !ignoreExtensions) {
            return expected.equals(o);
        } else {
            return equalsIgnoring(expected, o, ignoreExtensions, ignoringFields);
        }
    }

    @Override
    public void describeTo(Description description) {
        if (ignoreExtensions || !ignoringFields.isEmpty()) {
            description.appendText(expected.getDescriptorForType().getFullName());
            appendValueIgnoring("", expected, description, ignoreExtensions, ignoringFields);
        } else {
            description.appendText(toDebugString(expected));
        }
    }

    @Override
    public void describeMismatch(Object item, Description description) {
        if (item instanceof Message) {
            Message actual = (Message) item;
            description.appendText("was ");
            if (ignoreExtensions || !ignoringFields.isEmpty()) {
                description.appendText(actual.getDescriptorForType().getFullName());
                appendValueIgnoring("", actual, description, ignoreExtensions, ignoringFields);
            } else {
                description.appendText(toDebugString(actual));
            }
        } else {
            super.describeMismatch(item, description);
        }
    }

    private static void appendValueIgnoring(String prefix,
                                            Object value,
                                            Description description,
                                            boolean ignoreExtensions,
                                            Set<String> ignoringFields) {
        if (value instanceof Message) {
            var msg = (Message) value;
            var allFields = msg.getAllFields()
                               .entrySet()
                               .stream()
                               .filter(field -> {
                                   if (field.getKey().isExtension() && ignoreExtensions) {
                                       return false;
                                   }
                                   return !ignoringFields.contains(field.getKey().getName());
                               })
                               .collect(Collectors.toList());
            if (allFields.isEmpty()) {
                description.appendText("{}");
                return;
            }
            description.appendText("{\n");
            for (var fieldEntry : allFields) {
                var val = toJavaValue(fieldEntry.getKey(), fieldEntry.getValue());
                description.appendText(prefix)
                           .appendText("  ");
                if (fieldEntry.getKey().isExtension()) {
                    description.appendText("(")
                               .appendText(fieldEntry.getKey().getFullName())
                               .appendText(")");
                } else {
                    description.appendText(fieldEntry.getKey().getName());
                }
                description.appendText(" = ");
                appendValueIgnoring(prefix + "  ",
                                    val,
                                    description,
                                    ignoreExtensions,
                                    filterFields(ignoringFields, fieldEntry.getKey()));
                description.appendText("\n");
            }
            description.appendText(prefix).appendText("}");
        } else if (value instanceof Map) {
            description.appendText("{");
            boolean first = true;
            for (Map.Entry<?, ?> entry : ((Map<?, ?>) value).entrySet()) {
                if (first) {
                    first = false;
                } else {
                    description.appendText(",");
                }
                description.appendText("\n")
                           .appendText(prefix)
                           .appendText("  ")
                           .appendText(Stringable.asString(entry.getKey()))
                           .appendText(": ");
                appendValueIgnoring(prefix + "  ", entry.getValue(), description, ignoreExtensions, ignoringFields);
            }
            description.appendText("\n")
                       .appendText(prefix).appendText("}");
        } else if (value instanceof List) {
            List<?> list = (List<?>) value;
            if (list.isEmpty()) {
                // just a safe-guard. This should never be possible, so cannot even be tested.
                description.appendText("[]");
            } else {
                var first = list.get(0);
                if (list.size() > 5 || first instanceof Message || first instanceof ByteString || first instanceof String) {
                    // one line per entry.
                    description.appendText("[\n");
                    for (var item : list) {
                        if (first != null) {
                            first = null;
                        } else {
                            description.appendText(",\n");
                        }
                        description.appendText(prefix + "  ");
                        appendValueIgnoring(prefix + "  ", item, description, ignoreExtensions, ignoringFields);
                    }
                    description.appendText("\n")
                               .appendText(prefix).appendText("]");
                } else {
                    description.appendText("[");
                    for (var item : list) {
                        if (first != null) {
                            first = null;
                        } else {
                            description.appendText(", ");
                        }
                        description.appendText(Stringable.asString(item));
                    }
                    description.appendText("]");
                }
            }
        } else if (value instanceof ByteString) {
            description.appendText("b64(").appendText(toBase64((ByteString) value)).appendText(")");
        } else {
            description.appendText(Stringable.asString(value));
        }
    }

    private static boolean equalsIgnoring(Object expected,
                                          Object actual,
                                          boolean ignoreExtensions,
                                          Set<String> ignoringFields) {
        if (Objects.equals(expected, actual)) {
            return true;
        }
        if (expected instanceof Message && actual instanceof Message) {
            if (ignoringFields.isEmpty() && !ignoreExtensions) {
                return false;
            }

            var mExp = new ProtoMessage((MessageOrBuilder) expected);
            var mAct = new ProtoMessage((MessageOrBuilder) actual);
            // all fields from either message, minus
            var allFields = SetOperations
                    .union(mExp.getMessage().getAllFields().keySet(),
                           mAct.getMessage().getAllFields().keySet())
                    .stream()
                    .filter(field -> {
                        if (field.isExtension() && ignoreExtensions) {
                            return false;
                        }
                        return !ignoringFields.contains(field.getName());
                    })
                    .collect(Collectors.toList());
            for (Descriptors.FieldDescriptor field : allFields) {
                if (field.isExtension() && ignoreExtensions) {
                    continue;
                }
                if (!equalsIgnoring(
                        mExp.optional(field).orElse(null),
                        mAct.optional(field).orElse(null),
                        ignoreExtensions,
                        filterFields(ignoringFields, field))) {
                    return false;
                }
            }
            return true;
        }
        if (expected instanceof Map && actual instanceof Map) {
            @SuppressWarnings("unchecked")
            var expMap = (Map<Object, Object>) expected;
            @SuppressWarnings("unchecked")
            var actMap = (Map<Object, Object>) actual;
            if (!expMap.keySet().equals(actMap.keySet())) {
                return false;
            }
            for (var expEntry : expMap.entrySet()) {
                if (!equalsIgnoring(expEntry.getValue(),
                                    actMap.get(expEntry.getKey()),
                                    ignoreExtensions,
                                    ignoringFields)) {
                    return false;
                }
            }
            return true;
        }
        if (expected instanceof List && actual instanceof List) {
            @SuppressWarnings("unchecked")
            var expList = (List<Object>) expected;
            @SuppressWarnings("unchecked")
            var actList = (List<Object>) actual;
            if (expList.size() != actList.size()) {
                return false;
            }
            for (int i = 0; i < expList.size(); ++i) {
                if (!equalsIgnoring(expList.get(i), actList.get(i), ignoreExtensions, ignoringFields)) {
                    return false;
                }
            }
            return true;
        }
        return false;
    }
}