TTupleProtocolSerializer.java

/*
 * Copyright 2016 Providence Authors
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.providence.thrift;

import net.morimekta.providence.PApplicationException;
import net.morimekta.providence.PApplicationExceptionType;
import net.morimekta.providence.PEnumBuilder;
import net.morimekta.providence.PEnumValue;
import net.morimekta.providence.PMessage;
import net.morimekta.providence.PMessageBuilder;
import net.morimekta.providence.PMessageOrBuilder;
import net.morimekta.providence.PMessageVariant;
import net.morimekta.providence.PServiceCall;
import net.morimekta.providence.PServiceCallType;
import net.morimekta.providence.PUnion;
import net.morimekta.providence.descriptor.PDescriptor;
import net.morimekta.providence.descriptor.PEnumDescriptor;
import net.morimekta.providence.descriptor.PField;
import net.morimekta.providence.descriptor.PList;
import net.morimekta.providence.descriptor.PMap;
import net.morimekta.providence.descriptor.PMessageDescriptor;
import net.morimekta.providence.descriptor.PRequirement;
import net.morimekta.providence.descriptor.PService;
import net.morimekta.providence.descriptor.PServiceMethod;
import net.morimekta.providence.descriptor.PSet;
import net.morimekta.providence.serializer.Serializer;
import net.morimekta.providence.serializer.SerializerException;
import net.morimekta.util.Binary;
import net.morimekta.util.io.CountingOutputStream;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TMessage;
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.protocol.TTupleProtocol;
import org.apache.thrift.transport.TIOStreamTransport;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;

import javax.annotation.Nonnull;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

/**
 * TProtocol serializer specialized for Tuple protocol, just because thrift
 * decided that this protocol should be written in a different way than other
 * protocols.
 */
public class TTupleProtocolSerializer extends Serializer {
    public static final String MEDIA_TYPE = "application/vnd.apache.thrift.tuple";

    private final boolean          strict;
    private final TProtocolFactory protocolFactory;

    public TTupleProtocolSerializer() {
        this(DEFAULT_STRICT);
    }

    public TTupleProtocolSerializer(boolean strict) {
        this.strict = strict;
        this.protocolFactory = new TTupleProtocol.Factory();
    }

    @Override
    public <Message extends PMessage<Message>> int
    serialize(@Nonnull OutputStream output, @Nonnull PMessageOrBuilder<Message> message) throws IOException {
        CountingOutputStream wrapper = new CountingOutputStream(output);
        TTransport transport = new TIOStreamTransport(wrapper);
        try {
            TTupleProtocol protocol = (TTupleProtocol) protocolFactory.getProtocol(transport);
            writeMessage(message, protocol);
            transport.flush();
            wrapper.flush();
            return wrapper.getByteCount();
        } catch (TException e) {
            throw new SerializerException(e, e.getMessage());
        }
    }

    @Override
    public <Message extends PMessage<Message>>
    int serialize(@Nonnull OutputStream output, @Nonnull PServiceCall<Message> call)
            throws IOException {
        CountingOutputStream wrapper = new CountingOutputStream(output);
        TTransport transport = new TIOStreamTransport(wrapper);
        try {
            TTupleProtocol protocol = (TTupleProtocol) protocolFactory.getProtocol(transport);
            TMessage tm = new TMessage(call.getMethod(), (byte) call.getType().asInteger(), call.getSequence());

            protocol.writeMessageBegin(tm);
            writeMessage(call.getMessage(), protocol);
            protocol.writeMessageEnd();

            transport.flush();
            wrapper.flush();
            return wrapper.getByteCount();
        } catch (TException e) {
            throw new SerializerException(e, e.getMessage());
        }
    }

    @Nonnull
    @Override
    public <Message extends PMessage<Message>>
    Message deserialize(@Nonnull InputStream input, @Nonnull PMessageDescriptor<Message> descriptor) throws IOException {
        try {
            TTransport transport = new TIOStreamTransport(input);
            TTupleProtocol protocol = (TTupleProtocol) protocolFactory.getProtocol(transport);

            return readMessage(protocol, descriptor);
        } catch (TTransportException e) {
            throw new SerializerException(e, "Unable to serialize into transport protocol");
        } catch (TException e) {
            throw new SerializerException(e, "Transport exception in protocol");
        }
    }

    @Nonnull
    @Override
    @SuppressWarnings("unchecked")
    public <Message extends PMessage<Message>>
    PServiceCall<Message> deserialize(@Nonnull InputStream input, @Nonnull PService service)
            throws SerializerException {
        TMessage tm = null;
        PServiceCallType type = null;
        try {
            TTransport transport = new TIOStreamTransport(input);
            TTupleProtocol protocol = (TTupleProtocol) protocolFactory.getProtocol(transport);

            tm = protocol.readMessageBegin();

            type = PServiceCallType.findById((int) tm.type);
            if (type == null) {
                throw new SerializerException("Unknown call type for id " + tm.type);
            } else if (type == PServiceCallType.EXCEPTION) {
                PApplicationException exception = readMessage(protocol, PApplicationException.kDescriptor);
                return new PServiceCall(tm.name, type, tm.seqid, exception);
            }

            PServiceMethod method = service.getMethod(tm.name);
            if (method == null) {
                throw new SerializerException("No such method " + tm.name + " on " + service.getQualifiedName());
            }

            // Null value here should not be possible unless method return
            // types have been tampered with...
            PMessageDescriptor<Message> descriptor = Objects.requireNonNull(
                    isRequestCallType(type) ? method.getRequestType() : method.getResponseType());
            Message message = readMessage(protocol, descriptor);

            protocol.readMessageEnd();

            return new PServiceCall<>(tm.name, type, tm.seqid, message);
        } catch (TTransportException e) {
            throw new SerializerException(e, "Unable to serialize into transport protocol")
                    .setExceptionType(PApplicationExceptionType.findById(e.getType()))
                    .setCallType(type)
                    .setMethodName(tm != null ? tm.name : "")
                    .setSequenceNo(tm != null ? tm.seqid : 0);
        } catch (TException e) {
            throw new SerializerException(e, "Transport exception in protocol")
                    .setExceptionType(PApplicationExceptionType.PROTOCOL_ERROR)
                    .setCallType(type)
                    .setMethodName(tm != null ? tm.name : "")
                    .setSequenceNo(tm != null ? tm.seqid : 0);
        }
    }

    @Override
    public boolean binaryProtocol() {
        return true;
    }

    @Override
    public void verifyEndOfContent(@Nonnull InputStream input) throws IOException {
        try {
            int in = input.read();
            if (in >= 0) {
                throw new SerializerException("More content after end: 0x%02x", in)
                        .setExceptionType(PApplicationExceptionType.PROTOCOL_ERROR);
            }
        } finally {
            input.close();
        }
    }

    @Nonnull
    @Override
    public String mediaType() {
        return MEDIA_TYPE;
    }

    private void writeMessage(PMessageOrBuilder<?> message, TTupleProtocol protocol) throws TException, SerializerException {
        PMessageDescriptor<?> descriptor = message.descriptor();
        if (descriptor.getVariant() == PMessageVariant.UNION) {
            if (((PUnion<?>) message).unionFieldIsSet()) {
                PField fld = ((PUnion<?>) message).unionField();
                protocol.writeI16((short) fld.getId());
                writeTypedValue(message.get(fld.getId()), fld.getDescriptor(), protocol);
            } else {
                throw new SerializerException("Unable to write " + descriptor.getQualifiedName() + " without set union field.");
            }
        } else {
            PField[] fields = descriptor.getFields();
            Arrays.sort(fields, Comparator.comparingInt(PField::getId));
            int numOptionals = countOptionals(fields);
            BitSet optionals = new BitSet();
            if (numOptionals > 0) {
                int optionalPos = 0;
                for (PField fld : fields) {
                    if (fld.getRequirement() != PRequirement.REQUIRED) {
                        if (message.has(fld.getId())) {
                            optionals.set(optionalPos);
                        }
                        ++optionalPos;
                    }
                }
            }

            boolean shouldWriteOptionals = true;
            int optionalPos = 0;

            for (PField fld : fields) {
                if (fld.getRequirement() == PRequirement.REQUIRED) {
                    writeTypedValue(message.get(fld.getId()), fld.getDescriptor(), protocol);
                } else {
                    // Write the optionals bitset at the position of the first
                    // non-required field.
                    if (shouldWriteOptionals) {
                        protocol.writeBitSet(optionals, numOptionals);
                        shouldWriteOptionals = false;
                    }
                    if (optionals.get(optionalPos)) {
                        writeTypedValue(message.get(fld.getId()), fld.getDescriptor(), protocol);
                    }
                    ++optionalPos;
                }
            }
        }
    }

    private int countOptionals(PField[] fields) {
        int numOptionals = 0;
        for (PField fld : fields) {
            if (fld.getRequirement() != PRequirement.REQUIRED) {
                ++numOptionals;
            }
        }
        return numOptionals;
    }

    private <Message extends PMessage<Message>>
    Message readMessage(TTupleProtocol protocol, PMessageDescriptor<Message> descriptor)
            throws SerializerException, TException {
        PMessageBuilder<Message> builder = descriptor.builder();

        if (descriptor.getVariant() == PMessageVariant.UNION) {
            int fieldId = protocol.readI16();
            PField fld = descriptor.findFieldById(fieldId);
            if (fld != null) {
                builder.set(fld.getId(), readTypedValue(fld.getDescriptor(), protocol));
            } else {
                throw new SerializerException("Unable to read unknown union field " + fieldId + " in " + descriptor.getQualifiedName());
            }
        } else {
            PField[] fields = descriptor.getFields();
            int numOptionals = countOptionals(fields);

            BitSet optionals = null;
            int optionalPos = 0;
            for (PField fld : fields) {
                if (fld.getRequirement() == PRequirement.REQUIRED) {
                    builder.set(fld.getId(), readTypedValue(fld.getDescriptor(), protocol));
                } else {
                    if (optionals == null) {
                        optionals = protocol.readBitSet(numOptionals);
                    }
                    if (optionals.get(optionalPos)) {
                        builder.set(fld.getId(), readTypedValue(fld.getDescriptor(), protocol));
                    }
                    ++optionalPos;
                }
            }
        }

        if (strict) {
            try {
                builder.validate();
            } catch (IllegalStateException e) {
                throw new SerializerException(e, e.getMessage());
            }
        }

        return builder.build();
    }

    private Object readTypedValue(PDescriptor type, TTupleProtocol protocol)
            throws TException, SerializerException {
        switch (type.getType()) {
            case BOOL:
                return protocol.readBool();
            case BYTE:
                return protocol.readByte();
            case I16:
                return protocol.readI16();
            case I32:
                return protocol.readI32();
            case I64:
                return protocol.readI64();
            case DOUBLE:
                return protocol.readDouble();
            case BINARY: {
                ByteBuffer buffer = protocol.readBinary();
                return Binary.wrap(buffer.array());
            }
            case STRING:
                return protocol.readString();
            case ENUM: {
                PEnumDescriptor<?> et = (PEnumDescriptor<?>) type;
                PEnumBuilder<?> eb = et.builder();
                final int value = protocol.readI32();
                eb.setById(value);
                if (strict && !eb.valid()) {
                    throw new SerializerException("Invalid enum value " + value + " for " +
                                                  et.getQualifiedName());
                }
                return eb.build();
            }
            case MESSAGE:
                return readMessage(protocol, (PMessageDescriptor<?>) type);
            case LIST: {
                int lSize = protocol.readI32();
                @SuppressWarnings("unchecked")
                PList<Object> lDesc = (PList<Object>) type;
                PDescriptor liDesc = lDesc.itemDescriptor();

                PList.Builder<Object> list = lDesc.builder(lSize);
                for (int i = 0; i < lSize; ++i) {
                    list.add(readTypedValue(liDesc, protocol));
                }

                return list.build();
            }
            case SET: {
                int sSize = protocol.readI32();
                @SuppressWarnings("unchecked")
                PSet<Object> sDesc = (PSet<Object>) type;
                PDescriptor siDesc = sDesc.itemDescriptor();

                PSet.Builder<Object> set = sDesc.builder(sSize);
                for (int i = 0; i < sSize; ++i) {
                    set.add(readTypedValue(siDesc, protocol));
                }

                return set.build();
            }
            case MAP: {
                int mSize = protocol.readI32();
                @SuppressWarnings("unchecked")
                PMap<Object, Object> mDesc = (PMap<Object, Object>) type;
                PDescriptor mkDesc = mDesc.keyDescriptor();
                PDescriptor miDesc = mDesc.itemDescriptor();

                PMap.Builder<Object, Object> map = mDesc.builder(mSize);
                for (int i = 0; i < mSize; ++i) {
                    Object key = readTypedValue(mkDesc, protocol);
                    Object val = readTypedValue(miDesc, protocol);
                    map.put(key, val);
                }

                protocol.readMapEnd();
                return map.build();
            }
            default:
                throw new SerializerException("Unsupported protocol field type: " + type.getType());
        }
    }

    private void writeTypedValue(Object item, PDescriptor type, TTupleProtocol protocol)
            throws TException, SerializerException {
        switch (type.getType()) {
            case BOOL:
                protocol.writeBool((Boolean) item);
                break;
            case BYTE:
                protocol.writeByte((Byte) item);
                break;
            case I16:
                protocol.writeI16((Short) item);
                break;
            case I32:
                protocol.writeI32((Integer) item);
                break;
            case I64:
                protocol.writeI64((Long) item);
                break;
            case DOUBLE:
                protocol.writeDouble((Double) item);
                break;
            case STRING:
                protocol.writeString((String) item);
                break;
            case BINARY:
                protocol.writeBinary(((Binary) item).getByteBuffer());
                break;
            case ENUM:
                PEnumValue<?> value = (PEnumValue<?>) item;
                protocol.writeI32(value.asInteger());
                break;
            case MESSAGE:
                writeMessage((PMessage<?>) item, protocol);
                break;
            case LIST:
                PList<?> lType = (PList<?>) type;
                List<?> list = (List<?>) item;
                protocol.writeI32(list.size());
                for (Object i : list) {
                    writeTypedValue(i, lType.itemDescriptor(), protocol);
                }
                break;
            case SET:
                PSet<?> sType = (PSet<?>) type;
                Set<?> set = (Set<?>) item;
                protocol.writeI32(set.size());
                for (Object i : set) {
                    writeTypedValue(i, sType.itemDescriptor(), protocol);
                }
                break;
            case MAP:
                PMap<?, ?> mType = (PMap<?, ?>) type;
                Map<?, ?> map = (Map<?, ?>) item;
                protocol.writeI32(map.size());

                for (Map.Entry<?, ?> entry : map.entrySet()) {
                    writeTypedValue(entry.getKey(), mType.keyDescriptor(), protocol);
                    writeTypedValue(entry.getValue(), mType.itemDescriptor(), protocol);
                }

                protocol.writeMapEnd();
                break;
            default:
                break;
        }
    }
}