NonblockingSocketServer.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.providence.thrift.server;

import net.morimekta.providence.PApplicationException;
import net.morimekta.providence.PProcessor;
import net.morimekta.providence.PServiceCall;
import net.morimekta.providence.PServiceCallInstrumentation;
import net.morimekta.providence.PServiceCallType;
import net.morimekta.providence.serializer.BinarySerializer;
import net.morimekta.providence.serializer.Serializer;
import net.morimekta.providence.serializer.SerializerException;
import net.morimekta.providence.thrift.io.FramedBufferOutputStream;
import net.morimekta.util.concurrent.NamedThreadFactory;
import net.morimekta.util.io.BigEndianBinaryReader;
import net.morimekta.util.io.ByteBufferInputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.Iterator;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;

import static net.morimekta.providence.PApplicationExceptionType.INTERNAL_ERROR;

/**
 * Based heavily on <code>org.apache.thrift.transport.TNonblockingServerTransport</code>
 * and meant to be a providence replacement for it.
 */
public class NonblockingSocketServer implements AutoCloseable {
    public static class Builder {
        private final PProcessor                  processor;
        private       PServiceCallInstrumentation instrumentation;
        private       InetSocketAddress           bindAddress;

        private int maxFrameSizeInBytes = 16384000; // 16M.
        private int readTimeoutInMs     = 60000;  // 60 seconds
        private int backlog             = 50;
        private int workerThreads       = 10;

        private ThreadFactory receiverThreadFactory;
        private ThreadFactory workerThreadFactory;
        private Serializer    serializer;

        public Builder(@Nonnull PProcessor processor) {
            this.processor = processor;
            this.bindAddress = new InetSocketAddress(0);
            this.workerThreadFactory = NamedThreadFactory.builder()
                    .setNameFormat("providence-nonblocking-server-%d")
                    .setDaemon(true)
                    .build();
            this.receiverThreadFactory = workerThreadFactory;
            this.serializer = new BinarySerializer();
        }

        public Builder withPort(int port) {
            if (port < 0) {
                throw new IllegalArgumentException();
            }
            return withBindAddress(new InetSocketAddress(port));
        }

        public Builder withBindAddress(@Nonnull InetSocketAddress bindAddress) {
            this.bindAddress = bindAddress;
            return this;
        }

        public Builder withMaxBacklog(int maxBacklog) {
            if (maxBacklog < 0) {
                throw new IllegalArgumentException();
            }
            this.backlog = maxBacklog;
            return this;
        }

        public Builder withMaxFrameSizeInBytes(int size) {
            if (size < 1024) {
                throw new IllegalArgumentException();
            }
            this.maxFrameSizeInBytes = size;
            return this;
        }

        public Builder withInstrumentation(@Nonnull PServiceCallInstrumentation instrumentation) {
            this.instrumentation = instrumentation;
            return this;
        }

        public Builder withReadTimeout(int timeoutInMs) {
            if (timeoutInMs < 1) {
                throw new IllegalArgumentException();
            }
            this.readTimeoutInMs = timeoutInMs;
            return this;
        }

        public Builder withWorkerThreads(int numThreads) {
            if (numThreads < 1) {
                throw new IllegalArgumentException();
            }

            this.workerThreads = numThreads;
            return this;
        }

        public Builder withWorkerThreadFactory(ThreadFactory factory) {
            this.workerThreadFactory = factory;
            return this;
        }

        public Builder withReceiverThreadFactory(ThreadFactory factory) {
            this.receiverThreadFactory = factory;
            return this;
        }

        public Builder withSerializer(Serializer serializer) {
            this.serializer = serializer;
            return this;
        }

        public NonblockingSocketServer start() {
            return new NonblockingSocketServer(this);
        }
    }

    public static Builder builder(@Nonnull PProcessor processor) {
        return new Builder(processor);
    }

    public int getPort() {
        if (receiverExecutor.isShutdown())
            return -1;
        return serverSocket.getLocalPort();
    }

    public void close() {
        receiverExecutor.shutdown();
        workerExecutor.shutdown();
        try {
            // this should trigger exception in the accept task.
            serverSocket.close();
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            try {
                workerExecutor.awaitTermination(10, TimeUnit.MILLISECONDS);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            try {
                receiverExecutor.awaitTermination(10, TimeUnit.MILLISECONDS);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            // really really try to kill it now.
            receiverExecutor.shutdownNow();
            workerExecutor.shutdownNow();
        }
    }

    private final static Logger LOGGER       = LoggerFactory.getLogger(NonblockingSocketServer.class);
    private final static long   NS_IN_MILLIS = PServiceCallInstrumentation.NS_IN_MILLIS;

    private final Selector                    selector;
    private final PProcessor                  processor;
    private final Serializer                  serializer;
    private final PServiceCallInstrumentation instrumentation;
    private final ServerSocketChannel         serverSocketChannel;
    private final ServerSocket                serverSocket;
    private final ExecutorService             receiverExecutor;
    private final ExecutorService             workerExecutor;
    private final int                         maxFrameSizeInBytes;

    private NonblockingSocketServer(Builder builder) {
        try {
            maxFrameSizeInBytes = builder.maxFrameSizeInBytes;

            serializer = builder.serializer;
            processor = builder.processor;
            instrumentation = builder.instrumentation != null
                              ? builder.instrumentation
                              : PServiceCallInstrumentation.NOOP;
            selector = Selector.open();

            serverSocketChannel = ServerSocketChannel.open();
            serverSocketChannel.configureBlocking(false);

            // Make server socket
            serverSocket = serverSocketChannel.socket();
            serverSocketChannel.socket().setSoTimeout(builder.readTimeoutInMs);

            // Prevent 2MSL delay problem on server restarts
            serverSocket.setReuseAddress(true);
            // Bind to listening port
            serverSocket.bind(builder.bindAddress, builder.backlog);

            // Needs one thread for each receiver, and one for each response writer.
            receiverExecutor = Executors.newFixedThreadPool(2, builder.receiverThreadFactory);
            workerExecutor = Executors.newFixedThreadPool(builder.workerThreads, builder.workerThreadFactory);

            serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT);

            receiverExecutor.submit(this::selectLoop);
            receiverExecutor.submit(this::closeLoop);
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    private void selectLoop() {
        while (serverSocketChannel.isOpen()) {
            try {
                selector.select();
                Iterator<SelectionKey> selectedKeys = selector.selectedKeys()
                                                              .iterator();
                while (selectedKeys.hasNext()) {
                    SelectionKey key = selectedKeys.next();

                    if (!key.isValid()) {
                        // clean up?
                        selectedKeys.remove();
                        continue;
                    }

                    if (key.isAcceptable()) {
                        accept();
                    } else if (key.isReadable()) {
                        handleRead(key, (Context) key.attachment());
                    } else if (key.isWritable()) {
                        handleWrite(key, (Context) key.attachment());
                    }

                    // only remove successfully handled keys from currently selected.
                    selectedKeys.remove();
                }
            } catch (IOException e) {
                LOGGER.error("Exception in thread: " + e.getMessage(), e);
            }
        }
    }

    private void closeLoop() {
        while (serverSocketChannel.isOpen()) {
            for (SelectionKey cleanupKey : selector.keys()) {
                if (cleanupKey.channel() == serverSocketChannel) {
                    continue;
                }

                SocketChannel channel = (SocketChannel) cleanupKey.channel();
                if (!cleanupKey.isValid() ||
                    !channel.isOpen() ||
                    channel.socket().isClosed()) {
                    try {
                        cleanupKey.channel().close();
                    } catch (IOException e) {
                        LOGGER.warn("Unable to close channel", e);
                    }
                    cleanupKey.cancel();
                }
            }

            try {
                Thread.sleep(100);
            } catch (InterruptedException e) {
                LOGGER.error("Interrupted", e);
                close();
                Thread.currentThread().interrupt();
            }
        }
    }

    private void accept() {
        try {
            SocketChannel socketChannel;
            while ((socketChannel = serverSocketChannel.accept()) != null) {
                // But make the actual accepted channel blocking.
                socketChannel.configureBlocking(false);
                socketChannel.register(selector, SelectionKey.OP_READ, new Context(socketChannel, maxFrameSizeInBytes));
            }
        } catch (IOException e) {
            LOGGER.error("Exception when accepting: {}", e.getMessage(), e);
        }
    }

    @SuppressWarnings("unchecked")
    private void handleRead(SelectionKey key, Context context) throws IOException {
        long startTime = System.nanoTime();

        // part a: read into the readBuffer.
        if (context.currentFrameSize == 0) {
            // read frame size.
            try {
                if (context.channel.read(context.sizeBuffer) < 0) {
                    context.close();
                    key.cancel();
                    return;
                }
                if (context.sizeBuffer.position() < 4) {
                    return;
                }
            } catch (IOException e) {
                // LOGGER.error(e.getMessage(), e);
                context.close();
                key.cancel();
                return;
            }

            context.sizeBuffer.flip();
            try (ByteBufferInputStream in = new ByteBufferInputStream(context.sizeBuffer);
                 BigEndianBinaryReader reader = new BigEndianBinaryReader(in)) {
                context.currentFrameSize = reader.expectInt();
            }
            context.sizeBuffer.rewind();

            if (context.currentFrameSize > maxFrameSizeInBytes) {
                LOGGER.warn("Attempting message of " + context.currentFrameSize + " > " + maxFrameSizeInBytes);
                context.close();
                key.cancel();
                return;
            }
            if (context.currentFrameSize < 1) {
                LOGGER.warn("Attempting message of " + context.currentFrameSize);
                context.close();
                key.cancel();
                return;
            }

            context.readBuffer.rewind();
            context.readBuffer.limit(context.currentFrameSize);
        }

        try {
            if (context.channel.read(context.readBuffer) < 0) {
                LOGGER.warn("Closed connection while reading frame");
                context.close();
                key.cancel();
                return;
            }
        } catch (IOException e) {
            LOGGER.warn("Exception reading frame: {}", e.getMessage(), e);
            context.close();
            key.cancel();
            return;
        }

        if (context.readBuffer.position() < context.readBuffer.limit()) {
            // wait until next read, and see if remaining of frame has arrived.
            return;
        }

        // part b: if the read buffer is complete, handle the content.
        try {
            try {
                context.currentFrameSize = 0;
                context.readBuffer.flip();

                InputStream  in   = new ByteBufferInputStream(context.readBuffer);
                PServiceCall call = serializer.deserialize(in, processor.getDescriptor());
                serializer.verifyEndOfContent(in);

                context.readBuffer.clear();
                workerExecutor.submit(() -> {
                    PServiceCall reply;
                    try {
                        reply = processor.handleCall(call);
                    } catch (PApplicationException e) {
                        reply = new PServiceCall<>(call.getMethod(),
                                                   PServiceCallType.EXCEPTION,
                                                   call.getSequence(),
                                                   e);
                    } catch (IOException e) {
                        reply = new PServiceCall<>(call.getMethod(),
                                                   PServiceCallType.EXCEPTION,
                                                   call.getSequence(),
                                                   new PApplicationException(e.getMessage(), INTERNAL_ERROR)
                                                           .initCause(e));
                    }

                    context.writeQueue.offer(new WriteEntry(startTime, call, reply));
                    key.interestOps(key.interestOps() | SelectionKey.OP_WRITE);
                    selector.wakeup();
                });
            } catch (SerializerException e) {
                PServiceCall reply = new PServiceCall<>(
                        e.getMethodName(),
                        PServiceCallType.EXCEPTION,
                        e.getSequenceNo(),
                        new PApplicationException(e.getMessage(), e.getExceptionType())
                                .initCause(e));
                context.writeQueue.offer(new WriteEntry(startTime, null, reply));
                key.interestOps(key.interestOps() | SelectionKey.OP_WRITE);
                selector.wakeup();
            }
        } catch (IOException e) {
            double duration = ((double) System.nanoTime() - startTime) / NS_IN_MILLIS;
            instrumentation.onTransportException(e, duration, null, null);
        }
    }

    @SuppressWarnings("unchecked")
    private void handleWrite(SelectionKey key, Context context) {
        WriteEntry entry;
        while ((entry = context.writeQueue.poll()) != null) {
            Exception ex = null;
            try {
                serializer.serialize(context.out, entry.reply);
            } catch (IOException e) {
                ex = e;
            } finally {
                try {
                    context.out.completeFrame();
                    context.out.flush();
                } catch (IOException e) {
                    LOGGER.error("Failed to write frame: {}", e.getMessage(), e);
                    context.close();
                    key.cancel();
                } finally {
                    double duration = ((double) System.nanoTime() - entry.startTime) / NS_IN_MILLIS;
                    if (ex == null) {
                        instrumentation.onComplete(duration, entry.call, entry.reply);
                    } else {
                        instrumentation.onTransportException(ex, duration, entry.call, entry.reply);
                    }
                }
            }
        }

        // double-guard as a new write entry may just have been added.
        if (context.writeQueue.isEmpty()) {
            key.interestOps(key.interestOps() & ~SelectionKey.OP_WRITE);
        }
    }

    private static class WriteEntry {
        long startTime;
        PServiceCall call;
        PServiceCall reply;

        WriteEntry(long startTime,
                   @Nullable PServiceCall call,
                   @Nonnull PServiceCall reply) {
            this.startTime = startTime;
            this.call = call;
            this.reply = reply;
        }
    }

    private static class Context {
        final Object                   mutex;
        final SocketChannel            channel;
        final Queue<WriteEntry>        writeQueue;
        final FramedBufferOutputStream out;

        final ByteBuffer               sizeBuffer;
        final ByteBuffer               readBuffer;
        int currentFrameSize;

        private Context(SocketChannel channel, int maxFrameSizeInBytes) {
            this.mutex = new Object();
            this.channel = channel;
            this.currentFrameSize = 0;
            this.sizeBuffer = ByteBuffer.allocate(Integer.BYTES);
            this.readBuffer = ByteBuffer.allocateDirect(maxFrameSizeInBytes);
            this.out = new FramedBufferOutputStream(channel, maxFrameSizeInBytes);
            this.writeQueue = new ConcurrentLinkedQueue<>();
        }

        void close() {
            try {
                channel.socket().close();
                channel.close();
            } catch (IOException e) {
                LOGGER.warn("Exception closing channel: {}", e.getMessage(), e);
            }
        }
    }
}