NonblockingSocketClientHandler.java
/*
* Copyright 2016 Providence Authors
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package net.morimekta.providence.thrift.client;
import net.morimekta.providence.PApplicationException;
import net.morimekta.providence.PApplicationExceptionType;
import net.morimekta.providence.PMessage;
import net.morimekta.providence.PServiceCall;
import net.morimekta.providence.PServiceCallHandler;
import net.morimekta.providence.PServiceCallInstrumentation;
import net.morimekta.providence.PServiceCallType;
import net.morimekta.providence.descriptor.PService;
import net.morimekta.providence.serializer.Serializer;
import net.morimekta.providence.thrift.io.FramedBufferInputStream;
import net.morimekta.providence.thrift.io.FramedBufferOutputStream;
import net.morimekta.util.concurrent.NamedThreadFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.Closeable;
import java.io.IOException;
import java.io.OutputStream;
import java.net.Socket;
import java.net.SocketAddress;
import java.nio.channels.SocketChannel;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import static net.morimekta.providence.PServiceCallInstrumentation.NS_IN_MILLIS;
/**
* Client handler for thrift RPC using the TNonblockingServer, or similar that
* uses the TFramedTransport message wrapper. It is able to handle a true
* async-like message and response order, so even if the server sends responses
* out of order this client will match to the correct caller.
*
* The client handler is dependent on that there is a single client with unique
* sequence IDs on incoming service calls, otherwise there will be trouble with
* matching responses to the requesting thread.
*
* When using this client handler make sure to close it when no longer in use.
* Otherwise it will keep the socket channel open almost indefinitely.
*/
public class NonblockingSocketClientHandler implements PServiceCallHandler, Closeable {
private static final Logger LOGGER = LoggerFactory.getLogger(NonblockingSocketClientHandler.class);
private final Serializer serializer;
private final SocketAddress address;
private final int connect_timeout;
private final int read_timeout;
private final int response_timeout;
private final Map<Integer, CompletableFuture<PServiceCall>> responseFutures;
private final ExecutorService executorService;
private final PServiceCallInstrumentation instrumentation;
private final BlockingQueue<PServiceCall> requestQueue;
private final PService service;
private volatile SocketChannel channel;
private volatile FramedBufferOutputStream out;
public NonblockingSocketClientHandler(Serializer serializer,
SocketAddress address,
PService service) {
this(serializer, address, service, PServiceCallInstrumentation.NOOP);
}
public NonblockingSocketClientHandler(Serializer serializer,
SocketAddress address,
PService service,
PServiceCallInstrumentation instrumentation) {
this(serializer,
address,
service,
instrumentation,
10000,
10000);
}
public NonblockingSocketClientHandler(Serializer serializer,
SocketAddress address,
PService service,
int connect_timeout,
int read_timeout) {
this(serializer,
address,
service,
PServiceCallInstrumentation.NOOP,
connect_timeout,
read_timeout);
}
public NonblockingSocketClientHandler(Serializer serializer,
SocketAddress address,
PService service,
PServiceCallInstrumentation instrumentation,
int connect_timeout,
int read_timeout) {
this(serializer,
address,
service,
instrumentation,
connect_timeout,
read_timeout,
connect_timeout + 2 * read_timeout);
}
public NonblockingSocketClientHandler(Serializer serializer,
SocketAddress address,
PService service,
PServiceCallInstrumentation instrumentation,
int connect_timeout,
int read_timeout,
int response_timeout) {
this.serializer = serializer;
this.address = address;
this.service = service;
this.instrumentation = instrumentation;
this.connect_timeout = connect_timeout;
this.read_timeout = read_timeout;
this.response_timeout = response_timeout;
this.responseFutures = new ConcurrentHashMap<>();
this.requestQueue = new LinkedBlockingQueue<>();
this.executorService = Executors.newFixedThreadPool(
2, NamedThreadFactory.builder().setDaemon(true).setNameFormat("non-blocking-%d").build());
this.executorService.submit(this::handleWriteRequests);
}
@Override
public synchronized void close() throws IOException {
executorService.shutdown();
if (channel != null) {
try (SocketChannel ignore = channel;
OutputStream ignore2 = out) {
channel = null;
out = null;
}
}
try {
// triggers interrupt on children.
executorService.shutdownNow();
executorService.awaitTermination(10, TimeUnit.SECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
@Override
@SuppressWarnings("unchecked")
public <Request extends PMessage<Request>,
Response extends PMessage<Response>>
PServiceCall<Response> handleCall(PServiceCall<Request> call, PService service)
throws IOException {
if (call.getType() == PServiceCallType.EXCEPTION || call.getType() == PServiceCallType.REPLY) {
throw new PApplicationException("Request with invalid call type: " + call.getType(),
PApplicationExceptionType.INVALID_MESSAGE_TYPE);
}
long startTime = System.nanoTime();
PServiceCall<Response> response = null;
CompletableFuture<PServiceCall> responseFuture = null;
responseFuture = new CompletableFuture<>();
// Each sequence No must be unique for the client, otherwise this will be messed up.
responseFutures.put(call.getSequence(), responseFuture);
try {
requestQueue.put(call);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e.getMessage(), e);
}
try {
try {
if (response_timeout > 0) {
response = (PServiceCall<Response>) responseFuture.get(response_timeout, TimeUnit.MILLISECONDS);
} else {
response = (PServiceCall<Response>) responseFuture.get();
}
long endTime = System.nanoTime();
double duration = ((double) (endTime - startTime)) / NS_IN_MILLIS;
try {
instrumentation.onComplete(duration, call, response);
} catch (Exception ignore) {}
return response;
} catch (TimeoutException | InterruptedException e) {
responseFuture.completeExceptionally(e);
throw new IOException(e.getMessage(), e);
} catch (ExecutionException e) {
if (e.getCause() instanceof IOException) {
e.getCause().addSuppressed(e);
throw ((IOException) e.getCause());
}
throw new IOException(e.getMessage(), e);
} finally {
responseFutures.remove(call.getSequence());
}
} catch (Exception e) {
long endTime = System.nanoTime();
double duration = ((double) (endTime - startTime)) / NS_IN_MILLIS;
try {
instrumentation.onTransportException(e, duration, call, response);
} catch (Exception ie) {
e.addSuppressed(ie);
}
throw e;
}
}
private synchronized FramedBufferOutputStream ensureConnected() throws IOException {
if (channel == null || !channel.isConnected()) {
channel = SocketChannel.open();
// The client channel is always in blocking mode. The read and write
// threads handle the asynchronous nature of the protocol.
channel.configureBlocking(true);
Socket socket = channel.socket();
socket.setSoLinger(false, 0);
socket.setTcpNoDelay(true);
socket.setKeepAlive(true);
socket.setSoTimeout(read_timeout);
socket.connect(address, connect_timeout);
this.out = new FramedBufferOutputStream(channel);
this.executorService.submit(() -> this.handleReadResponses(channel));
}
return out;
}
private void handleWriteRequests() {
while (!executorService.isShutdown()) {
PServiceCall<?> call;
try {
if ((call = requestQueue.poll(100, TimeUnit.MILLISECONDS)) != null) {
try {
FramedBufferOutputStream out = ensureConnected();
try {
serializer.serialize(out, call);
out.flush();
} finally {
out.completeFrame();
}
if (call.getType() == PServiceCallType.ONEWAY) {
// It's sent, do not wait for response.
CompletableFuture<?> future = responseFutures.remove(call.getSequence());
if (future != null) {
future.complete(null);
}
}
} catch (IOException e) {
try {
close();
} catch (IOException e2) {
e.addSuppressed(e2);
}
CompletableFuture future = responseFutures.remove(call.getSequence());
if (future != null) {
responseFutures.remove(call.getSequence());
future.completeExceptionally(e);
}
}
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
return;
}
}
}
private void handleReadResponses(SocketChannel channel) {
FramedBufferInputStream in = new FramedBufferInputStream(channel);
while (this.channel == channel && channel.isOpen()) {
try {
in.nextFrame();
PServiceCall reply = serializer.deserialize(in, service);
if (reply.getType() == PServiceCallType.CALL || reply.getType() == PServiceCallType.ONEWAY) {
throw new PApplicationException("Reply with invalid call type: " + reply.getType(),
PApplicationExceptionType.INVALID_MESSAGE_TYPE);
}
CompletableFuture<PServiceCall> future = responseFutures.remove(reply.getSequence());
if (future == null) {
// The item response timed out.
LOGGER.debug("No future for sequence ID " + reply.getSequence());
continue;
}
future.complete(reply);
} catch (Exception e) {
if (this.channel != channel || !channel.isOpen()) {
// If the channel is closed. Should not trigger on disconnected.
break;
}
LOGGER.error("Exception in channel response reading", e);
}
}
if (responseFutures.size() > 0) {
LOGGER.warn("Channel closed with {} unfinished calls", responseFutures.size());
responseFutures.forEach((s, f) -> f.completeExceptionally(new IOException("Channel closed")));
responseFutures.clear();
}
}
}