EqualToMessage.java

/*
 * Copyright 2016 Providence Authors
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.providence.testing;

import net.morimekta.providence.PEnumValue;
import net.morimekta.providence.PMessage;
import net.morimekta.providence.PMessageOrBuilder;
import net.morimekta.providence.descriptor.PField;
import net.morimekta.util.Binary;
import net.morimekta.util.Strings;
import net.morimekta.util.collect.UnmodifiableSet;
import org.hamcrest.BaseMatcher;
import org.hamcrest.Description;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import static net.morimekta.providence.testing.util.MessageDiff.collectMismatches;
import static net.morimekta.providence.testing.util.MessageDiff.limitToString;

/**
 * Equality matcher for messages with pinpointed field diff output and
 * possibility to ignore individual fields.
 */
public class EqualToMessage<Message extends PMessage<Message>>
        extends BaseMatcher<PMessageOrBuilder<Message>> {
    private final PMessageOrBuilder<Message> expected;
    private final Set<PField> ignoringFields;

    public static <Message extends PMessage<Message>>
    EqualToMessage<Message> equalToMessage(PMessageOrBuilder<Message> expected) {
        return new EqualToMessage<>(expected);
    }

    public EqualToMessage(PMessageOrBuilder<Message> expected) {
        this.expected = expected.toMessage();
        this.ignoringFields = UnmodifiableSet.setOf();
    }

    private EqualToMessage(PMessageOrBuilder<Message> expected, Set<PField> ignoringFields) {
        this.expected = expected.toMessage();
        this.ignoringFields = ignoringFields;
    }

    public final EqualToMessage<Message> ignoring(PField... fields) {
        return new EqualToMessage<>(expected, UnmodifiableSet.copyOf(fields));
    }

    @Override
    @SuppressWarnings("unchecked")
    public boolean matches(Object actual) {
        if (expected == null) {
            return actual == null;
        }
        if (!(actual instanceof PMessage)) {
            throw new AssertionError("Item " + actual.getClass().toString() + " not a providence message.");
        }
        if (!((PMessage) actual).descriptor().equals(expected.descriptor())) {
            throw new AssertionError("Expected message type " + expected.descriptor().getQualifiedName() +
                                     ", but got " + ((PMessage) actual).descriptor().getQualifiedName());
        }
        try {
            collectMismatches("", expected, (Message) actual, UnmodifiableSet.copyOf(ignoringFields), str -> {
                throw new IllegalStateException();
            });
            return true;
        } catch (IllegalStateException e) {
            return false;
        }
    }

    @Override
    public void describeTo(Description description) {
        description.appendText("equals(")
                   .appendText(limitToString(expected))
                   .appendText(")");
    }

    @Override
    @SuppressWarnings("unchecked")
    public void describeMismatch(Object actual, Description mismatchDescription) {
        if (expected == null) {
            mismatchDescription.appendText("got " + toString(actual));
        } else if (actual == null) {
            mismatchDescription.appendText("got null");
        } else {
            ArrayList<String> mismatches = new ArrayList<>();
            collectMismatches("", expected, (PMessage) actual, UnmodifiableSet.copyOf(ignoringFields), mismatches::add);
            if (mismatches.size() == 1) {
                mismatchDescription.appendText(mismatches.get(0));
            } else {
                boolean first = true;
                mismatchDescription.appendText("[");
                int i = 0;
                for (String mismatch : mismatches) {
                    if (first) {
                        first = false;
                    } else {
                        mismatchDescription.appendText(",");
                    }
                    mismatchDescription.appendText("\n        ");
                    if (i >= 20) {
                        int remaining = mismatches.size() - i;
                        mismatchDescription.appendText("... and " + remaining + " more");
                        break;
                    }
                    mismatchDescription.appendText(mismatch);
                    ++i;
                }
                mismatchDescription.appendText("\n     ]");
            }
        }
    }

    protected static String toString(Object o) {
        if (o == null) {
            return "null";
        } else if (o instanceof PMessage) {
            return limitToString((PMessage<?>) o);
        } else if (o instanceof PEnumValue) {
            return ((PEnumValue) o).descriptor()
                                   .getName() + "." + ((PEnumValue) o).asString();
        } else if (o instanceof Map) {
            return "{" + Strings.join(",",
                                      ((Map<?, ?>) o).entrySet()
                                                     .stream()
                                                     .map(e -> toString(e.getKey()) + ":" + toString(e.getValue()))
                                                     .collect(Collectors.toList())) + "}";
        } else if (o instanceof Collection) {
            return "[" + Strings.join(",",
                                      ((Collection<?>) o).stream()
                                                         .map(EqualToMessage::toString)
                                                         .collect(Collectors.toList())) + "]";
        } else if (o instanceof CharSequence) {
            return "\"" + Strings.escape(o.toString()) + "\"";
        } else if (o instanceof Binary) {
            int len = ((Binary) o).length();
            if (len > 110) {
                return String.format(Locale.US, "binary[%s...+%d]",
                                     ((Binary) o).toHexString()
                                                 .substring(0, 100),
                                     len - 50);
            } else {
                return "binary[" + ((Binary) o).toHexString() + "]";
            }
        } else if (o instanceof Double) {
            long l = ((Double) o).longValue();
            if (o.equals((double) l)) {
                return Long.toString(l);
            } else {
                return o.toString();
            }
        } else {
            return o.toString();
        }
    }
}