UrlEncodedSerializer.java
package net.morimekta.providence.serializer;
import net.morimekta.providence.PApplicationException;
import net.morimekta.providence.PEnumValue;
import net.morimekta.providence.PMessage;
import net.morimekta.providence.PMessageBuilder;
import net.morimekta.providence.PMessageOrBuilder;
import net.morimekta.providence.PServiceCall;
import net.morimekta.providence.PServiceCallType;
import net.morimekta.providence.PType;
import net.morimekta.providence.descriptor.PContainer;
import net.morimekta.providence.descriptor.PDescriptor;
import net.morimekta.providence.descriptor.PEnumDescriptor;
import net.morimekta.providence.descriptor.PField;
import net.morimekta.providence.descriptor.PMessageDescriptor;
import net.morimekta.providence.descriptor.PService;
import net.morimekta.providence.descriptor.PServiceMethod;
import net.morimekta.util.Binary;
import net.morimekta.util.Strings;
import net.morimekta.util.io.CountingOutputStream;
import net.morimekta.util.io.IOUtils;
import net.morimekta.util.json.JsonException;
import net.morimekta.util.json.JsonToken;
import net.morimekta.util.json.JsonTokenizer;
import net.morimekta.util.json.JsonWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.StringReader;
import java.net.URLDecoder;
import java.net.URLEncoder;
import java.util.Base64;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Locale.US;
import static net.morimekta.providence.PApplicationExceptionType.BAD_SEQUENCE_ID;
import static net.morimekta.providence.PApplicationExceptionType.INVALID_MESSAGE_TYPE;
import static net.morimekta.providence.PApplicationExceptionType.MISSING_RESULT;
import static net.morimekta.providence.PApplicationExceptionType.PROTOCOL_ERROR;
import static net.morimekta.providence.PApplicationExceptionType.UNKNOWN_METHOD;
/**
* Serializer for handling URL encoded form data, also commonly used in
* open web protocols like OAuth2. It will operate on the content as one
* entry per line, and stop whenever a newline or end of input is encountered.
* Content not simply serializable to url-encoded string will be first JSON
* serialized, then URL-encoded.
*/
public class UrlEncodedSerializer extends Serializer {
public static final String MEDIA_TYPE = "application/x-www-form-urlencoded";
public static final String MEDIA_TYPE_MULTIPART = "multipart/form-data";
private static final Logger LOGGER = LoggerFactory.getLogger(UrlEncodedSerializer.class);
private static final JsonSerializer JSON = new JsonSerializer().named();
private static final Base64.Encoder B64E = Base64.getUrlEncoder().withoutPadding();
@Override
public <Message extends PMessage<Message>> int serialize(
@Nonnull OutputStream output,
@Nonnull PMessageOrBuilder<Message> message) throws IOException {
CountingOutputStream counting = new CountingOutputStream(output);
PrintWriter writer = new PrintWriter(new OutputStreamWriter(counting, UTF_8));
boolean first = true;
for (PField field : message.descriptor().getFields()) {
if (message.has(field.getId())) {
if (first) first = false;
else writer.print('&');
writeField(writer, field, message.get(field.getId()));
}
}
writer.flush();
return counting.getByteCount();
}
@Override
public <Message extends PMessage<Message>> int serialize(
@Nonnull OutputStream output,
@Nonnull PServiceCall<Message> call) throws IOException {
CountingOutputStream counting = new CountingOutputStream(output);
PrintWriter writer = new PrintWriter(new OutputStreamWriter(counting, UTF_8));
writer.print("method=");
writer.print(call.getMethod());
writer.print("&type=");
writer.print(call.getType().asString().toLowerCase());
writer.print("&seq=");
writer.print(call.getSequence());
writer.print("&message=");
writeFieldValue(writer, call.getMessage().descriptor(), call.getMessage());
writer.flush();
return counting.getByteCount();
}
public <Message extends PMessage<Message>> String serialize(
@Nonnull PMessageOrBuilder<Message> message) throws IOException {
ByteArrayOutputStream out= new ByteArrayOutputStream();
serialize(out, message);
return out.toString(UTF_8.toString());
}
private void writeField(@Nonnull PrintWriter output,
@Nonnull PField field,
@Nonnull Object value) throws IOException {
if (field.getType() == PType.LIST ||
field.getType() == PType.SET) {
Collection c = (Collection) value;
PContainer pc = (PContainer) field.getDescriptor();
if (c.isEmpty()) {
output.print(field.getName());
output.print("=[]");
return;
} else {
boolean first = true;
for (Object o : c) {
if (first) first = false;
else {
output.print('&');
}
output.print(field.getName());
output.print("[]=");
writeFieldValue(output, pc.itemDescriptor(), o);
}
return;
}
}
output.print(field.getName());
output.print('=');
writeFieldValue(output, field.getDescriptor(), value);
}
private void writeFieldValue(@Nonnull PrintWriter output,
@Nonnull PDescriptor descriptor,
@Nonnull Object value) throws IOException {
switch (descriptor.getType()) {
case VOID:
output.print(true);
break;
case BOOL:
case BYTE:
case I16:
case I32:
case I64:
case DOUBLE:
output.print(value);
break;
case STRING:
output.print(URLEncoder.encode(value.toString(), UTF_8.toString()));
break;
case BINARY:
output.print(B64E.encodeToString(((Binary) value).get()));
break;
case ENUM:
output.print(((PEnumValue) value).asString());
break;
case LIST:
case SET:
case MAP:
case MESSAGE: {
ByteArrayOutputStream out = new ByteArrayOutputStream();
JsonWriter writer = new JsonWriter(out);
JSON.appendTypedValue(writer, descriptor, value);
writer.flush();
output.print(URLEncoder.encode(out.toString(UTF_8.toString()), UTF_8.toString()));
break;
}
}
}
@Nonnull
@Override
@SuppressWarnings("unchecked")
public <Message extends PMessage<Message>> Message deserialize(
@Nonnull InputStream input,
@Nonnull PMessageDescriptor<Message> descriptor) throws IOException {
PMessageBuilder<Message> builder = descriptor.builder();
String line = IOUtils.readString(input, '\n');
String[] parts = line.split("&");
Map<PField, Set> buildingSet = new HashMap<>();
for (String part : parts) {
if (part.isEmpty()) continue;
String[] kv = part.split("=", 2);
String key = kv[0];
String value = kv.length == 1 ? "true" : URLDecoder.decode(kv[1], UTF_8.toString());
if (key.endsWith("[]")) {
key = key.substring(0, key.length() - 2);
PField field = descriptor.findFieldByName(key);
if (field != null) {
if (field.getType() == PType.SET) {
PContainer pc = (PContainer) field.getDescriptor();
// Not very elegant, but preserves the ordering from the
// serialized set as expected.
buildingSet.computeIfAbsent(field, f -> new LinkedHashSet())
.add(parseFieldValue(pc.itemDescriptor(), value));
} else if (field.getType() == PType.LIST) {
PContainer pc = (PContainer) field.getDescriptor();
builder.addTo(field.getId(), parseFieldValue(pc.itemDescriptor(), value));
} else {
LOGGER.info("Not a container type: {} for {}=...", field.getDescriptor().getQualifiedName(), kv[0]);
}
}
continue;
}
PField field = descriptor.findFieldByName(key);
if (field != null) {
builder.set(field.getId(), parseFieldValue(field.getDescriptor(), value));
}
}
for (Map.Entry<PField, Set> set : buildingSet.entrySet()) {
builder.set(set.getKey(), set.getValue());
}
return builder.build();
}
@Nonnull
@Override
@SuppressWarnings("unchecked")
public <Message extends PMessage<Message>> PServiceCall<Message> deserialize(
@Nonnull InputStream input,
@Nonnull PService service) throws IOException {
String line = IOUtils.readString(input, '\n');
String[] parts = line.split("&");
PServiceMethod method = null;
int sequence = 0;
PServiceCallType type = PServiceCallType.CALL;
String message = null;
for (String part : parts) {
if (part.startsWith("method=")) {
method = service.getMethod(part.substring(7));
if (method == null) {
throw new PApplicationException("No such method " + part.substring(7), UNKNOWN_METHOD);
}
} else if (part.startsWith("type=")) {
type = Optional.ofNullable(PServiceCallType.findByName(part.substring(5).toUpperCase(US)))
.orElseThrow(() -> new PApplicationException(
"Bad call type: '" + part.substring(5) + "'", INVALID_MESSAGE_TYPE));
} else if (part.startsWith("seq=")) {
try {
sequence = Integer.parseInt(part.substring(4));
} catch (NumberFormatException e) {
throw new PApplicationException("Bad sequence " + part.substring(4), BAD_SEQUENCE_ID);
}
} else if (part.startsWith("message=")) {
message = URLDecoder.decode(part.substring(8), UTF_8.toString());
}
}
if (method == null) {
throw new PApplicationException("No method in request", PROTOCOL_ERROR);
}
if (message == null) {
throw new PApplicationException("No message in request", MISSING_RESULT);
}
PMessageDescriptor<Message> md = (PMessageDescriptor<Message>) (
type == PServiceCallType.EXCEPTION ?
PApplicationException.kDescriptor :
(type == PServiceCallType.CALL || type == PServiceCallType.ONEWAY) ?
method.getRequestType() :
method.getResponseType());
if (md == null) {
throw new PApplicationException("No type for " + type + " on " + method, INVALID_MESSAGE_TYPE);
}
return new PServiceCall<>(method.getName(), type, sequence, (Message) parseFieldValue(md, message));
}
private Object parseFieldValue(PDescriptor descriptor, String value) throws IOException {
switch (descriptor.getType()) {
case VOID:
return Boolean.TRUE;
case BOOL:
return Boolean.parseBoolean(value);
case BYTE:
return Byte.parseByte(value);
case I16:
return Short.parseShort(value);
case I32:
return Integer.parseInt(value);
case I64:
return Long.parseLong(value);
case DOUBLE:
return Double.parseDouble(value);
case STRING:
return value;
case BINARY:
return Binary.fromBase64(value);
case ENUM:
PEnumDescriptor ed = (PEnumDescriptor) descriptor;
if (Strings.isInteger(value)) {
return ed.findById(Integer.parseInt(value));
}
return ed.findByName(value);
case LIST:
case SET:
case MAP:
case MESSAGE:
default:
try {
StringReader reader = new StringReader(value);
JsonTokenizer tokenizer = new JsonTokenizer(reader);
JsonToken first = tokenizer.expect("anything");
return JSON.parseTypedValue(first, tokenizer, descriptor, false);
} catch (JsonException e) {
throw new IOException(e.getMessage(), e);
}
}
}
@Override
public boolean binaryProtocol() {
return false;
}
@Override
public void verifyEndOfContent(@Nonnull InputStream input) throws IOException {
String content = IOUtils.readString(input).trim();
if (!content.isEmpty()) {
throw new IOException("After end of url-encoded content: '" + content + "'");
}
}
@Nonnull
@Override
public String mediaType() {
return MEDIA_TYPE;
}
}