GsonMessageReader.java

/*
 * Copyright 2020 Providence Authors
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.proto.gson.sio;

import com.google.gson.JsonParseException;
import com.google.gson.stream.JsonReader;
import com.google.gson.stream.JsonToken;
import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors;
import com.google.protobuf.Duration;
import com.google.protobuf.Message;
import com.google.protobuf.Timestamp;
import net.morimekta.collect.ListBuilder;
import net.morimekta.collect.MapBuilder;
import net.morimekta.collect.UnmodifiableList;
import net.morimekta.collect.UnmodifiableMap;
import net.morimekta.proto.MorimektaOptions;
import net.morimekta.proto.ProtoEnum;
import net.morimekta.proto.ProtoMessageBuilder;
import net.morimekta.proto.gson.ProtoTypeOptions;
import net.morimekta.proto.sio.MessageReader;
import net.morimekta.proto.utils.GoogleTypesUtil;
import net.morimekta.strings.io.LineBufferedReader;

import java.io.Closeable;
import java.io.IOException;
import java.io.Reader;
import java.time.Instant;
import java.util.Base64;
import java.util.List;
import java.util.Map;

import static java.lang.Boolean.parseBoolean;
import static java.lang.Integer.parseInt;
import static java.lang.Long.parseLong;
import static net.morimekta.proto.gson.ProtoTypeOptions.Option.FAIL_ON_NULL_VALUE;
import static net.morimekta.proto.gson.ProtoTypeOptions.Option.FAIL_ON_UNKNOWN_FIELD;
import static net.morimekta.proto.gson.ProtoTypeOptions.Option.IGNORE_UNKNOWN_ANY_TYPE;
import static net.morimekta.proto.gson.ProtoTypeOptions.Option.LENIENT_READER;
import static net.morimekta.proto.gson.ProtoTypeOptions.Value.ANY_TYPE_FIELD_NAME;
import static net.morimekta.proto.utils.FieldUtil.getMapKeyDescriptor;
import static net.morimekta.proto.utils.FieldUtil.getMapValueDescriptor;
import static net.morimekta.proto.utils.JsonNameUtil.ANY_TYPE_FIELDS;
import static net.morimekta.proto.utils.JsonNameUtil.NUMERIC_FIELD_ID;
import static net.morimekta.proto.utils.JsonNameUtil.getJsonFieldMap;
import static net.morimekta.proto.utils.ProtoTypeRegistry.getTypeNameFromTypeUrl;
import static net.morimekta.proto.utils.ProtoTypeRegistry.getTypeUrl;

/**
 * Reader for reading a JSON message from a stream.
 */
public class GsonMessageReader implements MessageReader, Closeable {
    private final JsonReader       reader;
    private final ProtoTypeOptions options;

    /**
     * @param in      Reader to read from.
     * @param options Proto options.
     */
    public GsonMessageReader(Reader in, ProtoTypeOptions options) {
        this(makeJsonReader(in, options.isEnabled(LENIENT_READER)), options);
    }

    /**
     * @param reader  JSON Reader to read from.
     * @param options Proto options.
     */
    public GsonMessageReader(JsonReader reader, ProtoTypeOptions options) {
        this.reader = reader;
        this.options = options;
    }

    @Override
    public void close() throws IOException {
        reader.close();
    }

    @Override
    public Message read(Descriptors.Descriptor descriptor) throws IOException {
        if (reader.peek() == JsonToken.NULL) {
            reader.nextNull();
            return null;
        }
        if (reader.peek() == JsonToken.BEGIN_ARRAY) {
            if (!descriptor.getOptions().getExtension(MorimektaOptions.compact)) {
                throw new JsonParseException("Compact format not allowed for " + descriptor.getFullName() + " at " + reader.getPath());
            }
            ProtoMessageBuilder builder = new ProtoMessageBuilder(descriptor);
            List<Descriptors.FieldDescriptor> fields = descriptor.getFields();

            reader.beginArray();
            int idx = 0;
            while (reader.peek() != JsonToken.END_ARRAY) {
                if (idx >= fields.size()) {
                    if (options.isEnabled(FAIL_ON_NULL_VALUE)) {
                        throw new JsonParseException("Data after last field at " + reader.getPath());
                    }
                    reader.skipValue();
                } else if (reader.peek() == JsonToken.NULL) {
                    reader.nextNull();
                } else {
                    Descriptors.FieldDescriptor field = fields.get(idx);
                    builder.set(field, readSingleValue(field));
                }
                idx++;
            }
            reader.endArray();
            return builder.getMessage().build();
        } else if (reader.peek() == JsonToken.BEGIN_OBJECT) {
            ProtoMessageBuilder builder = new ProtoMessageBuilder(descriptor);
            Map<String, Descriptors.FieldDescriptor> jsonFields = getJsonFieldMap(descriptor);

            reader.beginObject();
            boolean first = true;
            while (reader.peek() != JsonToken.END_OBJECT) {
                String name = reader.nextName();
                if (first) {
                    first = false;
                    if (descriptor.equals(Any.getDescriptor()) &&
                        (ANY_TYPE_FIELDS.contains(name) || options.getValue(ANY_TYPE_FIELD_NAME).equals(name))) {
                        var anyBuilder = (Any.Builder) builder.getMessage();
                        var typeUrl = reader.nextString();
                        var packedType = options.getRegistry().messageTypeByTypeUrl(typeUrl);
                        if (packedType != null) {
                            var packedJsonFields = getJsonFieldMap(packedType);
                            var packedBuilder = new ProtoMessageBuilder(packedType);
                            while (reader.peek() != JsonToken.END_OBJECT) {
                                findFieldAndReadValue(packedBuilder, reader.nextName(), packedType, packedJsonFields);
                            }
                            anyBuilder.setTypeUrl(getTypeUrl(packedType));
                            anyBuilder.setValue(packedBuilder.getMessage().build().toByteString());
                            break;
                        } else if (options.isEnabled(IGNORE_UNKNOWN_ANY_TYPE)) {
                            anyBuilder.setTypeUrl("type.googleapis.com/" + getTypeNameFromTypeUrl(typeUrl));
                            // ignore entire struct.
                            while (reader.peek() != JsonToken.END_OBJECT) {
                                reader.nextName();
                                reader.skipValue();
                            }
                        } else {
                            throw new JsonParseException("Unknown type " + typeUrl + " at " + reader.getPreviousPath());
                        }
                        break;
                    }
                }
                findFieldAndReadValue(builder, name, descriptor, jsonFields);
            }
            reader.endObject();
            return builder.getMessage().build();
        } else {
            if (Timestamp.getDescriptor().equals(descriptor)) {
                var builder = Timestamp.newBuilder();
                if (reader.peek() == JsonToken.NUMBER) {
                    // timestamp as fractional seconds since epoch.
                    var timestampDbl = reader.nextDouble();
                    var timestampS = (long) timestampDbl;
                    var micros = (int) Math.round((timestampDbl - timestampS) * 1_000_000L);
                    builder.setSeconds(timestampS);
                    builder.setNanos(micros * 1000);
                } else if (reader.peek() == JsonToken.STRING) {
                    // ISO date string.
                    var instant = parseIsoDateTime(reader, reader.nextString());
                    builder.setSeconds(instant.getEpochSecond());
                    builder.setNanos(instant.getNano());
                } else {
                    throw new JsonParseException("Expected '{' or '[' or unix timestamp or ISO date or null, but found " + reader.peek() + " at " + reader.getPath());
                }
                return builder.build();
            } else if (Duration.getDescriptor().equals(descriptor)) {
                var builder = Duration.newBuilder();
                if (reader.peek() == JsonToken.NUMBER) {
                    // duration as seconds.
                    var durationDbl = reader.nextDouble();
                    var durationS = (long) durationDbl;
                    var micros = (int) Math.round((durationDbl - durationS) * 1_000_000L);
                    builder.setSeconds(durationS);
                    builder.setNanos(micros * 1000);
                } else if (reader.peek() == JsonToken.STRING) {
                    var duration = parseDuration(reader, reader.nextString());
                    builder.setSeconds(duration.getSeconds());
                    builder.setNanos(duration.getNano());
                } else {
                    throw new JsonParseException("Expected '{' or '[', duration number or string or null, but found " + reader.peek() + " at " + reader.getPath());
                }
                return builder.build();
            }
            throw new JsonParseException("Expected '{' or '[' or null, but found: " + reader.peek() + " at " + reader.getPath());
        }
    }

    private void findFieldAndReadValue(ProtoMessageBuilder builder,
                                       String fieldName,
                                       Descriptors.Descriptor descriptor,
                                       Map<String, Descriptors.FieldDescriptor> jsonFields) throws IOException {
        Descriptors.FieldDescriptor field;
        if (NUMERIC_FIELD_ID.matcher(fieldName).matches()) {
            int number = parseInt(fieldName);
            field = descriptor.findFieldByNumber(number);
            if (field == null) {
                field = options.getRegistry().extensionByScopeAndNumber(descriptor, number);
            }
        } else {
            field = jsonFields.get(fieldName);
            if (field == null) {
                field = options.getRegistry().extensionByScopeAndName(descriptor, fieldName);
            }
        }
        if (field == null) {
            if (options.isEnabled(FAIL_ON_UNKNOWN_FIELD)) {
                throw new JsonParseException("Unknown field at " + reader.getPath());
            }
            reader.skipValue();
        } else {
            Object o = readValue(field);
            if (o != null) {
                builder.set(field, o);
            } else if (options.isEnabled(FAIL_ON_NULL_VALUE)) {
                throw new JsonParseException("Null value at " + reader.getPath());
            }
        }
    }

    private Object readValue(Descriptors.FieldDescriptor descriptor) throws IOException {
        if (reader.peek() == JsonToken.NULL) {
            reader.nextNull();
            return null;
        }
        if (descriptor.isRepeated()) {
            if (descriptor.isMapField()) {
                var keyDescriptor = getMapKeyDescriptor(descriptor);
                var valueDescriptor = getMapValueDescriptor(descriptor);

                MapBuilder<Object, Object> map = UnmodifiableMap.newBuilder();
                reader.beginObject();
                while (reader.peek() != JsonToken.END_OBJECT) {
                    var key = readMapKey(keyDescriptor);
                    var value = readSingleValue(valueDescriptor);
                    if (value != null) {
                        map.put(key, value);
                    } else if (options.isEnabled(FAIL_ON_NULL_VALUE)) {
                        throw new JsonParseException("Null value in map at " + reader.getPreviousPath());
                    }
                }
                reader.endObject();
                return map.build();
            } else {
                ListBuilder<Object> list = UnmodifiableList.newBuilder();
                reader.beginArray();
                while (reader.peek() != JsonToken.END_ARRAY) {
                    list.add(readSingleValue(descriptor));
                }
                reader.endArray();
                return list.build();
            }
        } else {
            return readSingleValue(descriptor);
        }
    }

    private Object readMapKey(Descriptors.FieldDescriptor descriptor) throws IOException {
        var name = reader.nextName();
        switch (descriptor.getType().getJavaType()) {
            case BOOLEAN: {
                return parseBoolean(name);
            }
            case INT: {
                return parseInt(name);
            }
            case LONG: {
                return parseLong(name);
            }
            case STRING: {
                return name;
            }
            default: {
                // Usually not testable.
                throw new IOException("Unhandled map key type: " + descriptor.getType());
            }
        }
    }

    private Object readSingleValue(Descriptors.FieldDescriptor descriptor) throws IOException {
        if (reader.peek() == JsonToken.NULL) {
            return null;
        }
        switch (descriptor.getType().getJavaType()) {
            case BOOLEAN: {
                return reader.nextBoolean();
            }
            case INT: {
                try {
                    return reader.nextInt();
                } catch (NumberFormatException e) {
                    throw new JsonParseException("Invalid int value \"" + reader.nextString() + "\" at " + reader.getPath());
                }
            }
            case LONG: {
                try {
                    return reader.nextLong();
                } catch (NumberFormatException e) {
                    throw new JsonParseException("Invalid long value \"" + reader.nextString() + "\" at " + reader.getPath());
                }
            }
            case FLOAT: {
                try {
                    return (float) reader.nextDouble();
                } catch (NumberFormatException e) {
                    throw new JsonParseException("Invalid float value \"" + reader.nextString() + "\" at " + reader.getPath());
                }
            }
            case DOUBLE: {
                try {
                    return reader.nextDouble();
                } catch (NumberFormatException e) {
                    throw new JsonParseException("Invalid double value \"" + reader.nextString() + "\" at " + reader.getPath());
                }
            }
            case STRING: {
                return reader.nextString();
            }
            case BYTE_STRING: {
                return parseBase64(reader.nextString());
            }
            case ENUM: {
                return readEnum(descriptor.getEnumType());
            }
            case MESSAGE: {
                return read(descriptor.getMessageType());
            }
            default: {
                // Usually not testable.
                throw new JsonParseException("Unhandled value type: " + descriptor.getType());
            }
        }
    }

    private ByteString parseBase64(String value) {
        try {
            return ByteString.copyFrom(Base64.getDecoder().decode(value));
        } catch (IllegalArgumentException e) {
            throw new JsonParseException(e.getMessage() + " at " + reader.getPreviousPath(), e);
        }
    }

    private Object readEnum(Descriptors.EnumDescriptor descriptor) throws IOException {
        ProtoEnum<?> type = ProtoEnum.getEnumDescriptor(descriptor);
        if (reader.peek() == JsonToken.NUMBER) {
            var num = reader.nextInt();
            return type.valueForNumber(num);
        } else {
            var name = reader.nextString();
            return type.valueForName(name);
        }
    }

    private static JsonReader makeJsonReader(Reader in, boolean lenient) {
        JsonReader reader = new JsonReader(new LineBufferedReader(in));
        if (lenient) {
            reader.setLenient(true);
        }
        return reader;
    }

    private static java.time.Duration parseDuration(JsonReader reader, String value) {
        try {
            return GoogleTypesUtil.parseJavaDurationString(value);
        } catch (IllegalArgumentException e) {
            throw new JsonParseException(e.getMessage() + " at " + reader.getPreviousPath(), e);
        }
    }

    private static Instant parseIsoDateTime(JsonReader reader, String value) {
        try {
            return GoogleTypesUtil.parseJavaTimestampString(value);
        } catch (IllegalArgumentException e) {
            throw new JsonParseException(e.getMessage() + " at " + reader.getPreviousPath());
        }
    }
}