ProtoMessageDeserializer.java

/*
 * Copyright 2017 Providence Authors
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.proto.jackson.adapter;

import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonToken;
import com.fasterxml.jackson.databind.DeserializationConfig;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonMappingException;
import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors;
import com.google.protobuf.Duration;
import com.google.protobuf.Internal;
import com.google.protobuf.Message;
import com.google.protobuf.Timestamp;
import net.morimekta.collect.MapBuilder;
import net.morimekta.collect.UnmodifiableList;
import net.morimekta.collect.UnmodifiableMap;
import net.morimekta.proto.MorimektaOptions;
import net.morimekta.proto.ProtoEnum;
import net.morimekta.proto.ProtoMessageBuilder;
import net.morimekta.proto.jackson.ProtoFeature;
import net.morimekta.proto.jackson.ProtoStringFeature;
import net.morimekta.proto.jackson.ProtoTypeRegistryOption;
import net.morimekta.proto.utils.FieldUtil;
import net.morimekta.proto.utils.GoogleTypesUtil;
import net.morimekta.proto.utils.JsonNameUtil;
import net.morimekta.proto.utils.ProtoTypeRegistry;

import java.io.IOException;
import java.time.Instant;
import java.util.Base64;
import java.util.List;

import static com.fasterxml.jackson.core.JsonToken.END_ARRAY;
import static com.fasterxml.jackson.core.JsonToken.START_ARRAY;
import static com.fasterxml.jackson.core.JsonToken.START_OBJECT;
import static com.fasterxml.jackson.core.JsonToken.VALUE_FALSE;
import static com.fasterxml.jackson.core.JsonToken.VALUE_NULL;
import static com.fasterxml.jackson.core.JsonToken.VALUE_NUMBER_FLOAT;
import static com.fasterxml.jackson.core.JsonToken.VALUE_NUMBER_INT;
import static com.fasterxml.jackson.core.JsonToken.VALUE_STRING;
import static com.fasterxml.jackson.core.JsonToken.VALUE_TRUE;
import static net.morimekta.proto.jackson.ProtoFeature.FAIL_ON_NULL_VALUE;
import static net.morimekta.proto.jackson.ProtoFeature.FAIL_ON_UNKNOWN_FIELD;
import static net.morimekta.proto.jackson.ProtoFeature.IGNORE_UNKNOWN_ANY_TYPE;
import static net.morimekta.proto.utils.JsonNameUtil.NUMERIC_FIELD_ID;
import static net.morimekta.strings.EscapeUtil.javaEscape;

/**
 * Deserialized proto messages.
 *
 * @param <M> The proto message type.
 */
public class ProtoMessageDeserializer<M extends Message>
        extends JsonDeserializer<M> {
    private final Descriptors.Descriptor descriptor;
    private final DeserializationConfig  config;

    /**
     * Instantiate deserializer.
     *
     * @param descriptor The message type descriptor.
     * @param config     The deserialization config.
     */
    public ProtoMessageDeserializer(Descriptors.Descriptor descriptor, DeserializationConfig config) {
        this.descriptor = descriptor;
        this.config = config;
    }

    @Override
    @SuppressWarnings("unchecked")
    public M deserialize(JsonParser jsonParser, DeserializationContext deserializationContext)
            throws IOException {
        return (M) parseMessage(jsonParser, jsonParser.currentToken(), descriptor, deserializationContext);
    }

    private Message parseMessage(
            JsonParser jsonParser,
            JsonToken next,
            Descriptors.Descriptor descriptor,
            DeserializationContext deserializationContext) throws IOException {
        if (isTokenType(next, START_OBJECT)) {
            var first = jsonParser.nextToken();
            if (descriptor.equals(Any.getDescriptor()) && first == JsonToken.FIELD_NAME) {
                var name = jsonParser.currentName();
                if (JsonNameUtil.ANY_TYPE_FIELDS.contains(name) ||
                    name.equals(ProtoStringFeature.ANY_TYPE_FIELD_NAME.get(config))) {
                    validateTokenType(jsonParser, jsonParser.nextValue(), VALUE_STRING);
                    var typeUrl = jsonParser.getValueAsString();
                    var tr = ProtoTypeRegistryOption.getRegistry(deserializationContext.getConfig());
                    if (tr != null) {
                        var anyType = tr.messageTypeByTypeUrl(typeUrl);
                        if (anyType != null) {
                            var unpacked = parseMessageFields(jsonParser,
                                                              jsonParser.nextToken(),
                                                              anyType,
                                                              deserializationContext);
                            return Any.pack(unpacked);
                        }
                    }
                    if (IGNORE_UNKNOWN_ANY_TYPE.isEnabled(config)) {
                        jsonParser.skipChildren();
                        return Any.newBuilder()
                                  .setTypeUrl(typeUrl)
                                  // with no value.
                                  .build();
                    } else {
                        throw JsonMappingException.from(
                                jsonParser, "Unknown type for unpacked any: " + typeUrl);
                    }
                }
            }
            return parseMessageFields(jsonParser, first, descriptor, deserializationContext);
        } else if (isTokenType(next, START_ARRAY)) {
            if (!descriptor.getOptions().getExtension(MorimektaOptions.compact)) {
                throw JsonMappingException.from(
                        jsonParser, "Array notation not allowed for " + descriptor.getFullName());
            }
            var builder = new ProtoMessageBuilder(descriptor);
            next = jsonParser.nextToken();
            int idx = -1;
            List<Descriptors.FieldDescriptor> fields = descriptor.getFields();
            while (!isTokenType(next, END_ARRAY)) {
                ++idx;
                if (idx >= fields.size()) {
                    if (FAIL_ON_NULL_VALUE.isEnabled(config)) {
                        throw JsonMappingException.from(
                                jsonParser, "Values after last field for " + descriptor.getFullName());
                    }
                    // just consume all remaining values
                    jsonParser.skipChildren();
                    break;
                }
                if (isTokenType(next, VALUE_NULL)) {
                    next = jsonParser.nextToken();
                    continue;
                }
                Descriptors.FieldDescriptor field = fields.get(idx);
                Object value = parseValue(jsonParser, next, field, deserializationContext);
                builder.set(field, value);
                next = jsonParser.nextToken();
            }
            return builder.getMessage().build();
        } else {
            if (Timestamp.getDescriptor().equals(descriptor)) {
                var builder = Timestamp.newBuilder();
                if (isTokenType(next, VALUE_NUMBER_INT)) {
                    builder.setSeconds(jsonParser.getLongValue());
                } else if (isTokenType(next, VALUE_NUMBER_FLOAT)) {
                    // timestamp as fractional seconds since epoch.
                    var timestampDbl = jsonParser.getDoubleValue();
                    var timestampS = (long) timestampDbl;
                    var micros = (int) Math.round((timestampDbl - timestampS) * 1_000_000L);
                    builder.setSeconds(timestampS);
                    builder.setNanos(micros * 1000);
                } else if (isTokenType(next, VALUE_STRING)) {
                    var instant = parseIsoDateTime(jsonParser, jsonParser.getValueAsString());
                    builder.setSeconds(instant.getEpochSecond());
                    builder.setNanos(instant.getNano());
                } else {
                    throw JsonMappingException.from(
                            jsonParser, "Unknown value for " + descriptor.getFullName() + ": " + next);
                }
                return builder.build();
            } else if (Duration.getDescriptor().equals(descriptor)) {
                var builder = Duration.newBuilder();
                if (isTokenType(next, VALUE_NUMBER_INT)) {
                    builder.setSeconds(jsonParser.getLongValue());
                } else if (isTokenType(next, VALUE_NUMBER_FLOAT)) {
                    // timestamp as fractional seconds since epoch.
                    var durationDbl = jsonParser.getDoubleValue();
                    var durationS = (long) durationDbl;
                    var micros = (int) Math.round((durationDbl - durationS) * 1_000_000L);
                    builder.setSeconds(durationS);
                    builder.setNanos(micros * 1000);
                } else if (isTokenType(next, VALUE_STRING)) {
                    var duration = parseDuration(jsonParser, jsonParser.getValueAsString());
                    builder.setSeconds(duration.getSeconds());
                    builder.setNanos(duration.getNano());
                } else {
                    throw JsonMappingException.from(
                            jsonParser, "Unknown value for " + descriptor.getFullName() + ": " + next);
                }
                return builder.build();
            }
            throw JsonMappingException.from(
                    jsonParser, "Unknown start of object for " + descriptor.getFullName() + ": " + next);
        }
    }

    private Message parseMessageFields(JsonParser jsonParser,
                                       JsonToken next,
                                       Descriptors.Descriptor descriptor,
                                       DeserializationContext deserializationContext) throws IOException {
        var builder = new ProtoMessageBuilder(descriptor);
        var jsonFields = JsonNameUtil.getJsonFieldMap(descriptor);
        while (!next.isStructEnd()) {
            String fieldName = jsonParser.currentName();
            Descriptors.FieldDescriptor field;
            if (NUMERIC_FIELD_ID.matcher(fieldName).matches()) {
                var number = Integer.parseInt(fieldName);
                field = descriptor.findFieldByNumber(number);
                if (field == null) {
                    ProtoTypeRegistry er = ProtoTypeRegistryOption.getRegistry(deserializationContext.getConfig());
                    if (er != null) {
                        field = er.extensionByScopeAndNumber(descriptor, number);
                    }
                }
            } else {
                field = descriptor.findFieldByName(fieldName);
                if (field == null) {
                    field = jsonFields.get(fieldName);
                }
                if (field == null) {
                    ProtoTypeRegistry er = ProtoTypeRegistryOption.getRegistry(deserializationContext.getConfig());
                    if (er != null) {
                        field = er.extensionByScopeAndName(descriptor, fieldName);
                    }
                }
            }
            if (field == null) {
                if (FAIL_ON_UNKNOWN_FIELD.isEnabled(config)) {
                    throw JsonMappingException.from(
                            jsonParser, "Unknown field " + fieldName + " for " + descriptor.getFullName());
                }
                jsonParser.nextValue();
                jsonParser.skipChildren();
                next = jsonParser.nextToken();
                continue;
            }
            next = jsonParser.nextValue();
            Object value = parseValue(jsonParser, next, field, deserializationContext);
            builder.set(field, value);
            next = jsonParser.nextToken();
        }
        return builder.getMessage().build();
    }

    private Object parseValue(JsonParser jsonParser,
                              JsonToken next,
                              Descriptors.FieldDescriptor descriptor,
                              DeserializationContext deserializationContext)
            throws IOException {
        if (isTokenType(next, VALUE_NULL)) {
            if (FAIL_ON_NULL_VALUE.isEnabled(config)) {
                throw JsonMappingException.from(
                        jsonParser, "Null value for field " + descriptor.getFullName());
            }
            return null;
        }
        if (descriptor.isRepeated()) {
            if (descriptor.isMapField()) {
                validateTokenType(jsonParser, next, START_OBJECT);
                var keyType = FieldUtil.getMapKeyDescriptor(descriptor);
                var valueType = FieldUtil.getMapValueDescriptor(descriptor);
                MapBuilder<Object, Object> builder = UnmodifiableMap.newBuilder(2);
                String keyName = jsonParser.nextFieldName();
                while (keyName != null) {
                    Object key = parseMapKey(keyName, keyType);
                    Object value = parseSingleValue(jsonParser,
                                                    jsonParser.nextValue(),
                                                    valueType,
                                                    deserializationContext);
                    if (key != null && value != null) {
                        builder.put(key, enumToInt(value));
                    }
                    keyName = jsonParser.nextFieldName();
                }
                return builder.build();
            } else {
                validateTokenType(jsonParser, next, START_ARRAY);
                var builder = UnmodifiableList.newBuilder();
                next = jsonParser.nextToken();
                while (!isTokenType(next, END_ARRAY)) {
                    Object value = parseSingleValue(jsonParser, next, descriptor, deserializationContext);
                    if (value != null) {
                        builder.add(value);
                    }
                    next = jsonParser.nextToken();
                }
                return builder.build();
            }
        } else {
            return parseSingleValue(jsonParser, next, descriptor, deserializationContext);
        }
    }

    private Object enumToInt(Object o) {
        if (o instanceof Internal.EnumLite) {
            return ((Internal.EnumLite) o).getNumber();
        }
        return o;
    }

    private Object parseSingleValue(JsonParser jsonParser,
                                    JsonToken next,
                                    Descriptors.FieldDescriptor descriptor,
                                    DeserializationContext deserializationContext) throws IOException {
        switch (descriptor.getType().getJavaType()) {
            case BOOLEAN:
                validateTokenType(jsonParser, next, VALUE_TRUE, VALUE_FALSE);
                return jsonParser.getValueAsBoolean();
            case INT:
                validateTokenType(jsonParser, next, VALUE_NUMBER_INT, VALUE_STRING);
                if (next == VALUE_NUMBER_INT) {
                    return jsonParser.getValueAsInt();
                } else {
                    try {
                        return Integer.parseInt(jsonParser.getValueAsString());
                    } catch (NumberFormatException e) {
                        throw JsonMappingException.from(
                                jsonParser,
                                "Invalid integer \"" + javaEscape(jsonParser.getValueAsString()) + "\"",
                                e);
                    }
                }
            case LONG:
                validateTokenType(jsonParser, next, VALUE_NUMBER_INT, VALUE_STRING);
                if (next == VALUE_NUMBER_INT) {
                    return jsonParser.getValueAsLong();
                } else {
                    try {
                        return Long.parseLong(jsonParser.getValueAsString());
                    } catch (NumberFormatException e) {
                        throw JsonMappingException.from(
                                jsonParser,
                                "Invalid long integer \"" + javaEscape(jsonParser.getValueAsString()) + "\"",
                                e);
                    }
                }
            case FLOAT:
                validateTokenType(jsonParser, next, VALUE_NUMBER_FLOAT, VALUE_NUMBER_INT, VALUE_STRING);
                if (next == VALUE_STRING) {
                    try {
                        return Float.parseFloat(jsonParser.getValueAsString());
                    } catch (NumberFormatException e) {
                        throw JsonMappingException.from(
                                jsonParser,
                                "Invalid float \"" + javaEscape(jsonParser.getValueAsString()) + "\"",
                                e);
                    }
                } else {
                    return (float) jsonParser.getValueAsDouble();
                }
            case DOUBLE:
                if (next == VALUE_STRING) {
                    try {
                        return Double.parseDouble(jsonParser.getValueAsString());
                    } catch (NumberFormatException e) {
                        throw JsonMappingException.from(
                                jsonParser,
                                "Invalid double \"" + javaEscape(jsonParser.getValueAsString()) + "\"",
                                e);
                    }
                } else {
                    validateTokenType(jsonParser, next, VALUE_NUMBER_FLOAT, VALUE_NUMBER_INT, VALUE_STRING);
                }
                return jsonParser.getValueAsDouble();
            case STRING:
                validateTokenType(jsonParser, next, VALUE_STRING);
                return jsonParser.getValueAsString();
            case BYTE_STRING:
                validateTokenType(jsonParser, next, VALUE_STRING);
                return ByteString.copyFrom(Base64.getDecoder().decode(jsonParser.getValueAsString()));
            case ENUM: {
                var ed = ProtoEnum.getEnumDescriptor(descriptor.getEnumType());
                if (next == VALUE_NUMBER_INT) {
                    var value = jsonParser.getValueAsInt();
                    try {
                        return ed.valueForNumber(value);
                    } catch (IllegalArgumentException e) {
                        if (ProtoFeature.FAIL_ON_UNKNOWN_ENUM.isEnabled(config)) {
                            throw JsonMappingException.from(
                                    jsonParser, "Unknown " + descriptor.getFullName() + " value " + value);
                        } else {
                            return null;
                        }
                    }
                } else if (next == VALUE_STRING) {
                    var value = jsonParser.getValueAsString();
                    try {
                        return ed.valueForName(value);
                    } catch (IllegalArgumentException e) {
                        if (ProtoFeature.FAIL_ON_UNKNOWN_ENUM.isEnabled(config)) {
                            throw JsonMappingException.from(
                                    jsonParser,
                                    "Unknown " + descriptor.getFullName() + " value \"" + javaEscape(value) + "\"");
                        } else {
                            return null;
                        }
                    }
                } else {
                    throw JsonMappingException.from(jsonParser, "Invalid type for enum value " + next);
                }
            }
            case MESSAGE:
                return parseMessage(jsonParser, next, descriptor.getMessageType(), deserializationContext);
        }
        // Untestable code.
        throw new IllegalStateException("Unhandled value type: " + descriptor.getType());
    }

    private static Object parseMapKey(String keyName, Descriptors.FieldDescriptor keyType) throws IOException {
        switch (keyType.getType().getJavaType()) {
            case BOOLEAN:
                return Boolean.parseBoolean(keyName);
            case INT:
                return Integer.parseInt(keyName);
            case LONG:
                return Long.parseLong(keyName);
            case STRING:
                return keyName;
        }
        // Untestable code.
        throw new IOException("Not allowed in JSON key: " + keyType.getType());
    }

    private static boolean isTokenType(JsonToken token, JsonToken type) {
        return token.id() == type.id();
    }

    private static void validateTokenType(JsonParser parser, JsonToken token, JsonToken type) throws IOException {
        if (token.id() != type.id()) {
            throw JsonMappingException.from(
                    parser, "Expected " + name(type) + " but got " + name(token));
        }
    }

    private static void validateTokenType(JsonParser parser, JsonToken token, JsonToken typeA, JsonToken typeB)
            throws IOException {
        if (token.id() != typeA.id() && token.id() != typeB.id()) {
            throw JsonMappingException.from(
                    parser,
                    "Expected " + name(typeA) + " or " + name(typeB) + " but got " + name(token));
        }
    }

    private static void validateTokenType(JsonParser parser,
                                          JsonToken token,
                                          JsonToken typeA,
                                          JsonToken typeB,
                                          JsonToken typeC)
            throws IOException {
        if (token.id() != typeA.id() && token.id() != typeB.id() && token.id() != typeC.id()) {
            throw JsonMappingException.from(
                    parser,
                    "Expected " + name(typeA) + ", " + name(typeB) + " or " + name(typeC) + " but got " + name(token));
        }
    }

    private static java.time.Duration parseDuration(JsonParser jsonParser, String value) throws IOException {
        try {
            return GoogleTypesUtil.parseJavaDurationString(value);
        } catch (IllegalArgumentException e) {
            throw JsonMappingException.from(jsonParser, e.getMessage(), e);
        }
    }

    private static Instant parseIsoDateTime(JsonParser parser, String value) throws IOException {
        try {
            return GoogleTypesUtil.parseJavaTimestampString(value);
        } catch (IllegalArgumentException e) {
            throw JsonMappingException.from(parser, e.getMessage(), e);
        }
    }

    private static String name(JsonToken token) {
        var asStr = token.asString();
        if (asStr != null) {
            return "'" + asStr + "'";
        }
        switch (token) {
            case VALUE_STRING:
                return "\"<string>\"";
            case VALUE_NUMBER_INT:
                return "<integer>";
            case VALUE_NUMBER_FLOAT:
                return "<real>";
        }
        return token.name();
    }
}