ProvidenceHttpServletWrapper.java

/*
 * Copyright 2016-2017 Providence Authors
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.providence.server;

import net.morimekta.providence.PApplicationException;
import net.morimekta.providence.PApplicationExceptionType;
import net.morimekta.providence.PException;
import net.morimekta.providence.PMessage;
import net.morimekta.providence.PProcessor;
import net.morimekta.providence.PServiceCall;
import net.morimekta.providence.PServiceCallInstrumentation;
import net.morimekta.providence.PServiceCallType;
import net.morimekta.providence.PUnion;
import net.morimekta.providence.descriptor.PService;
import net.morimekta.providence.descriptor.PServiceMethod;
import net.morimekta.providence.serializer.DefaultSerializerProvider;
import net.morimekta.providence.serializer.Serializer;
import net.morimekta.providence.serializer.SerializerException;
import net.morimekta.providence.serializer.SerializerProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.activation.MimeType;
import javax.activation.MimeTypeParseException;
import javax.annotation.Nonnull;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import static net.morimekta.providence.PApplicationExceptionType.INVALID_PROTOCOL;

/**
 * A HTTP POST servlet wrapper around a service using sub-path routing
 * to each compatible service method. Note that it will only support
 * service methods that have a message as the response type, which will
 * be used as the response object. The request is the method params
 * message generated for the method.
 * <p>
 * The servlet must be registered with a wildcard path ending, so it can
 * capture sub-paths to be used to determine which service method to route
 * to.
 *
 * <pre>{@code
 * class MyMain {
 *     void main(String... args) {
 *         Server server = new Server(8080);
 *         ServletContextHandler handler = new ServletContextHandler();
 *         handler.addServlet(
 *                 new ServletHolder(new ProvidenceHttpServletWrapper(
 *                         MyService.kDescriptor),
 *                 "/foo/*"));
 *         server.setHandler(handler);
 *         server.start();
 *         server.join();
 *     }
 * }
 * }</pre>
 *
 * <p>
 * Note that this is not a full server-side handling of a thrift service,
 * as the interface may hide and obscure exceptions, so a matching client
 * does not exist. See HTTP helpers in <code>providence-core-client</code>
 * as replacement.
 *
 * <pre>{@code
 * class OtherMain {
 *     void main(String... args) {
 *         HttpRequestFactory fac = new NetHttpTransport().createRequestFactory();
 *         MyResponse response = fac.buildPostRequest("http://localhost:8080/foo/method",
 *                 new ProvidenceHttpContent(MyService.Method$Request
 *                         .builder()
 *                         .addToArgs(args)
 *                         .build(), JsonSerializer.INSTANCE)
 *                 .setParser(new ProvidenceObjectParser(JsonSerializer.INSTANCE))
 *                 .execute()
 *                 .parseAs(MyResponse.class);
 *         System.out.println(PrettySerializer.toDebugString(response));
 *     }
 * }
 * }</pre>
 *
 * @since 2.0
 */
public class ProvidenceHttpServletWrapper extends HttpServlet {
    private static final Logger LOGGER = LoggerFactory.getLogger(ProvidenceHttpServletWrapper.class);

    private final SerializerProvider          serializerProvider;
    private final ProcessorProvider           processorProvider;
    private final Map<String, PServiceMethod> mapping;
    private final ExceptionHandler            exceptionHandler;
    private final PServiceCallInstrumentation instrumentation;

    public ProvidenceHttpServletWrapper(@Nonnull PService service,
                                        @Nonnull ProcessorProvider processorProvider,
                                        ExceptionHandler exceptionHandler,
                                        SerializerProvider serializerProvider,
                                        PServiceCallInstrumentation instrumentation) {
        this(processorProvider, createMapping(service), exceptionHandler, serializerProvider, instrumentation);
    }

    public ProvidenceHttpServletWrapper(@Nonnull ProcessorProvider processorProvider,
                                        @Nonnull Map<String, PServiceMethod> mapping,
                                        ExceptionHandler exceptionHandler,
                                        SerializerProvider serializerProvider,
                                        PServiceCallInstrumentation instrumentation) {
        this.processorProvider = processorProvider;
        this.mapping = mapping;
        this.exceptionHandler = exceptionHandler == null ? ExceptionHandler.INSTANCE : exceptionHandler;
        this.serializerProvider = serializerProvider == null ? DefaultSerializerProvider.INSTANCE : serializerProvider;
        this.instrumentation = instrumentation == null ? PServiceCallInstrumentation.NOOP : instrumentation;
    }

    @Override
    @SuppressWarnings("unchecked")
    protected final void doPost(HttpServletRequest httpRequest, HttpServletResponse httpResponse) throws IOException {
        long start = System.nanoTime();
        try {
            String path = httpRequest.getPathInfo();
            if (path == null) path = "/";
            if (path.length() > 1 && path.endsWith("/")) {
                path = path.substring(0, path.length() - 1);
            }
            PServiceMethod method = mapping.get(path);
            if (method == null) {
                LOGGER.debug("No servlet for translated path {}", path);
                httpResponse.sendError(HttpServletResponse.SC_NOT_FOUND);
                try {
                    PApplicationException ex = new PApplicationException("No servlet for translated path" + path, PApplicationExceptionType.UNKNOWN_METHOD);
                    instrumentation.onTransportException(ex, runTime(start), null, null);
                } catch (Exception ignore) {
                }
                return;
            }

            PMessage response;
            Serializer requestSerializer = serializerProvider.getDefault();
            if (httpRequest.getContentType() != null) {
                try {
                    requestSerializer = serializerProvider.getSerializer(httpRequest.getContentType());
                } catch (IllegalArgumentException e) {
                    httpResponse.sendError(HttpServletResponse.SC_BAD_REQUEST, "Unknown content-type: " + httpRequest.getContentType());
                    LOGGER.warn("Unknown content type in request", e);
                    try {
                        PApplicationException ex = new PApplicationException("Unknown content-type: " + httpRequest.getContentType(), INVALID_PROTOCOL);
                        instrumentation.onTransportException(ex, runTime(start), null, null);
                    } catch (Exception ignore) {
                    }
                    return;
                }
            } else {
                LOGGER.debug("Request is missing content type.");
            }

            Serializer responseSerializer = requestSerializer;
            String acceptHeader = httpRequest.getHeader("Accept");
            if (acceptHeader != null) {
                String[] entries = acceptHeader.split(",");
                for (String entry : entries) {
                    entry = entry.trim();
                    if (entry.isEmpty()) {
                        continue;
                    }
                    if ("*/*".equals(entry)) {
                        // Then responding same as request is good.
                        break;
                    }

                    try {
                        MimeType mediaType = new MimeType(entry);
                        responseSerializer = serializerProvider.getSerializer(mediaType.getBaseType());
                        break;
                    } catch (MimeTypeParseException ignore) {
                        // Ignore. Bad header input is pretty common.
                    }
                }
            }

            PMessage request;
            try {
                request = requestSerializer.deserialize(httpRequest.getInputStream(), method.getRequestType());
                requestSerializer.verifyEndOfContent(httpRequest.getInputStream());
            } catch (SerializerException e) {
                LOGGER.info("Failed to deserialize request to {}: {}", httpRequest.getServletPath(), e.displayString(), e);

                PApplicationException ex = new PApplicationException(e.getMessage(), INVALID_PROTOCOL).initCause(e);
                exceptionHandler.handleException(
                        exceptionHandler.getResponseException(ex),
                        responseSerializer, httpRequest, httpResponse);
                try {
                    instrumentation.onTransportException(ex, runTime(start), null, null);
                } catch (Exception ignore) {
                }
                return;
            }

            PServiceCall<?> callToHandle = new PServiceCall<>(method.getName(), PServiceCallType.CALL, 1, request);
            PServiceCall<?> handledCall = null;
            PProcessor processor = processorProvider.processorForRequest(httpRequest);
            try {
                handledCall = processor.handleCall(callToHandle);
                if (handledCall.getType() == PServiceCallType.EXCEPTION) {
                    throw (PApplicationException) handledCall.getMessage();
                } else if (handledCall.getMessage().has(0)) {
                    Object responseMessage = handledCall.getMessage().get(0);
                    if (responseMessage instanceof PMessage) {
                        response = (PMessage) responseMessage;
                    } else {
                        response = handledCall.getMessage();
                    }
                } else {
                    PUnion responseUnion = (PUnion) handledCall.getMessage();
                    PException ex = (PException) responseUnion.get(responseUnion.unionField());
                    throw (Exception) ex;
                }
            } catch (Exception e) {
                try {
                    exceptionHandler.handleException(
                            exceptionHandler.getResponseException(e),
                            responseSerializer, httpRequest, httpResponse);
                    if (!httpResponse.isCommitted()) {
                        httpResponse.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, e.getMessage());
                    }
                } catch (Exception e1) {
                    LOGGER.error("Exception sending error", e1);
                    if (!httpResponse.isCommitted()) {
                        httpResponse.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, e1.getMessage());
                    }
                }
                try {
                    instrumentation.onComplete(runTime(start), callToHandle, handledCall);
                } catch (Exception ignore) {
                }
                return;
            }

            httpResponse.setStatus(HttpServletResponse.SC_OK);
            httpResponse.setContentType(responseSerializer.mediaType());
            responseSerializer.serialize(httpResponse.getOutputStream(), response);
            httpResponse.flushBuffer();
            try {
                instrumentation.onComplete(runTime(start), callToHandle, handledCall);
            } catch (Exception ignore) {
            }
        } catch (IOException e) {
            try {
                instrumentation.onTransportException(e, runTime(start), null, null);
            } catch (Exception ignore) {}
            throw e;
        } catch (Exception e) {
            LOGGER.warn("Unhandled exception in {}", httpRequest.getPathInfo(), e);
            try {
                instrumentation.onTransportException(e, runTime(start), null, null);
            } catch (Exception ignore) {}
            throw e;
        }
    }

    private static double runTime(long startNano) {
        long endNano = System.nanoTime();
        return ((double) (endNano - startNano)) / PServiceCallInstrumentation.NS_IN_MILLIS;
    }


    private static Map<String, PServiceMethod> createMapping(PService service) {
        Map<String, PServiceMethod> mapping = new HashMap<>();
        for (PServiceMethod method : service.getMethods()) {
            mapping.put("/" + method.getName(), method);
        }
        return mapping;
    }
}