UrlEncodedSerializer.java

package net.morimekta.providence.serializer;

import net.morimekta.providence.PApplicationException;
import net.morimekta.providence.PEnumValue;
import net.morimekta.providence.PMessage;
import net.morimekta.providence.PMessageBuilder;
import net.morimekta.providence.PMessageOrBuilder;
import net.morimekta.providence.PServiceCall;
import net.morimekta.providence.PServiceCallType;
import net.morimekta.providence.PType;
import net.morimekta.providence.descriptor.PContainer;
import net.morimekta.providence.descriptor.PDescriptor;
import net.morimekta.providence.descriptor.PEnumDescriptor;
import net.morimekta.providence.descriptor.PField;
import net.morimekta.providence.descriptor.PMessageDescriptor;
import net.morimekta.providence.descriptor.PService;
import net.morimekta.providence.descriptor.PServiceMethod;
import net.morimekta.util.Binary;
import net.morimekta.util.Strings;
import net.morimekta.util.io.CountingOutputStream;
import net.morimekta.util.io.IOUtils;
import net.morimekta.util.json.JsonException;
import net.morimekta.util.json.JsonToken;
import net.morimekta.util.json.JsonTokenizer;
import net.morimekta.util.json.JsonWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.StringReader;
import java.net.URLDecoder;
import java.net.URLEncoder;
import java.util.Base64;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Locale.US;
import static net.morimekta.providence.PApplicationExceptionType.BAD_SEQUENCE_ID;
import static net.morimekta.providence.PApplicationExceptionType.INVALID_MESSAGE_TYPE;
import static net.morimekta.providence.PApplicationExceptionType.MISSING_RESULT;
import static net.morimekta.providence.PApplicationExceptionType.PROTOCOL_ERROR;
import static net.morimekta.providence.PApplicationExceptionType.UNKNOWN_METHOD;

/**
 * Serializer for handling URL encoded form data, also commonly used in
 * open web protocols like OAuth2. It will operate on the content as one
 * entry per line, and stop whenever a newline or end of input is encountered.
 * Content not simply serializable to url-encoded string will be first JSON
 * serialized, then URL-encoded.
 */
public class UrlEncodedSerializer extends Serializer {
    public static final String MEDIA_TYPE           = "application/x-www-form-urlencoded";
    public static final String MEDIA_TYPE_MULTIPART = "multipart/form-data";

    private static final Logger         LOGGER = LoggerFactory.getLogger(UrlEncodedSerializer.class);
    private static final JsonSerializer JSON   = new JsonSerializer().named();
    private static final Base64.Encoder B64E   = Base64.getUrlEncoder().withoutPadding();

    @Override
    public <Message extends PMessage<Message>> int serialize(
            @Nonnull OutputStream output,
            @Nonnull PMessageOrBuilder<Message> message) throws IOException {
        CountingOutputStream counting = new CountingOutputStream(output);
        PrintWriter writer = new PrintWriter(new OutputStreamWriter(counting, UTF_8));
        boolean first = true;
        for (PField field : message.descriptor().getFields()) {
            if (message.has(field.getId())) {
                if (first) first = false;
                else writer.print('&');
                writeField(writer, field, message.get(field.getId()));
            }
        }
        writer.flush();
        return counting.getByteCount();
    }

    @Override
    public <Message extends PMessage<Message>> int serialize(
            @Nonnull OutputStream output,
            @Nonnull PServiceCall<Message> call) throws IOException {
        CountingOutputStream counting = new CountingOutputStream(output);
        PrintWriter writer = new PrintWriter(new OutputStreamWriter(counting, UTF_8));
        writer.print("method=");
        writer.print(call.getMethod());
        writer.print("&type=");
        writer.print(call.getType().asString().toLowerCase());
        writer.print("&seq=");
        writer.print(call.getSequence());
        writer.print("&message=");
        writeFieldValue(writer, call.getMessage().descriptor(), call.getMessage());
        writer.flush();
        return counting.getByteCount();
    }

    public <Message extends PMessage<Message>> String serialize(
            @Nonnull PMessageOrBuilder<Message> message) throws IOException {
        ByteArrayOutputStream out= new ByteArrayOutputStream();
        serialize(out, message);
        return out.toString(UTF_8.toString());
    }

    private void writeField(@Nonnull PrintWriter output,
                            @Nonnull PField field,
                            @Nonnull Object value) throws IOException {
        if (field.getType() == PType.LIST ||
            field.getType() == PType.SET) {
            Collection c  = (Collection) value;
            PContainer pc = (PContainer) field.getDescriptor();
            if (c.isEmpty()) {
                output.print(field.getName());
                output.print("=[]");
                return;
            } else {
                boolean first = true;
                for (Object o : c) {
                    if (first) first = false;
                    else {
                        output.print('&');
                    }
                    output.print(field.getName());
                    output.print("[]=");
                    writeFieldValue(output, pc.itemDescriptor(), o);
                }
                return;
            }
        }
        output.print(field.getName());
        output.print('=');
        writeFieldValue(output, field.getDescriptor(), value);
    }

    private void writeFieldValue(@Nonnull PrintWriter output,
                                 @Nonnull PDescriptor descriptor,
                                 @Nonnull Object value) throws IOException {
        switch (descriptor.getType()) {
            case VOID:
                output.print(true);
                break;
            case BOOL:
            case BYTE:
            case I16:
            case I32:
            case I64:
            case DOUBLE:
                output.print(value);
                break;
            case STRING:
                output.print(URLEncoder.encode(value.toString(), UTF_8.toString()));
                break;
            case BINARY:
                output.print(B64E.encodeToString(((Binary) value).get()));
                break;
            case ENUM:
                output.print(((PEnumValue) value).asString());
                break;
            case LIST:
            case SET:
            case MAP:
            case MESSAGE: {
                ByteArrayOutputStream out = new ByteArrayOutputStream();
                JsonWriter writer = new JsonWriter(out);
                JSON.appendTypedValue(writer, descriptor, value);
                writer.flush();
                output.print(URLEncoder.encode(out.toString(UTF_8.toString()), UTF_8.toString()));
                break;
            }
        }
    }

    @Nonnull
    @Override
    @SuppressWarnings("unchecked")
    public <Message extends PMessage<Message>> Message deserialize(
            @Nonnull InputStream input,
            @Nonnull PMessageDescriptor<Message> descriptor) throws IOException {
        PMessageBuilder<Message> builder     = descriptor.builder();
        String                   line        = IOUtils.readString(input, '\n');
        String[]                 parts       = line.split("&");
        Map<PField, Set>         buildingSet = new HashMap<>();
        for (String part : parts) {
            if (part.isEmpty()) continue;

            String[] kv = part.split("=", 2);
            String key = kv[0];
            String value = kv.length == 1 ? "true" : URLDecoder.decode(kv[1], UTF_8.toString());
            if (key.endsWith("[]")) {
                key = key.substring(0, key.length() - 2);
                PField field = descriptor.findFieldByName(key);
                if (field != null) {
                    if (field.getType() == PType.SET) {
                        PContainer pc = (PContainer) field.getDescriptor();
                        // Not very elegant, but preserves the ordering from the
                        // serialized set as expected.
                        buildingSet.computeIfAbsent(field, f -> new LinkedHashSet())
                                   .add(parseFieldValue(pc.itemDescriptor(), value));
                    } else if (field.getType() == PType.LIST) {
                        PContainer pc = (PContainer) field.getDescriptor();
                        builder.addTo(field.getId(), parseFieldValue(pc.itemDescriptor(), value));
                    } else {
                        LOGGER.info("Not a container type: {} for {}=...", field.getDescriptor().getQualifiedName(), kv[0]);
                    }
                }
                continue;
            }

            PField field = descriptor.findFieldByName(key);
            if (field != null) {
                builder.set(field.getId(), parseFieldValue(field.getDescriptor(), value));
            }
        }
        for (Map.Entry<PField, Set> set : buildingSet.entrySet()) {
            builder.set(set.getKey(), set.getValue());
        }
        return builder.build();
    }

    @Nonnull
    @Override
    @SuppressWarnings("unchecked")
    public <Message extends PMessage<Message>> PServiceCall<Message> deserialize(
            @Nonnull InputStream input,
            @Nonnull PService service) throws IOException {
        String line = IOUtils.readString(input, '\n');
        String[] parts = line.split("&");

        PServiceMethod method = null;
        int sequence = 0;
        PServiceCallType type = PServiceCallType.CALL;
        String message = null;
        for (String part : parts) {
            if (part.startsWith("method=")) {
                method = service.getMethod(part.substring(7));
                if (method == null) {
                    throw new PApplicationException("No such method " + part.substring(7), UNKNOWN_METHOD);
                }
            } else if (part.startsWith("type=")) {
                type = Optional.ofNullable(PServiceCallType.findByName(part.substring(5).toUpperCase(US)))
                               .orElseThrow(() -> new PApplicationException(
                                       "Bad call type: '" + part.substring(5) + "'", INVALID_MESSAGE_TYPE));
            } else if (part.startsWith("seq=")) {
                try {
                    sequence = Integer.parseInt(part.substring(4));
                } catch (NumberFormatException e) {
                    throw new PApplicationException("Bad sequence " + part.substring(4), BAD_SEQUENCE_ID);
                }
            } else if (part.startsWith("message=")) {
                message = URLDecoder.decode(part.substring(8), UTF_8.toString());
            }
        }

        if (method == null) {
            throw new PApplicationException("No method in request", PROTOCOL_ERROR);
        }
        if (message == null) {
            throw new PApplicationException("No message in request", MISSING_RESULT);
        }

        PMessageDescriptor<Message> md = (PMessageDescriptor<Message>) (
                type == PServiceCallType.EXCEPTION ?
                PApplicationException.kDescriptor :
                (type == PServiceCallType.CALL || type == PServiceCallType.ONEWAY) ?
                method.getRequestType() :
                method.getResponseType());
        if (md == null) {
            throw new PApplicationException("No type for " + type + " on " + method, INVALID_MESSAGE_TYPE);
        }
        return new PServiceCall<>(method.getName(), type, sequence, (Message) parseFieldValue(md, message));
    }

    private Object parseFieldValue(PDescriptor descriptor, String value) throws IOException {
        switch (descriptor.getType()) {
            case VOID:
                return Boolean.TRUE;
            case BOOL:
                return Boolean.parseBoolean(value);
            case BYTE:
                return Byte.parseByte(value);
            case I16:
                return Short.parseShort(value);
            case I32:
                return Integer.parseInt(value);
            case I64:
                return Long.parseLong(value);
            case DOUBLE:
                return Double.parseDouble(value);
            case STRING:
                return value;
            case BINARY:
                return Binary.fromBase64(value);
            case ENUM:
                PEnumDescriptor ed = (PEnumDescriptor) descriptor;
                if (Strings.isInteger(value)) {
                    return ed.findById(Integer.parseInt(value));
                }
                return ed.findByName(value);
            case LIST:
            case SET:
            case MAP:
            case MESSAGE:
            default:
                try {
                    StringReader  reader    = new StringReader(value);
                    JsonTokenizer tokenizer = new JsonTokenizer(reader);
                    JsonToken     first     = tokenizer.expect("anything");
                    return JSON.parseTypedValue(first, tokenizer, descriptor, false);
                } catch (JsonException e) {
                    throw new IOException(e.getMessage(), e);
                }
        }
    }

    @Override
    public boolean binaryProtocol() {
        return false;
    }

    @Override
    public void verifyEndOfContent(@Nonnull InputStream input) throws IOException {
        String content = IOUtils.readString(input).trim();
        if (!content.isEmpty()) {
            throw new IOException("After end of url-encoded content: '" + content + "'");
        }
    }

    @Nonnull
    @Override
    public String mediaType() {
        return MEDIA_TYPE;
    }
}