FastBinarySerializer.java
/*
* Copyright 2016 Providence Authors
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package net.morimekta.providence.serializer;
import net.morimekta.providence.PApplicationException;
import net.morimekta.providence.PApplicationExceptionType;
import net.morimekta.providence.PEnumBuilder;
import net.morimekta.providence.PEnumValue;
import net.morimekta.providence.PMessage;
import net.morimekta.providence.PMessageBuilder;
import net.morimekta.providence.PMessageOrBuilder;
import net.morimekta.providence.PServiceCall;
import net.morimekta.providence.PServiceCallType;
import net.morimekta.providence.PType;
import net.morimekta.providence.PUnion;
import net.morimekta.providence.descriptor.PContainer;
import net.morimekta.providence.descriptor.PDescriptor;
import net.morimekta.providence.descriptor.PEnumDescriptor;
import net.morimekta.providence.descriptor.PField;
import net.morimekta.providence.descriptor.PList;
import net.morimekta.providence.descriptor.PMap;
import net.morimekta.providence.descriptor.PMessageDescriptor;
import net.morimekta.providence.descriptor.PService;
import net.morimekta.providence.descriptor.PServiceMethod;
import net.morimekta.providence.descriptor.PSet;
import net.morimekta.util.Binary;
import net.morimekta.util.io.LittleEndianBinaryReader;
import net.morimekta.util.io.LittleEndianBinaryWriter;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.Map;
import static java.nio.charset.StandardCharsets.UTF_8;
/**
* Compact binary serializer. This uses a pretty compact binary format
* while being optimized for fewer operations during read and write.
* <p>
* Documentation: <a href="http://www.morimekta.net/providence/serializer-fast-binary.html">Fast Binary Serialization Format</a>
* with IDL and explanation.
*/
public class FastBinarySerializer extends Serializer {
public static final String MEDIA_TYPE = "application/vnd.morimekta.providence.binary";
private final boolean readStrict;
/**
* Construct a serializer instance.
*/
public FastBinarySerializer() {
this(DEFAULT_STRICT);
}
/**
* Construct a serializer instance.
*
* @param readStrict If serializer should fail on unknown input data.
*/
public FastBinarySerializer(boolean readStrict) {
this.readStrict = readStrict;
}
@Override
public <Message extends PMessage<Message>>
int serialize(@Nonnull OutputStream os, @Nonnull PMessageOrBuilder<Message> message) throws IOException {
LittleEndianBinaryWriter out = new LittleEndianBinaryWriter(os);
return writeMessage(out, message);
}
@Override
public <Message extends PMessage<Message>>
int serialize(@Nonnull OutputStream os, @Nonnull PServiceCall<Message> call)
throws IOException {
LittleEndianBinaryWriter out = new LittleEndianBinaryWriter(os);
byte[] method = call.getMethod().getBytes(UTF_8);
int len = out.writeVarint(method.length << 3 | call.getType().asInteger());
len += method.length;
out.write(method);
len += out.writeVarint(call.getSequence());
len += writeMessage(out, call.getMessage());
return len;
}
@Nonnull
@Override
public <Message extends PMessage<Message>>
Message deserialize(@Nonnull InputStream is,
@Nonnull PMessageDescriptor<Message> descriptor)
throws IOException {
LittleEndianBinaryReader in = new LittleEndianBinaryReader(is);
return readMessage(in, descriptor);
}
@Nonnull
@Override
@SuppressWarnings("unchecked")
public <Message extends PMessage<Message>>
PServiceCall<Message> deserialize(@Nonnull InputStream is, @Nonnull PService service)
throws SerializerException {
String methodName = null;
int sequence = 0;
PServiceCallType type = null;
try {
LittleEndianBinaryReader in = new LittleEndianBinaryReader(is);
// Max method name length: 255 chars.
int tag = in.readIntVarint();
int len = tag >>> 3;
int typeKey = tag & 0x07;
methodName = new String(in.expectBytes(len), UTF_8);
sequence = in.readIntVarint();
type = PServiceCallType.findById(typeKey);
if (type == null) {
throw new SerializerException("Invalid call type " + typeKey)
.setExceptionType(PApplicationExceptionType.INVALID_MESSAGE_TYPE);
} else if (type == PServiceCallType.EXCEPTION) {
PApplicationException ex = readMessage(in, PApplicationException.kDescriptor);
return (PServiceCall<Message>) new PServiceCall<>(methodName, type, sequence, ex);
}
PServiceMethod method = service.getMethod(methodName);
if (method == null) {
throw new SerializerException("No such method %s on %s",
methodName,
service.getQualifiedName())
.setExceptionType(PApplicationExceptionType.UNKNOWN_METHOD);
}
@SuppressWarnings("unchecked")
PMessageDescriptor<Message> descriptor = (PMessageDescriptor<Message>) (
isRequestCallType(type) ? method.getRequestType() : method.getResponseType());
if (descriptor == null) {
throw new SerializerException("No such %s descriptor for %s",
isRequestCallType(type) ? "request" : "response",
service.getQualifiedName())
.setExceptionType(PApplicationExceptionType.UNKNOWN_METHOD);
}
Message message = readMessage(in, descriptor);
return new PServiceCall<>(methodName, type, sequence, message);
} catch (SerializerException e) {
throw new SerializerException(e)
.setCallType(type)
.setMethodName(methodName)
.setSequenceNo(sequence);
} catch (IOException e) {
throw new SerializerException(e, e.getMessage())
.setCallType(type)
.setMethodName(methodName)
.setSequenceNo(sequence);
}
}
@Override
public boolean binaryProtocol() {
return true;
}
@Override
public void verifyEndOfContent(@Nonnull InputStream input) throws IOException {
try {
int in = input.read();
if (in >= 0) {
throw new SerializerException("More content after end: 0x%02x", in)
.setExceptionType(PApplicationExceptionType.PROTOCOL_ERROR);
}
} finally {
input.close();
}
}
@Nonnull
@Override
public String mediaType() {
return MEDIA_TYPE;
}
// --- MESSAGE ---
private <Message extends PMessage<Message>>
int writeMessage(LittleEndianBinaryWriter out, PMessageOrBuilder<Message> message)
throws IOException {
int len = 0;
if (message instanceof PUnion) {
if (((PUnion) message).unionFieldIsSet()) {
PField field = ((PUnion) message).unionField();
len += writeFieldValue(out, field.getId(), field.getDescriptor(), message.get(field.getId()));
}
} else {
for (PField field : message.descriptor()
.getFields()) {
if (message.has(field.getId())) {
len += writeFieldValue(out, field.getId(), field.getDescriptor(), message.get(field.getId()));
}
}
}
// write STOP field.
return len + out.writeVarint(STOP);
}
private void consumeMessage(LittleEndianBinaryReader in) throws IOException {
int tag;
while ((tag = in.readIntVarint()) != STOP) {
int type = tag & 0x07;
readFieldValue(in, type, null);
}
}
@Nonnull
private <Message extends PMessage<Message>>
Message readMessage(@Nonnull LittleEndianBinaryReader in,
@Nonnull PMessageDescriptor<Message> descriptor)
throws IOException {
int tag;
PMessageBuilder<Message> builder = descriptor.builder();
while ((tag = in.readIntVarint()) != STOP) {
int id = tag >>> 3;
int type = tag & 0x07;
PField field = descriptor.findFieldById(id);
if (field != null) {
Object value = readFieldValue(in, type, field.getDescriptor());
builder.set(field.getId(), value);
} else {
readFieldValue(in, type, null);
}
}
if (readStrict) {
try {
builder.validate();
} catch (IllegalStateException e) {
throw new SerializerException(e, e.getMessage());
}
}
return builder.build();
}
// --- FIELD VALUE ---
@SuppressWarnings("unchecked")
private int writeFieldValue(LittleEndianBinaryWriter out, int key, PDescriptor descriptor, Object value)
throws IOException {
switch (descriptor.getType()) {
case VOID: {
return out.writeVarint(key << 3 | TRUE);
}
case BOOL: {
return out.writeVarint(key << 3 | ((Boolean) value ? TRUE : NONE));
}
case BYTE: {
int len = out.writeVarint(key << 3 | VARINT);
return len + out.writeZigzag((byte) value);
}
case I16: {
int len = out.writeVarint(key << 3 | VARINT);
return len + out.writeZigzag((short) value);
}
case I32: {
int len = out.writeVarint(key << 3 | VARINT);
return len + out.writeZigzag((int) value);
}
case I64: {
int len = out.writeVarint(key << 3 | VARINT);
return len + out.writeZigzag((long) value);
}
case DOUBLE: {
int len = out.writeVarint(key << 3 | FIXED_64);
return len + out.writeDouble((Double) value);
}
case STRING: {
byte[] bytes = ((String) value).getBytes(StandardCharsets.UTF_8);
int len = out.writeVarint(key << 3 | BINARY);
len += out.writeVarint(bytes.length);
out.write(bytes);
return len + bytes.length;
}
case BINARY: {
Binary bytes = (Binary) value;
int len = out.writeVarint(key << 3 | BINARY);
len += out.writeVarint(bytes.length());
bytes.write(out);
return len + bytes.length();
}
case ENUM: {
int len = out.writeVarint(key << 3 | VARINT);
return len + out.writeZigzag(((PEnumValue) value).asInteger());
}
case MESSAGE: {
int len = out.writeVarint(key << 3 | MESSAGE);
return len + writeMessage(out, (PMessage) value);
}
case MAP:
case SET:
case LIST: {
int len = out.writeVarint(key << 3 | COLLECTION);
return len + writeContainerEntry(out, COLLECTION, descriptor, value);
}
default:
throw new Error("Unreachable code reached");
}
}
@SuppressWarnings("unchecked")
private int writeContainerEntry(LittleEndianBinaryWriter out, int typeid, PDescriptor descriptor, Object value)
throws IOException {
switch (typeid) {
case VARINT: {
if (value instanceof Boolean) {
return out.writeVarint(((Boolean) value ? 1 : 0));
} else if (value instanceof Number) {
return out.writeZigzag(((Number) value).longValue());
} else if (value instanceof PEnumValue) {
return out.writeZigzag(((PEnumValue) value).asInteger());
} else {
throw new SerializerException("Impossible");
}
}
case FIXED_64: {
return out.writeDouble((Double) value);
}
case BINARY: {
if (value instanceof CharSequence) {
byte[] bytes = ((String) value).getBytes(StandardCharsets.UTF_8);
int len = out.writeVarint(bytes.length);
out.write(bytes);
return len + bytes.length;
} else if (value instanceof Binary) {
Binary bytes = (Binary) value;
int len = out.writeVarint(bytes.length());
bytes.write(out);
return len + bytes.length();
} else {
throw new SerializerException("Impossible");
}
}
case MESSAGE: {
return writeMessage(out, (PMessage) value);
}
case COLLECTION: {
if (value instanceof Map) {
Map<Object, Object> map = (Map<Object, Object>) value;
PMap<?, ?> desc = (PMap<?, ?>) descriptor;
int ktype = itemType(desc.keyDescriptor());
int vtype = itemType(desc.itemDescriptor());
int len = out.writeVarint(map.size() * 2);
len += out.writeVarint(ktype << 3 | vtype);
for (Map.Entry<Object, Object> entry : map.entrySet()) {
len += writeContainerEntry(out, ktype, desc.keyDescriptor(), entry.getKey());
len += writeContainerEntry(out, vtype, desc.itemDescriptor(), entry.getValue());
}
return len;
} else if (value instanceof Collection){
Collection<Object> coll = (Collection<Object>) value;
PContainer<?> desc = (PContainer<?>) descriptor;
int vtype = itemType(desc.itemDescriptor());
int len = out.writeVarint(coll.size());
len += out.writeVarint(vtype);
for (Object item : coll) {
len += writeContainerEntry(out, vtype, desc.itemDescriptor(), item);
}
return len;
} else {
throw new SerializerException("Impossible");
}
}
default:
throw new SerializerException("Impossible");
}
}
@SuppressWarnings("unchecked")
private Object readFieldValue(@Nonnull LittleEndianBinaryReader in,
int type,
@Nullable PDescriptor descriptor)
throws IOException {
switch (type) {
case NONE:
return Boolean.FALSE;
case TRUE:
return Boolean.TRUE;
case VARINT: {
if (descriptor == null) {
in.readLongVarint();
return null;
}
switch (descriptor.getType()) {
case BOOL:
return in.readIntVarint() != 0;
case BYTE:
return (byte) in.readIntZigzag();
case I16:
return (short) in.readIntZigzag();
case I32:
return in.readIntZigzag();
case I64:
return in.readLongZigzag();
case ENUM: {
PEnumBuilder<?> builder = ((PEnumDescriptor<?>) descriptor).builder();
builder.setById(in.readIntZigzag());
return builder.build();
}
default: {
throw new SerializerException("Impossible");
}
}
}
case FIXED_64:
return in.expectDouble();
case BINARY: {
int len = in.readIntVarint();
byte[] data = in.expectBytes(len);
if (descriptor != null) {
switch (descriptor.getType()) {
case STRING:
return new String(data, StandardCharsets.UTF_8);
case BINARY:
return Binary.wrap(data);
default:
throw new SerializerException("Impossible");
}
} else {
return null;
}
}
case MESSAGE:
if (descriptor == null) {
consumeMessage(in);
return null;
}
return readMessage(in, (PMessageDescriptor<?>) descriptor);
case COLLECTION:
if (descriptor == null) {
final int len = in.readIntVarint();
final int tag = in.readIntVarint();
final int vtype = tag & 0x07;
final int ktype = tag > 0x07 ? tag >>> 3 : vtype;
for (int i = 0; i < len; ++i) {
if (i % 2 == 0) {
readFieldValue(in, ktype, null);
} else {
readFieldValue(in, vtype, null);
}
}
return null;
} else if (descriptor.getType() == PType.MAP) {
PMap<Object, Object> ct = (PMap<Object, Object>) descriptor;
PDescriptor kt = ct.keyDescriptor();
PDescriptor vt = ct.itemDescriptor();
final int len = in.readIntVarint();
final int tag = in.readIntVarint();
final int vtype = tag & 0x07;
final int ktype = tag > 0x07 ? tag >>> 3 : vtype;
PMap.Builder<Object, Object> out = ct.builder(len / 2);
for (int i = 0; i < len; ++i, ++i) {
Object key = readFieldValue(in, ktype, kt);
Object value = readFieldValue(in, vtype, vt);
if (key != null && value != null) {
out.put(key, value);
} else if (readStrict) {
if (key == null) {
throw new SerializerException("Unknown enum key in map");
}
throw new SerializerException("Null value in map");
}
}
return out.build();
} else if (descriptor.getType() == PType.LIST) {
PList<Object> ct = (PList<Object>) descriptor;
PDescriptor it = ct.itemDescriptor();
final int len = in.readIntVarint();
final int vtype = in.readIntVarint() & 0x07;
PList.Builder<Object> out = ct.builder(len);
for (int i = 0; i < len; ++i) {
Object item = readFieldValue(in, vtype, it);
if (item != null) {
out.add(item);
} else if (readStrict) {
throw new SerializerException("Null value in list");
}
}
return out.build();
} else if (descriptor.getType() == PType.SET) {
PSet<Object> ct = (PSet<Object>) descriptor;
PDescriptor it = ct.itemDescriptor();
final int len = in.readIntVarint();
final int vtype = in.readIntVarint() & 0x07;
PSet.Builder<Object> out = ct.builder(len);
for (int i = 0; i < len; ++i) {
Object item = readFieldValue(in, vtype, it);
if (item != null) {
out.add(item);
} else if (readStrict) {
throw new SerializerException("Null value in set");
}
}
return out.build();
} else {
throw new SerializerException("Type " + descriptor.getType() +
" not compatible with collection data");
}
default:
throw new Error("Unreachable code reached");
}
}
private static int itemType(PDescriptor descriptor) {
switch (descriptor.getType()) {
case BOOL:
case BYTE:
case I16:
case I32:
case I64:
case ENUM:
return VARINT;
case DOUBLE:
return FIXED_64;
case BINARY:
case STRING:
return BINARY;
case MESSAGE:
return MESSAGE;
case SET:
case LIST:
case MAP:
return COLLECTION;
default:
throw new Error("Unreachable code reached");
}
}
private static final int STOP = 0x00;
private static final int NONE = 0x01; // 0, false, empty.
private static final int TRUE = 0x02; // 1, true.
private static final int VARINT = 0x03; // -> zigzag encoded base-128 number (byte, i16, i32, i64).
private static final int FIXED_64 = 0x04; // -> double
private static final int BINARY = 0x05; // -> varint len + binary data.
private static final int MESSAGE = 0x06; // -> messages, terminated with field-ID 0.
private static final int COLLECTION = 0x07; // -> varint len + N * (tag + field).
}