MessageDiff.java

/*
 * Copyright 2016 Providence Authors
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.providence.testing.util;

import net.morimekta.providence.PEnumValue;
import net.morimekta.providence.PMessage;
import net.morimekta.providence.PMessageOrBuilder;
import net.morimekta.providence.PMessageVariant;
import net.morimekta.providence.PUnion;
import net.morimekta.providence.descriptor.PField;
import net.morimekta.util.Binary;
import net.morimekta.util.Strings;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * Utility to map differences between two messages in a much
 * more compact way than showing the two messages. It will try
 */
public class MessageDiff {
    private MessageDiff() { }

    public interface MismatchHandler {
        void onMismatch(String description);
    }

    public static <T extends PMessage<T>>
    void collectMismatches(String xPath,
                           PMessageOrBuilder<T> expected,
                           PMessageOrBuilder<T> actual,
                           Set<PField> ignoringFields,
                           MismatchHandler mismatches) {
        // This is pretty heavy calculation, but since it's only done on
        // mismatch / test failure, it should be fine.
        if (expected.descriptor()
                    .getVariant() == PMessageVariant.UNION) {
            PUnion<?> eu = (PUnion) expected;
            PUnion<?> ac = (PUnion) actual;

            if (eu.unionFieldIsSet() != ac.unionFieldIsSet()) {
                if (eu.unionFieldIsSet()) {
                    mismatches.onMismatch("unexpected value in " + xPath);
                } else {
                    mismatches.onMismatch("expected value in " + xPath + " but not set");
                }
            } else if (!eu.unionField().equals(ac.unionField())) {
                mismatches.onMismatch(String.format(Locale.US, "%s to be %s, but was %s",
                                                    xPath,
                                                    eu.unionField()
                                                      .getName(),
                                                    ac.unionField()
                                                      .getName()));
            }
        }

        for (PField<T> field : expected.descriptor().getFields()) {
            if (ignoringFields.contains(field)) {
                continue;
            }

            int    key        = field.getId();
            String fieldXPath = xPath.isEmpty() ? field.getName() : xPath + "." + field.getName();

            if (expected.has(key) != actual.has(key)) {
                if (!expected.has(key)) {
                    mismatches.onMismatch(String.format(Locale.US, "%s to be missing, but was %s",
                                                        fieldXPath,
                                                        toString(actual.get(field.getId()))));
                } else if (!actual.has(key)) {
                    mismatches.onMismatch(String.format(Locale.US, "%s to be %s, but was missing",
                                                        fieldXPath,
                                                        toString(expected.get(field.getId()))));
                }
            } else if (!Objects.equals(expected.get(key), actual.get(key))) {
                switch (field.getType()) {
                    case MESSAGE: {
                        collectMismatches(fieldXPath,
                                          expected.get(key),
                                          actual.get(key),
                                          ignoringFields,
                                          mismatches);
                        break;
                    }
                    case LIST: {
                        collectListMismatches(fieldXPath,
                                              expected.get(key),
                                              actual.get(key),
                                              ignoringFields,
                                              mismatches);
                        break;
                    }
                    case SET: {
                        collectSetMismatches(fieldXPath, expected.get(key), actual.get(key), mismatches);
                        break;
                    }
                    case MAP: {
                        collectMapMismatches(fieldXPath,
                                             expected.get(key),
                                             actual.get(key),
                                             ignoringFields,
                                             mismatches);
                        break;
                    }
                    default: {
                        mismatches.onMismatch(String.format(Locale.US, "%s was %s, expected %s",
                                                            fieldXPath,
                                                            toString(actual.get(field.getId())),
                                                            toString(expected.get(field.getId()))));
                        break;
                    }
                }
            }
        }
    }

    @SuppressWarnings("unchecked")
    private static <K, V> void collectMapMismatches(String xPath,
                                                    Map<K, V> expected,
                                                    Map<K, V> actual,
                                                    Set<PField> ignoringFields,
                                                    MismatchHandler mismatches) {

        actual.keySet()
              .stream()
              .filter(key -> !expected.containsKey(key))
              .map(key -> String.format(Locale.US, "found unexpected entry (%s, %s) in %s",
                                        toString(key),
                                        toString(actual.get(key)),
                                        xPath))
              .forEach(mismatches::onMismatch);


        for (Map.Entry<K, V> entry : expected.entrySet()) {
            if (!actual.keySet()
                       .contains(entry.getKey())) {
                mismatches.onMismatch(String.format(Locale.US, "did not find entry (%s, %s) in in %s",
                                                    toString(entry.getKey()),
                                                    toString(expected.get(entry.getKey())),
                                                    xPath));
            } else {
                V exp = entry.getValue();
                V act = actual.get(entry.getKey());
                if (!Objects.equals(exp, act)) {
                    // value differs.
                    String keyedXPath = String.format(Locale.US, "%s[%s]", xPath, toString(entry));
                    if (exp == null || act == null) {
                        mismatches.onMismatch(String.format(Locale.US, "%s was %s, should be %s",
                                                            keyedXPath,
                                                            toString(exp),
                                                            toString(act)));
                    } else if (act instanceof PMessage) {
                        collectMismatches(keyedXPath, (PMessage) exp, (PMessage) act, ignoringFields, mismatches);
                    } else {
                        mismatches.onMismatch(String.format(Locale.US, "%s was %s, should be %s",
                                                            keyedXPath,
                                                            toString(act),
                                                            toString(exp)));
                    }
                }
            }
        }
    }

    private static <T> void collectSetMismatches(String xPath,
                                                 Set<T> expected,
                                                 Set<T> actual,
                                                 MismatchHandler mismatches) {
        // order does NOT matter regardless of type. The only
        // errors are missing and unexpected values. Partial
        // matches are not checked.
        actual.stream()
              .filter(item -> !expected.contains(item))
              .map(item -> String.format(Locale.US, "found unexpected set value %s in %s",
                                         toString(item),
                                         xPath))
              .forEach(mismatches::onMismatch);

        expected.stream()
                .filter(item -> !actual.contains(item))
                .map(item -> String.format(Locale.US, "did not find value %s in %s", toString(item), xPath))
                .forEach(mismatches::onMismatch);
    }

    @SuppressWarnings("unchecked")
    private static <T> void collectListMismatches(String xPath,
                                                  List<T> expected,
                                                  List<T> actual,
                                                  Set<PField> ignoringFields,
                                                  MismatchHandler mismatches) {
        Set<T> handledItems = new HashSet<>();

        boolean           hasReorder = false;
        ArrayList<String> reordering = new ArrayList<>();
        for (int expectedIndex = 0; expectedIndex < expected.size(); ++expectedIndex) {
            String indexedXPath = String.format(Locale.US, "%s[%d]", xPath, expectedIndex);
            T      expectedItem = expected.get(expectedIndex);
            handledItems.add(expectedItem);

            T actualItem = actual.size() > expectedIndex ? actual.get(expectedIndex) : null;
            if (Objects.equals(expectedItem, actualItem)) {
                continue;
            }
            int actualIndex = actual.indexOf(expectedItem);

            int actualItemExpectedIndex = -1;
            if (actualItem != null) {
                actualItemExpectedIndex = expected.indexOf(actualItem);
            }

            if (actualIndex < 0) {
                reordering.add("NaN");
                // this item is missing.
                if (actualItemExpectedIndex < 0) {
                    handledItems.add(actualItem);
                    // replaced with new item, diff them normally.
                    if (actualItem instanceof PMessage) {
                        collectMismatches(indexedXPath,
                                          (PMessage) expectedItem,
                                          (PMessage) actualItem,
                                          ignoringFields,
                                          mismatches);
                    } else {
                        mismatches.onMismatch(String.format(Locale.US, "expected %s to be %s, but was %s",
                                                            indexedXPath,
                                                            toString(expectedItem),
                                                            toString(actualItem)));
                    }
                } else {
                    // the other item is reordered, so this is blindly inserted.
                    mismatches.onMismatch(String.format(Locale.US,
                                                        "missing item %s in %s",
                                                        toString(expectedItem),
                                                        indexedXPath));
                }
            } else if (actualIndex != expectedIndex) {
                reordering.add(String.format(Locale.US, "%+d", actualIndex - expectedIndex));
                hasReorder = true;
            } else {
                reordering.add("±0");
            }
        }
        for (int actualIndex = 0; actualIndex < actual.size(); ++actualIndex) {
            T actualItem = actual.get(actualIndex);
            if (handledItems.contains(actualItem)) {
                continue;
            }
            if (expected.contains(actualItem)) {
                continue;
            }
            String indexedXPath = String.format(Locale.US, "%s[%d]", xPath, actualIndex);
            mismatches.onMismatch(String.format(Locale.US,
                                                "unexpected item %s in %s",
                                                toString(actualItem),
                                                indexedXPath));
        }
        if (hasReorder) {
            mismatches.onMismatch(String.format(Locale.US,
                                                "unexpected item ordering in %s: [%s]",
                                                xPath,
                                                Strings.join(",", reordering)));
        }

    }

    public static String toString(Object o) {
        if (o == null) {
            return "null";
        } else if (o instanceof PMessage) {
            return limitToString((PMessage<?>) o);
        } else if (o instanceof PEnumValue) {
            PEnumValue v = (PEnumValue) o;
            return v.descriptor().getName() + "." + v.asString();
        } else if (o instanceof Map) {
            return "{" + Strings.join(",",
                                      ((Map<?, ?>) o).entrySet()
                                                     .stream()
                                                     .map(e -> toString(e.getKey()) + ":" + toString(e.getValue()))
                                                     .collect(Collectors.toList())) + "}";
        } else if (o instanceof Collection) {
            return "[" + Strings.join(",",
                                      ((Collection<?>) o).stream()
                                                         .map(MessageDiff::toString)
                                                         .collect(Collectors.toList())) + "]";
        } else if (o instanceof CharSequence) {
            return "\"" + Strings.escape(o.toString()) + "\"";
        } else if (o instanceof Binary) {
            int len = ((Binary) o).length();
            if (len > 110) {
                return String.format(Locale.US, "binary[%s...+%d]",
                                     ((Binary) o).toHexString()
                                                 .substring(0, 100),
                                     len - 50);
            } else {
                return "binary[" + ((Binary) o).toHexString() + "]";
            }
        } else if (o instanceof Double) {
            long l = ((Double) o).longValue();
            if (o.equals((double) l)) {
                return "<" + l + ">";
            } else {
                return "<" + o.toString() + ">";
            }
        } else {
            return "<" + o.toString() + ">";
        }
    }

    public static <M extends PMessage<M>>
    String limitToString(PMessageOrBuilder<M> message) {
        String tos = message == null ? "null" : message.toMessage().asString();
        if (tos.length() > 120) {
            tos = tos.substring(0, 110) + "...}";
        }

        return tos;
    }
}