TProtocolSerializer.java

/*
 * Copyright 2016 Providence Authors
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.providence.thrift;

import net.morimekta.providence.PApplicationException;
import net.morimekta.providence.PApplicationExceptionType;
import net.morimekta.providence.PEnumBuilder;
import net.morimekta.providence.PEnumValue;
import net.morimekta.providence.PMessage;
import net.morimekta.providence.PMessageBuilder;
import net.morimekta.providence.PMessageOrBuilder;
import net.morimekta.providence.PServiceCall;
import net.morimekta.providence.PServiceCallType;
import net.morimekta.providence.PType;
import net.morimekta.providence.descriptor.PDescriptor;
import net.morimekta.providence.descriptor.PEnumDescriptor;
import net.morimekta.providence.descriptor.PField;
import net.morimekta.providence.descriptor.PList;
import net.morimekta.providence.descriptor.PMap;
import net.morimekta.providence.descriptor.PMessageDescriptor;
import net.morimekta.providence.descriptor.PPrimitive;
import net.morimekta.providence.descriptor.PService;
import net.morimekta.providence.descriptor.PServiceMethod;
import net.morimekta.providence.descriptor.PSet;
import net.morimekta.providence.serializer.Serializer;
import net.morimekta.providence.serializer.SerializerException;
import net.morimekta.providence.serializer.binary.BinaryType;
import net.morimekta.util.Binary;
import net.morimekta.util.io.CountingOutputStream;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TField;
import org.apache.thrift.protocol.TList;
import org.apache.thrift.protocol.TMap;
import org.apache.thrift.protocol.TMessage;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.protocol.TProtocolUtil;
import org.apache.thrift.protocol.TSet;
import org.apache.thrift.protocol.TStruct;
import org.apache.thrift.transport.TIOStreamTransport;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;

import javax.annotation.Nonnull;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static net.morimekta.providence.serializer.binary.BinaryType.asString;
import static net.morimekta.providence.serializer.binary.BinaryType.forType;

/**
 * @author Stein Eldar Johnsen
 * @since 23.09.15
 */
abstract class TProtocolSerializer extends Serializer {
    private final TProtocolFactory protocolFactory;
    private final boolean          readStrict;
    private final boolean          binary;
    private final String           mediaType;

    public TProtocolSerializer(boolean readStrict, TProtocolFactory protocolFactory,
                               boolean binary, String mediaType) {
        this.readStrict = readStrict;
        this.protocolFactory = protocolFactory;
        this.binary = binary;
        this.mediaType = mediaType;
    }

    @Override
    public boolean binaryProtocol() {
        return binary;
    }

    @Nonnull
    @Override
    public String mediaType() {
        return mediaType;
    }

    @Override
    public <Message extends PMessage<Message>>
    int serialize(@Nonnull OutputStream output, @Nonnull PMessageOrBuilder<Message> message) throws IOException {
        CountingOutputStream wrapper = new CountingOutputStream(output);
        TTransport transport = new TIOStreamTransport(wrapper);
        try {
            TProtocol protocol = protocolFactory.getProtocol(transport);
            writeMessage(message, protocol);
            transport.flush();
            wrapper.flush();
            return wrapper.getByteCount();
        } catch (TException e) {
            throw new SerializerException(e, e.getMessage());
        }
    }

    @Override
    public <Message extends PMessage<Message>>
    int serialize(@Nonnull OutputStream output, @Nonnull PServiceCall<Message> call)
            throws IOException {
        CountingOutputStream wrapper = new CountingOutputStream(output);
        TTransport transport = new TIOStreamTransport(wrapper);
        try {
            TProtocol protocol = protocolFactory.getProtocol(transport);
            TMessage tm = new TMessage(call.getMethod(), (byte) call.getType().asInteger(), call.getSequence());

            protocol.writeMessageBegin(tm);
            writeMessage(call.getMessage(), protocol);
            protocol.writeMessageEnd();

            transport.flush();
            wrapper.flush();
            return wrapper.getByteCount();
        } catch (TException e) {
            throw new SerializerException(e, e.getMessage());
        }
    }

    @Nonnull
    @Override
    public <Message extends PMessage<Message>>
    Message deserialize(@Nonnull InputStream input, @Nonnull PMessageDescriptor<Message> descriptor) throws IOException {
        try {
            TTransport transport = new TIOStreamTransport(input);
            TProtocol protocol = protocolFactory.getProtocol(transport);

            return readMessage(protocol, descriptor);
        } catch (TTransportException e) {
            throw new SerializerException(e, "Unable to serialize into transport protocol");
        } catch (TException e) {
            throw new SerializerException(e, "Transport exception in protocol");
        }
    }

    @Nonnull
    @Override
    @SuppressWarnings("unchecked")
    public <Message extends PMessage<Message>>
    PServiceCall<Message> deserialize(@Nonnull InputStream input, @Nonnull PService service)
            throws SerializerException {
        PServiceCallType type = null;
        TMessage tm = null;
        try {
            TTransport transport = new TIOStreamTransport(input);
            TProtocol protocol = protocolFactory.getProtocol(transport);

            tm = protocol.readMessageBegin();

            type = PServiceCallType.findById((int) tm.type);
            if (type == null) {
                throw new SerializerException("Unknown call type for id " + tm.type)
                        .setExceptionType(PApplicationExceptionType.INVALID_MESSAGE_TYPE);
            } else if (type == PServiceCallType.EXCEPTION) {
                PApplicationException exception = readMessage(protocol, PApplicationException.kDescriptor);
                return new PServiceCall(tm.name, type, tm.seqid, exception);
            }

            PServiceMethod method = service.getMethod(tm.name);
            if (method == null) {
                throw new SerializerException("No such method " + tm.name + " on " + service.getQualifiedName())
                        .setExceptionType(PApplicationExceptionType.UNKNOWN_METHOD);
            }

            @SuppressWarnings("unchecked")
            PMessageDescriptor<Message> descriptor = isRequestCallType(type) ? method.getRequestType() : method.getResponseType();

            Message message = readMessage(protocol, descriptor);

            protocol.readMessageEnd();

            return new PServiceCall<>(tm.name, type, tm.seqid, message);
        } catch (TTransportException e) {
            throw new SerializerException(e, e.getMessage())
                    .setExceptionType(PApplicationExceptionType.findById(e.getType()))
                    .setCallType(type)
                    .setSequenceNo(tm != null ? tm.seqid : 0)
                    .setMethodName(tm != null ? tm.name : null);
        } catch (TException e) {
            throw new SerializerException(e, e.getMessage())
                    .setExceptionType(PApplicationExceptionType.PROTOCOL_ERROR)
                    .setCallType(type)
                    .setSequenceNo(tm != null ? tm.seqid : 0)
                    .setMethodName(tm != null ? tm.name : null);
        } catch (SerializerException e) {
            e.setMethodName(tm.name)
             .setSequenceNo(tm.seqid)
             .setCallType(type);
            throw e;
        }
    }

    private void writeMessage(PMessageOrBuilder<?> message, TProtocol protocol) throws TException, SerializerException {
        PMessageDescriptor<?> type = message.descriptor();

        protocol.writeStructBegin(new TStruct(message.descriptor()
                                                     .getQualifiedName()));

        for (PField field : type.getFields()) {
            if (!message.has(field.getId())) {
                continue;
            }

            protocol.writeFieldBegin(new TField(field.getName(),
                                                forType(field.getDescriptor().getType()),
                                                (short) field.getId()));

            writeTypedValue(message.get(field.getId()), field.getDescriptor(), protocol);

            protocol.writeFieldEnd();
        }

        protocol.writeFieldStop();
        protocol.writeStructEnd();
    }

    private <Message extends PMessage<Message>>
    Message readMessage(TProtocol protocol, PMessageDescriptor<Message> descriptor)
            throws SerializerException, TException {
        TField f;

        PMessageBuilder<Message> builder = descriptor.builder();
        protocol.readStructBegin();  // ignored.
        while ((f = protocol.readFieldBegin()) != null) {
            if (f.type == BinaryType.STOP) {
                break;
            }

            PField field;
            // f.name is never fulled out, rely on f.id being correct.
            field = descriptor.findFieldById(f.id);
            if (field != null) {
                if (f.type != forType(field.getDescriptor().getType())) {
                    throw new SerializerException("Incompatible serialized type " + asString(f.type) +
                                                  " for field " + field.getName() +
                                                  ", expected " + asString(forType(field.getDescriptor().getType())));
                }

                Object value = readTypedValue(f.type, field.getDescriptor(), protocol, true);
                if (value != null) {
                    builder.set(field.getId(), value);
                }
            } else {
                TProtocolUtil.skip(protocol, f.type);
            }

            protocol.readFieldEnd();
        }
        protocol.readStructEnd();

        if (readStrict) {
            try {
                builder.validate();
            } catch (IllegalStateException e) {
                throw new SerializerException(e, e.getMessage());
            }
        }

        return builder.build();
    }

    @SuppressWarnings("unchecked")
    private Object readTypedValue(byte tType, PDescriptor type, TProtocol protocol, boolean allowNull)
            throws TException, SerializerException {
        if (tType != forType(type.getType())) {
            throw new SerializerException("Expected type " +
                                          asString(forType(type.getType())) +
                                          " but found " +
                                          asString(tType));
        }
        switch (tType) {
            case BinaryType.BOOL:
                return protocol.readBool();
            case BinaryType.BYTE:
                return protocol.readByte();
            case BinaryType.I16:
                return protocol.readI16();
            case BinaryType.I32:
                if (PType.ENUM == type.getType()) {
                    PEnumDescriptor<?> et = (PEnumDescriptor<?>) type;
                    PEnumBuilder<?> eb = et.builder();
                    int value = protocol.readI32();
                    eb.setById(value);
                    if (!eb.valid() && !allowNull) {
                        throw new SerializerException("Invalid enum value " + value + " for " +
                                                      et.getQualifiedName());
                    }
                    return eb.build();
                } else {
                    return protocol.readI32();
                }
            case BinaryType.I64:
                return protocol.readI64();
            case BinaryType.DOUBLE:
                return protocol.readDouble();
            case BinaryType.STRING:
                if (type == PPrimitive.BINARY) {
                    ByteBuffer buffer = protocol.readBinary();
                    return Binary.wrap(buffer.array());
                }
                return protocol.readString();
            case BinaryType.STRUCT:
                return readMessage(protocol, (PMessageDescriptor<?>) type);
            case BinaryType.LIST:
                TList listInfo = protocol.readListBegin();
                PList<Object> lDesc = (PList<Object>) type;
                PDescriptor liDesc = lDesc.itemDescriptor();

                PList.Builder<Object> list = lDesc.builder(listInfo.size);
                for (int i = 0; i < listInfo.size; ++i) {
                    list.add(readTypedValue(listInfo.elemType, liDesc, protocol, false));
                }

                protocol.readListEnd();
                return list.build();
            case BinaryType.SET:
                TSet setInfo = protocol.readSetBegin();
                PSet<Object> sDesc = (PSet<Object>) type;
                PDescriptor siDesc = sDesc.itemDescriptor();

                PSet.Builder<Object> set = sDesc.builder(setInfo.size);
                for (int i = 0; i < setInfo.size; ++i) {
                    set.add(readTypedValue(setInfo.elemType, siDesc, protocol, false));
                }

                protocol.readSetEnd();
                return set.build();
            case BinaryType.MAP:
                TMap mapInfo = protocol.readMapBegin();
                PMap<Object, Object> mDesc = (PMap<Object, Object>) type;
                PDescriptor mkDesc = mDesc.keyDescriptor();
                PDescriptor miDesc = mDesc.itemDescriptor();

                PMap.Builder<Object, Object> map = mDesc.builder(mapInfo.size);
                for (int i = 0; i < mapInfo.size; ++i) {
                    Object key = readTypedValue(mapInfo.keyType, mkDesc, protocol, false);
                    Object val = readTypedValue(mapInfo.valueType, miDesc, protocol, false);
                    map.put(key, val);
                }

                protocol.readMapEnd();
                return map.build();
            default:
                throw new SerializerException("Unsupported protocol field type: " + tType);
        }
    }

    private void writeTypedValue(Object item, PDescriptor type, TProtocol protocol)
            throws TException, SerializerException {
        switch (type.getType()) {
            case BOOL:
                protocol.writeBool((Boolean) item);
                break;
            case BYTE:
                protocol.writeByte((Byte) item);
                break;
            case I16:
                protocol.writeI16((Short) item);
                break;
            case I32:
                protocol.writeI32((Integer) item);
                break;
            case I64:
                protocol.writeI64((Long) item);
                break;
            case DOUBLE:
                protocol.writeDouble((Double) item);
                break;
            case STRING:
                protocol.writeString((String) item);
                break;
            case BINARY:
                protocol.writeBinary(((Binary) item).getByteBuffer());
                break;
            case ENUM:
                PEnumValue<?> value = (PEnumValue<?>) item;
                protocol.writeI32(value.asInteger());
                break;
            case MESSAGE:
                writeMessage((PMessage<?>) item, protocol);
                break;
            case LIST:
                PList<?> lType = (PList<?>) type;
                List<?> list = (List<?>) item;
                TList listInfo = new TList(forType(lType.itemDescriptor().getType()), list.size());
                protocol.writeListBegin(listInfo);
                for (Object i : list) {
                    writeTypedValue(i, lType.itemDescriptor(), protocol);
                }
                protocol.writeListEnd();
                break;
            case SET:
                PSet<?> sType = (PSet<?>) type;
                Set<?> set = (Set<?>) item;
                TSet setInfo = new TSet(forType(sType.itemDescriptor().getType()), set.size());
                protocol.writeSetBegin(setInfo);
                for (Object i : set) {
                    writeTypedValue(i, sType.itemDescriptor(), protocol);
                }
                protocol.writeSetEnd();
                break;
            case MAP:
                PMap<?, ?> mType = (PMap<?, ?>) type;
                Map<?, ?> map = (Map<?, ?>) item;
                protocol.writeMapBegin(new TMap(forType(mType.keyDescriptor().getType()),
                                                forType(mType.itemDescriptor().getType()),
                                                map.size()));

                for (Map.Entry<?, ?> entry : map.entrySet()) {
                    writeTypedValue(entry.getKey(), mType.keyDescriptor(), protocol);
                    writeTypedValue(entry.getValue(), mType.itemDescriptor(), protocol);
                }

                protocol.writeMapEnd();
                break;
            default:
                break;
        }
    }
}