SocketServer.java
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package net.morimekta.providence.thrift.server;
import net.morimekta.providence.PApplicationException;
import net.morimekta.providence.PApplicationExceptionType;
import net.morimekta.providence.PProcessor;
import net.morimekta.providence.PServiceCall;
import net.morimekta.providence.PServiceCallInstrumentation;
import net.morimekta.providence.PServiceCallType;
import net.morimekta.providence.serializer.BinarySerializer;
import net.morimekta.providence.serializer.Serializer;
import net.morimekta.providence.serializer.SerializerException;
import net.morimekta.util.concurrent.NamedThreadFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import static net.morimekta.providence.PServiceCallInstrumentation.NS_IN_MILLIS;
/**
* Based heavily on <code>org.apache.thrift.transport.TServerTransport</code>
* and meant to be a providence replacement for it.
*/
public class SocketServer implements AutoCloseable {
public static class Builder {
private final PProcessor processor;
private PServiceCallInstrumentation instrumentation;
private InetSocketAddress bindAddress;
private int clientTimeout = 60000; // 60 seconds
private int backlog = 50;
private int workerThreads = 10;
private ThreadFactory workerThreadFactory;
private Serializer serializer;
public Builder(@Nonnull PProcessor processor) {
this.processor = processor;
this.bindAddress = new InetSocketAddress(0);
this.serializer = new BinarySerializer();
}
public Builder withPort(int port) {
if (port < 0) {
throw new IllegalArgumentException();
}
return withBindAddress(new InetSocketAddress(port));
}
public Builder withBindAddress(@Nonnull InetSocketAddress bindAddress) {
this.bindAddress = bindAddress;
return this;
}
public Builder withMaxBacklog(int maxBacklog) {
if (maxBacklog < 0) {
throw new IllegalArgumentException();
}
this.backlog = maxBacklog;
return this;
}
public Builder withInstrumentation(@Nonnull PServiceCallInstrumentation instrumentation) {
this.instrumentation = instrumentation;
return this;
}
public Builder withClientTimeout(int timeoutInMs) {
if (timeoutInMs < 1) {
throw new IllegalArgumentException();
}
this.clientTimeout = timeoutInMs;
return this;
}
public Builder withWorkerThreads(int numThreads) {
if (numThreads < 1) {
throw new IllegalArgumentException();
}
this.workerThreads = numThreads;
return this;
}
public Builder withThreadFactory(ThreadFactory factory) {
this.workerThreadFactory = factory;
return this;
}
public Builder withSerializer(Serializer serializer) {
this.serializer = serializer;
return this;
}
public SocketServer start() {
if (workerThreadFactory == null) {
this.workerThreadFactory = NamedThreadFactory.builder()
.setNameFormat("providence-server-%d")
.setDaemon(false)
.build();
}
return new SocketServer(this);
}
}
public static Builder builder(@Nonnull PProcessor processor) {
return new Builder(processor);
}
public int getPort() {
if (workerExecutor.isShutdown()) return -1;
return serverSocket.getLocalPort();
}
@Override
public void close() {
workerExecutor.shutdown();
try {
// this should trigger exception in the accept task.
serverSocket.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
try {
workerExecutor.awaitTermination(10, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
e.printStackTrace();
}
// really really try to kill it now.
workerExecutor.shutdownNow();
}
}
private final static Logger LOGGER = LoggerFactory.getLogger(SocketServer.class);
private final int clientTimeout;
private final PProcessor processor;
private final PServiceCallInstrumentation instrumentation;
private final ServerSocket serverSocket;
private final ExecutorService workerExecutor;
private final Serializer serializer;
private SocketServer(Builder builder) {
try {
clientTimeout = builder.clientTimeout;
processor = builder.processor;
instrumentation = builder.instrumentation != null
? builder.instrumentation
: PServiceCallInstrumentation.NOOP;
serializer = builder.serializer;
// Make server socket.
serverSocket = new ServerSocket();
// Prevent 2MSL delay problem on server restarts
serverSocket.setReuseAddress(true);
// Bind to listening port
serverSocket.bind(builder.bindAddress, builder.backlog);
serverSocket.setSoTimeout(0);
workerExecutor = Executors.newFixedThreadPool(builder.workerThreads + 1,
builder.workerThreadFactory);
workerExecutor.submit(this::accept);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
private void accept() {
try {
Socket socket = serverSocket.accept();
socket.setSoTimeout(clientTimeout);
long startTime = System.nanoTime();
workerExecutor.submit(() -> process(startTime, socket));
} catch (SocketTimeoutException e) {
// ignore.
} catch (IOException e) {
if (workerExecutor.isShutdown()) {
return;
}
throw new UncheckedIOException(e);
} catch (Exception e) {
e.printStackTrace();
throw new IllegalStateException(e);
}
workerExecutor.submit(this::accept);
}
@SuppressWarnings("unchecked")
private void process(long startTime, @Nonnull Socket socket) {
try (Socket ignore = socket;
BufferedInputStream in = new BufferedInputStream(socket.getInputStream());
BufferedOutputStream out = new BufferedOutputStream(socket.getOutputStream())) {
while (socket.isConnected()) {
PServiceCall request = null;
PServiceCall response = null;
try {
try {
request = serializer.deserialize(in, processor.getDescriptor());
if (request.getType() == PServiceCallType.REPLY ||
request.getType() == PServiceCallType.EXCEPTION) {
PApplicationException ex = new PApplicationException(
"Invalid service request call type: " + request.getType(),
PApplicationExceptionType.INVALID_MESSAGE_TYPE);
response = new PServiceCall<>(request.getMethod(), PServiceCallType.EXCEPTION, request.getSequence(), ex);
} else {
response = processor.handleCall(request);
}
} catch (SerializerException e) {
if (e.getMethodName() != null) {
LOGGER.error("Error when reading service call " + processor.getDescriptor().getName() + "." + e.getMethodName() + "()", e);
} else {
LOGGER.error("Error when reading service call " + processor.getDescriptor().getName(), e);
}
PApplicationException ex = new PApplicationException(e.getMessage(), e.getExceptionType()).initCause(e);
response = new PServiceCall<>(e.getMethodName(), PServiceCallType.EXCEPTION, e.getSequenceNo(), ex);
}
if (response != null) {
serializer.serialize(out, response);
out.flush();
}
long endTime = System.nanoTime();
double duration = ((double) (endTime - startTime)) / NS_IN_MILLIS;
try {
instrumentation.onComplete(duration, request, response);
} catch (Throwable th) {
LOGGER.error("Exception in service instrumentation", th);
}
} catch (IOException e) {
long endTime = System.nanoTime();
double duration = ((double) (endTime - startTime)) / NS_IN_MILLIS;
try {
instrumentation.onTransportException(e, duration, request, response);
} catch (Throwable th) {
LOGGER.error("Exception in service instrumentation", th);
}
throw new UncheckedIOException(e.getMessage(), e);
}
in.mark(1);
if (in.read() < 0) {
return;
}
in.reset();
startTime = System.nanoTime();
}
} catch (IOException e) {
throw new UncheckedIOException(e.getMessage(), e);
}
}
}