SocketServer.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.providence.thrift.server;

import net.morimekta.providence.PApplicationException;
import net.morimekta.providence.PApplicationExceptionType;
import net.morimekta.providence.PProcessor;
import net.morimekta.providence.PServiceCall;
import net.morimekta.providence.PServiceCallInstrumentation;
import net.morimekta.providence.PServiceCallType;
import net.morimekta.providence.serializer.BinarySerializer;
import net.morimekta.providence.serializer.Serializer;
import net.morimekta.providence.serializer.SerializerException;
import net.morimekta.util.concurrent.NamedThreadFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;

import static net.morimekta.providence.PServiceCallInstrumentation.NS_IN_MILLIS;

/**
 * Based heavily on <code>org.apache.thrift.transport.TServerTransport</code>
 * and meant to be a providence replacement for it.
 */
public class SocketServer implements AutoCloseable {
    public static class Builder {
        private final PProcessor                  processor;
        private       PServiceCallInstrumentation instrumentation;
        private       InetSocketAddress           bindAddress;

        private int clientTimeout       = 60000;  // 60 seconds
        private int backlog             = 50;
        private int workerThreads       = 10;
        private ThreadFactory workerThreadFactory;
        private Serializer serializer;

        public Builder(@Nonnull PProcessor processor) {
            this.processor = processor;
            this.bindAddress = new InetSocketAddress(0);
            this.serializer = new BinarySerializer();
        }

        public Builder withPort(int port) {
            if (port < 0) {
                throw new IllegalArgumentException();
            }
            return withBindAddress(new InetSocketAddress(port));
        }

        public Builder withBindAddress(@Nonnull InetSocketAddress bindAddress) {
            this.bindAddress = bindAddress;
            return this;
        }

        public Builder withMaxBacklog(int maxBacklog) {
            if (maxBacklog < 0) {
                throw new IllegalArgumentException();
            }
            this.backlog = maxBacklog;
            return this;
        }

        public Builder withInstrumentation(@Nonnull PServiceCallInstrumentation instrumentation) {
            this.instrumentation = instrumentation;
            return this;
        }

        public Builder withClientTimeout(int timeoutInMs) {
            if (timeoutInMs < 1) {
                throw new IllegalArgumentException();
            }
            this.clientTimeout = timeoutInMs;
            return this;
        }

        public Builder withWorkerThreads(int numThreads) {
            if (numThreads < 1) {
                throw new IllegalArgumentException();
            }

            this.workerThreads = numThreads;
            return this;
        }

        public Builder withThreadFactory(ThreadFactory factory) {
            this.workerThreadFactory = factory;
            return this;
        }

        public Builder withSerializer(Serializer serializer) {
            this.serializer = serializer;
            return this;
        }

        public SocketServer start() {
            if (workerThreadFactory == null) {
                this.workerThreadFactory = NamedThreadFactory.builder()
                        .setNameFormat("providence-server-%d")
                        .setDaemon(false)
                        .build();
            }

            return new SocketServer(this);
        }

    }

    public static Builder builder(@Nonnull PProcessor processor) {
        return new Builder(processor);
    }

    public int getPort() {
        if (workerExecutor.isShutdown()) return -1;
        return serverSocket.getLocalPort();
    }

    @Override
    public void close() {
        workerExecutor.shutdown();
        try {
            // this should trigger exception in the accept task.
            serverSocket.close();
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            try {
                workerExecutor.awaitTermination(10, TimeUnit.MILLISECONDS);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            // really really try to kill it now.
            workerExecutor.shutdownNow();
        }
    }

    private final static Logger LOGGER = LoggerFactory.getLogger(SocketServer.class);

    private final int                         clientTimeout;
    private final PProcessor                  processor;
    private final PServiceCallInstrumentation instrumentation;
    private final ServerSocket                serverSocket;
    private final ExecutorService             workerExecutor;
    private final Serializer                  serializer;

    private SocketServer(Builder builder) {
        try {
            clientTimeout = builder.clientTimeout;
            processor = builder.processor;
            instrumentation = builder.instrumentation != null
                              ? builder.instrumentation
                              : PServiceCallInstrumentation.NOOP;
            serializer = builder.serializer;

            // Make server socket.
            serverSocket = new ServerSocket();
            // Prevent 2MSL delay problem on server restarts
            serverSocket.setReuseAddress(true);
            // Bind to listening port
            serverSocket.bind(builder.bindAddress, builder.backlog);
            serverSocket.setSoTimeout(0);

            workerExecutor = Executors.newFixedThreadPool(builder.workerThreads + 1,
                                                          builder.workerThreadFactory);
            workerExecutor.submit(this::accept);
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    private void accept() {
        try {
            Socket socket = serverSocket.accept();
            socket.setSoTimeout(clientTimeout);
            long startTime = System.nanoTime();
            workerExecutor.submit(() -> process(startTime, socket));
        } catch (SocketTimeoutException e) {
            // ignore.
        } catch (IOException e) {
            if (workerExecutor.isShutdown()) {
                return;
            }
            throw new UncheckedIOException(e);
        } catch (Exception e) {
            e.printStackTrace();
            throw new IllegalStateException(e);
        }
        workerExecutor.submit(this::accept);
    }

    @SuppressWarnings("unchecked")
    private void process(long startTime, @Nonnull Socket socket) {
        try (Socket ignore = socket;
             BufferedInputStream in = new BufferedInputStream(socket.getInputStream());
             BufferedOutputStream out = new BufferedOutputStream(socket.getOutputStream())) {
            while (socket.isConnected()) {
                PServiceCall request = null;
                PServiceCall response = null;
                try {
                    try {
                        request = serializer.deserialize(in, processor.getDescriptor());
                        if (request.getType() == PServiceCallType.REPLY ||
                            request.getType() == PServiceCallType.EXCEPTION) {
                            PApplicationException ex = new PApplicationException(
                                    "Invalid service request call type: " + request.getType(),
                                    PApplicationExceptionType.INVALID_MESSAGE_TYPE);
                            response = new PServiceCall<>(request.getMethod(), PServiceCallType.EXCEPTION, request.getSequence(), ex);
                        } else {
                            response = processor.handleCall(request);
                        }
                    } catch (SerializerException e) {
                        if (e.getMethodName() != null) {
                            LOGGER.error("Error when reading service call " + processor.getDescriptor().getName() + "." + e.getMethodName() + "()", e);
                        } else {
                            LOGGER.error("Error when reading service call " + processor.getDescriptor().getName(), e);
                        }
                        PApplicationException ex = new PApplicationException(e.getMessage(), e.getExceptionType()).initCause(e);
                        response = new PServiceCall<>(e.getMethodName(), PServiceCallType.EXCEPTION, e.getSequenceNo(), ex);
                    }

                    if (response != null) {
                        serializer.serialize(out, response);
                        out.flush();
                    }

                    long endTime = System.nanoTime();
                    double duration = ((double) (endTime - startTime)) / NS_IN_MILLIS;
                    try {
                        instrumentation.onComplete(duration, request, response);
                    } catch (Throwable th) {
                        LOGGER.error("Exception in service instrumentation", th);
                    }
                } catch (IOException e) {
                    long endTime = System.nanoTime();
                    double duration = ((double) (endTime - startTime)) / NS_IN_MILLIS;
                    try {
                        instrumentation.onTransportException(e, duration, request, response);
                    } catch (Throwable th) {
                        LOGGER.error("Exception in service instrumentation", th);
                    }

                    throw new UncheckedIOException(e.getMessage(), e);
                }

                in.mark(1);
                if (in.read() < 0) {
                    return;
                }
                in.reset();

                startTime = System.nanoTime();
            }
        } catch (IOException e) {
            throw new UncheckedIOException(e.getMessage(), e);
        }
    }
}