FastBinarySerializer.java

/*
 * Copyright 2016 Providence Authors
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.providence.serializer;

import net.morimekta.providence.PApplicationException;
import net.morimekta.providence.PApplicationExceptionType;
import net.morimekta.providence.PEnumBuilder;
import net.morimekta.providence.PEnumValue;
import net.morimekta.providence.PMessage;
import net.morimekta.providence.PMessageBuilder;
import net.morimekta.providence.PMessageOrBuilder;
import net.morimekta.providence.PServiceCall;
import net.morimekta.providence.PServiceCallType;
import net.morimekta.providence.PType;
import net.morimekta.providence.PUnion;
import net.morimekta.providence.descriptor.PContainer;
import net.morimekta.providence.descriptor.PDescriptor;
import net.morimekta.providence.descriptor.PEnumDescriptor;
import net.morimekta.providence.descriptor.PField;
import net.morimekta.providence.descriptor.PList;
import net.morimekta.providence.descriptor.PMap;
import net.morimekta.providence.descriptor.PMessageDescriptor;
import net.morimekta.providence.descriptor.PService;
import net.morimekta.providence.descriptor.PServiceMethod;
import net.morimekta.providence.descriptor.PSet;
import net.morimekta.util.Binary;
import net.morimekta.util.io.LittleEndianBinaryReader;
import net.morimekta.util.io.LittleEndianBinaryWriter;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.Map;

import static java.nio.charset.StandardCharsets.UTF_8;

/**
 * Compact binary serializer. This uses a pretty compact binary format
 * while being optimized for fewer operations during read and write.
 * <p>
 * Documentation: <a href="http://www.morimekta.net/providence/serializer-fast-binary.html">Fast Binary Serialization Format</a>
 * with IDL and explanation.
 */
public class FastBinarySerializer extends Serializer {
    public static final String MEDIA_TYPE = "application/vnd.morimekta.providence.binary";

    private final boolean readStrict;

    /**
     * Construct a serializer instance.
     */
    public FastBinarySerializer() {
        this(DEFAULT_STRICT);
    }

    /**
     * Construct a serializer instance.
     *
     * @param readStrict If serializer should fail on unknown input data.
     */
    public FastBinarySerializer(boolean readStrict) {
        this.readStrict = readStrict;
    }

    @Override
    public <Message extends PMessage<Message>>
    int serialize(@Nonnull OutputStream os, @Nonnull PMessageOrBuilder<Message> message) throws IOException {
        LittleEndianBinaryWriter out = new LittleEndianBinaryWriter(os);
        return writeMessage(out, message);
    }

    @Override
    public <Message extends PMessage<Message>>
    int serialize(@Nonnull OutputStream os, @Nonnull PServiceCall<Message> call)
            throws IOException {
        LittleEndianBinaryWriter out = new LittleEndianBinaryWriter(os);
        byte[] method = call.getMethod().getBytes(UTF_8);
        int len = out.writeVarint(method.length << 3 | call.getType().asInteger());
        len += method.length;
        out.write(method);
        len += out.writeVarint(call.getSequence());
        len += writeMessage(out, call.getMessage());
        return len;
    }

    @Nonnull
    @Override
    public <Message extends PMessage<Message>>
    Message deserialize(@Nonnull InputStream is,
                        @Nonnull PMessageDescriptor<Message> descriptor)
            throws IOException {
        LittleEndianBinaryReader in = new LittleEndianBinaryReader(is);
        return readMessage(in, descriptor);
    }

    @Nonnull
    @Override
    @SuppressWarnings("unchecked")
    public <Message extends PMessage<Message>>
    PServiceCall<Message> deserialize(@Nonnull InputStream is, @Nonnull PService service)
            throws SerializerException {
        String methodName = null;
        int sequence = 0;
        PServiceCallType type = null;
        try {
            LittleEndianBinaryReader in = new LittleEndianBinaryReader(is);
            // Max method name length: 255 chars.
            int tag = in.readIntVarint();
            int len = tag >>> 3;
            int typeKey = tag & 0x07;

            methodName = new String(in.expectBytes(len), UTF_8);
            sequence = in.readIntVarint();
            type = PServiceCallType.findById(typeKey);

            if (type == null) {
                throw new SerializerException("Invalid call type " + typeKey)
                        .setExceptionType(PApplicationExceptionType.INVALID_MESSAGE_TYPE);
            } else if (type == PServiceCallType.EXCEPTION) {
                PApplicationException ex = readMessage(in, PApplicationException.kDescriptor);
                return (PServiceCall<Message>) new PServiceCall<>(methodName, type, sequence, ex);
            }

            PServiceMethod method = service.getMethod(methodName);
            if (method == null) {
                throw new SerializerException("No such method %s on %s",
                                              methodName,
                                              service.getQualifiedName())
                        .setExceptionType(PApplicationExceptionType.UNKNOWN_METHOD);
            }

            @SuppressWarnings("unchecked")
            PMessageDescriptor<Message> descriptor = (PMessageDescriptor<Message>) (
                    isRequestCallType(type) ? method.getRequestType() : method.getResponseType());
            if (descriptor == null) {
                throw new SerializerException("No such %s descriptor for %s",
                                              isRequestCallType(type) ? "request" : "response",
                                              service.getQualifiedName())
                        .setExceptionType(PApplicationExceptionType.UNKNOWN_METHOD);
            }

            Message message = readMessage(in, descriptor);
            return new PServiceCall<>(methodName, type, sequence, message);
        } catch (SerializerException e) {
            throw new SerializerException(e)
                    .setCallType(type)
                    .setMethodName(methodName)
                    .setSequenceNo(sequence);
        } catch (IOException e) {
            throw new SerializerException(e, e.getMessage())
                    .setCallType(type)
                    .setMethodName(methodName)
                    .setSequenceNo(sequence);
        }
    }

    @Override
    public boolean binaryProtocol() {
        return true;
    }

    @Override
    public void verifyEndOfContent(@Nonnull InputStream input) throws IOException {
        try {
            int in = input.read();
            if (in >= 0) {
                throw new SerializerException("More content after end: 0x%02x", in)
                        .setExceptionType(PApplicationExceptionType.PROTOCOL_ERROR);
            }
        } finally {
            input.close();
        }
    }

    @Nonnull
    @Override
    public String mediaType() {
        return MEDIA_TYPE;
    }

    // --- MESSAGE ---

    private <Message extends PMessage<Message>>
    int writeMessage(LittleEndianBinaryWriter out, PMessageOrBuilder<Message> message)
            throws IOException {
        int len = 0;
        if (message instanceof PUnion) {
            if (((PUnion) message).unionFieldIsSet()) {
                PField field = ((PUnion) message).unionField();
                len += writeFieldValue(out, field.getId(), field.getDescriptor(), message.get(field.getId()));
            }
        } else {
            for (PField field : message.descriptor()
                                          .getFields()) {
                if (message.has(field.getId())) {
                    len += writeFieldValue(out, field.getId(), field.getDescriptor(), message.get(field.getId()));
                }
            }
        }
        // write STOP field.
        return len + out.writeVarint(STOP);
    }

    private void consumeMessage(LittleEndianBinaryReader in) throws IOException {
        int tag;
        while ((tag = in.readIntVarint()) != STOP) {
            int type = tag & 0x07;
            readFieldValue(in, type, null);
        }
    }

    @Nonnull
    private <Message extends PMessage<Message>>
    Message readMessage(@Nonnull LittleEndianBinaryReader in,
                        @Nonnull PMessageDescriptor<Message> descriptor)
            throws IOException {
        int tag;
        PMessageBuilder<Message> builder = descriptor.builder();
        while ((tag = in.readIntVarint()) != STOP) {
            int id = tag >>> 3;
            int type = tag & 0x07;
            PField field = descriptor.findFieldById(id);
            if (field != null) {
                Object value = readFieldValue(in, type, field.getDescriptor());
                builder.set(field.getId(), value);
            } else {
                readFieldValue(in, type, null);
            }
        }

        if (readStrict) {
            try {
                builder.validate();
            } catch (IllegalStateException e) {
                throw new SerializerException(e, e.getMessage());
            }
        }

        return builder.build();
    }

    // --- FIELD VALUE ---

    @SuppressWarnings("unchecked")
    private int writeFieldValue(LittleEndianBinaryWriter out, int key, PDescriptor descriptor, Object value)
            throws IOException {
        switch (descriptor.getType()) {
            case VOID: {
                return out.writeVarint(key << 3 | TRUE);
            }
            case BOOL: {
                return out.writeVarint(key << 3 | ((Boolean) value ? TRUE : NONE));
            }
            case BYTE: {
                int len = out.writeVarint(key << 3 | VARINT);
                return len + out.writeZigzag((byte) value);
            }
            case I16: {
                int len = out.writeVarint(key << 3 | VARINT);
                return len + out.writeZigzag((short) value);
            }
            case I32: {
                int len = out.writeVarint(key << 3 | VARINT);
                return len + out.writeZigzag((int) value);
            }
            case I64: {
                int len = out.writeVarint(key << 3 | VARINT);
                return len + out.writeZigzag((long) value);
            }
            case DOUBLE: {
                int len = out.writeVarint(key << 3 | FIXED_64);
                return len + out.writeDouble((Double) value);
            }
            case STRING: {
                byte[] bytes = ((String) value).getBytes(StandardCharsets.UTF_8);
                int len = out.writeVarint(key << 3 | BINARY);
                len += out.writeVarint(bytes.length);
                out.write(bytes);
                return len + bytes.length;
            }
            case BINARY: {
                Binary bytes = (Binary) value;
                int len = out.writeVarint(key << 3 | BINARY);
                len += out.writeVarint(bytes.length());
                bytes.write(out);
                return len + bytes.length();
            }
            case ENUM: {
                int len = out.writeVarint(key << 3 | VARINT);
                return len + out.writeZigzag(((PEnumValue) value).asInteger());
            }
            case MESSAGE: {
                int len = out.writeVarint(key << 3 | MESSAGE);
                return len + writeMessage(out, (PMessage) value);
            }
            case MAP:
            case SET:
            case LIST: {
                int len = out.writeVarint(key << 3 | COLLECTION);
                return len + writeContainerEntry(out, COLLECTION, descriptor, value);
            }
            default:
                throw new Error("Unreachable code reached");
        }
    }


    @SuppressWarnings("unchecked")
    private int writeContainerEntry(LittleEndianBinaryWriter out, int typeid, PDescriptor descriptor, Object value)
            throws IOException {
        switch (typeid) {
            case VARINT: {
                if (value instanceof Boolean) {
                    return out.writeVarint(((Boolean) value ? 1 : 0));
                } else if (value instanceof Number) {
                    return out.writeZigzag(((Number) value).longValue());
                } else if (value instanceof PEnumValue) {
                    return out.writeZigzag(((PEnumValue) value).asInteger());
                } else {
                    throw new SerializerException("Impossible");
                }
            }
            case FIXED_64: {
                return out.writeDouble((Double) value);
            }
            case BINARY: {
                if (value instanceof CharSequence) {
                    byte[] bytes = ((String) value).getBytes(StandardCharsets.UTF_8);
                    int len = out.writeVarint(bytes.length);
                    out.write(bytes);
                    return len + bytes.length;
                } else if (value instanceof Binary) {
                    Binary bytes = (Binary) value;
                    int len = out.writeVarint(bytes.length());
                    bytes.write(out);
                    return len + bytes.length();
                } else {
                    throw new SerializerException("Impossible");
                }
            }
            case MESSAGE: {
                return writeMessage(out, (PMessage) value);
            }
            case COLLECTION: {
                if (value instanceof Map) {
                    Map<Object, Object> map = (Map<Object, Object>) value;
                    PMap<?, ?> desc = (PMap<?, ?>) descriptor;

                    int ktype = itemType(desc.keyDescriptor());
                    int vtype = itemType(desc.itemDescriptor());

                    int len = out.writeVarint(map.size() * 2);
                    len += out.writeVarint(ktype << 3 | vtype);
                    for (Map.Entry<Object, Object> entry : map.entrySet()) {
                        len += writeContainerEntry(out, ktype, desc.keyDescriptor(), entry.getKey());
                        len += writeContainerEntry(out, vtype, desc.itemDescriptor(), entry.getValue());
                    }
                    return len;
                } else if (value instanceof Collection){
                    Collection<Object> coll = (Collection<Object>) value;
                    PContainer<?> desc = (PContainer<?>) descriptor;
                    int vtype = itemType(desc.itemDescriptor());

                    int len = out.writeVarint(coll.size());
                    len    += out.writeVarint(vtype);
                    for (Object item : coll) {
                        len += writeContainerEntry(out, vtype, desc.itemDescriptor(), item);
                    }
                    return len;
                } else {
                    throw new SerializerException("Impossible");
                }
            }
            default:
                throw new SerializerException("Impossible");
        }
    }

    @SuppressWarnings("unchecked")
    private Object readFieldValue(@Nonnull LittleEndianBinaryReader in,
                                  int type,
                                  @Nullable PDescriptor descriptor)
            throws IOException {
        switch (type) {
            case NONE:
                return Boolean.FALSE;
            case TRUE:
                return Boolean.TRUE;
            case VARINT: {
                if (descriptor == null) {
                    in.readLongVarint();
                    return null;
                }
                switch (descriptor.getType()) {
                    case BOOL:
                        return in.readIntVarint() != 0;
                    case BYTE:
                        return (byte) in.readIntZigzag();
                    case I16:
                        return (short) in.readIntZigzag();
                    case I32:
                        return in.readIntZigzag();
                    case I64:
                        return in.readLongZigzag();
                    case ENUM: {
                        PEnumBuilder<?> builder = ((PEnumDescriptor<?>) descriptor).builder();
                        builder.setById(in.readIntZigzag());
                        return builder.build();
                    }
                    default: {
                        throw new SerializerException("Impossible");
                    }
                }
            }
            case FIXED_64:
                return in.expectDouble();
            case BINARY: {
                int len = in.readIntVarint();
                byte[] data = in.expectBytes(len);
                if (descriptor != null) {
                    switch (descriptor.getType()) {
                        case STRING:
                            return new String(data, StandardCharsets.UTF_8);
                        case BINARY:
                            return Binary.wrap(data);
                        default:
                            throw new SerializerException("Impossible");
                    }
                } else {
                    return null;
                }
            }
            case MESSAGE:
                if (descriptor == null) {
                    consumeMessage(in);
                    return null;
                }
                return readMessage(in, (PMessageDescriptor<?>) descriptor);
            case COLLECTION:
                if (descriptor == null) {
                    final int len = in.readIntVarint();
                    final int tag = in.readIntVarint();
                    final int vtype = tag & 0x07;
                    final int ktype = tag > 0x07 ? tag >>> 3 : vtype;
                    for (int i = 0; i < len; ++i) {
                        if (i % 2 == 0) {
                            readFieldValue(in, ktype, null);
                        } else {
                            readFieldValue(in, vtype, null);
                        }
                    }
                    return null;
                } else if (descriptor.getType() == PType.MAP) {
                    PMap<Object, Object> ct = (PMap<Object, Object>) descriptor;
                    PDescriptor kt = ct.keyDescriptor();
                    PDescriptor vt = ct.itemDescriptor();

                    final int len = in.readIntVarint();
                    final int tag = in.readIntVarint();
                    final int vtype = tag & 0x07;
                    final int ktype = tag > 0x07 ? tag >>> 3 : vtype;
                    PMap.Builder<Object, Object> out = ct.builder(len / 2);
                    for (int i = 0; i < len; ++i, ++i) {
                        Object key = readFieldValue(in, ktype, kt);
                        Object value = readFieldValue(in, vtype, vt);
                        if (key != null && value != null) {
                            out.put(key, value);
                        } else if (readStrict) {
                            if (key == null) {
                                throw new SerializerException("Unknown enum key in map");
                            }
                            throw new SerializerException("Null value in map");
                        }
                    }
                    return out.build();
                } else if (descriptor.getType() == PType.LIST) {
                    PList<Object> ct = (PList<Object>) descriptor;
                    PDescriptor it = ct.itemDescriptor();
                    final int len = in.readIntVarint();
                    final int vtype = in.readIntVarint() & 0x07;
                    PList.Builder<Object> out = ct.builder(len);
                    for (int i = 0; i < len; ++i) {
                        Object item = readFieldValue(in, vtype, it);
                        if (item != null) {
                            out.add(item);
                        } else if (readStrict) {
                            throw new SerializerException("Null value in list");
                        }
                    }
                    return out.build();
                } else if (descriptor.getType() == PType.SET) {
                    PSet<Object> ct = (PSet<Object>) descriptor;
                    PDescriptor it = ct.itemDescriptor();
                    final int len = in.readIntVarint();
                    final int vtype = in.readIntVarint() & 0x07;
                    PSet.Builder<Object> out = ct.builder(len);
                    for (int i = 0; i < len; ++i) {
                        Object item = readFieldValue(in, vtype, it);
                        if (item != null) {
                            out.add(item);
                        } else if (readStrict) {
                            throw new SerializerException("Null value in set");
                        }
                    }
                    return out.build();
                } else {
                    throw new SerializerException("Type " + descriptor.getType() +
                                                  " not compatible with collection data");
                }
            default:
                throw new Error("Unreachable code reached");
        }
    }

    private static int itemType(PDescriptor descriptor) {
        switch (descriptor.getType()) {
            case BOOL:
            case BYTE:
            case I16:
            case I32:
            case I64:
            case ENUM:
                return VARINT;
            case DOUBLE:
                return FIXED_64;
            case BINARY:
            case STRING:
                return BINARY;
            case MESSAGE:
                return MESSAGE;
            case SET:
            case LIST:
            case MAP:
                return COLLECTION;
            default:
                throw new Error("Unreachable code reached");
        }
    }

    private static final int STOP       = 0x00;
    private static final int NONE       = 0x01;  // 0, false, empty.
    private static final int TRUE       = 0x02;  // 1, true.
    private static final int VARINT     = 0x03;  // -> zigzag encoded base-128 number (byte, i16, i32, i64).
    private static final int FIXED_64   = 0x04;  // -> double
    private static final int BINARY     = 0x05;  // -> varint len + binary data.
    private static final int MESSAGE    = 0x06;  // -> messages, terminated with field-ID 0.
    private static final int COLLECTION = 0x07;  // -> varint len + N * (tag + field).
}