NonblockingSocketClientHandler.java

  1. /*
  2.  * Copyright 2016 Providence Authors
  3.  *
  4.  * Licensed to the Apache Software Foundation (ASF) under one
  5.  * or more contributor license agreements. See the NOTICE file
  6.  * distributed with this work for additional information
  7.  * regarding copyright ownership. The ASF licenses this file
  8.  * to you under the Apache License, Version 2.0 (the
  9.  * "License"); you may not use this file except in compliance
  10.  * with the License. You may obtain a copy of the License at
  11.  *
  12.  *   http://www.apache.org/licenses/LICENSE-2.0
  13.  *
  14.  * Unless required by applicable law or agreed to in writing,
  15.  * software distributed under the License is distributed on an
  16.  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  17.  * KIND, either express or implied. See the License for the
  18.  * specific language governing permissions and limitations
  19.  * under the License.
  20.  */
  21. package net.morimekta.providence.thrift.client;

  22. import net.morimekta.providence.PApplicationException;
  23. import net.morimekta.providence.PApplicationExceptionType;
  24. import net.morimekta.providence.PMessage;
  25. import net.morimekta.providence.PServiceCall;
  26. import net.morimekta.providence.PServiceCallHandler;
  27. import net.morimekta.providence.PServiceCallInstrumentation;
  28. import net.morimekta.providence.PServiceCallType;
  29. import net.morimekta.providence.descriptor.PService;
  30. import net.morimekta.providence.serializer.Serializer;
  31. import net.morimekta.providence.thrift.io.FramedBufferInputStream;
  32. import net.morimekta.providence.thrift.io.FramedBufferOutputStream;
  33. import net.morimekta.util.concurrent.NamedThreadFactory;
  34. import org.slf4j.Logger;
  35. import org.slf4j.LoggerFactory;

  36. import java.io.Closeable;
  37. import java.io.IOException;
  38. import java.io.OutputStream;
  39. import java.net.Socket;
  40. import java.net.SocketAddress;
  41. import java.nio.channels.SocketChannel;
  42. import java.util.Map;
  43. import java.util.concurrent.BlockingQueue;
  44. import java.util.concurrent.CompletableFuture;
  45. import java.util.concurrent.ConcurrentHashMap;
  46. import java.util.concurrent.ExecutionException;
  47. import java.util.concurrent.ExecutorService;
  48. import java.util.concurrent.Executors;
  49. import java.util.concurrent.LinkedBlockingQueue;
  50. import java.util.concurrent.TimeUnit;
  51. import java.util.concurrent.TimeoutException;

  52. import static net.morimekta.providence.PServiceCallInstrumentation.NS_IN_MILLIS;

  53. /**
  54.  * Client handler for thrift RPC using the TNonblockingServer, or similar that
  55.  * uses the TFramedTransport message wrapper. It is able to handle a true
  56.  * async-like message and response order, so even if the server sends responses
  57.  * out of order this client will match to the correct caller.
  58.  *
  59.  * The client handler is dependent on that there is a single client with unique
  60.  * sequence IDs on incoming service calls, otherwise there will be trouble with
  61.  * matching responses to the requesting thread.
  62.  *
  63.  * When using this client handler make sure to close it when no longer in use.
  64.  * Otherwise it will keep the socket channel open almost indefinitely.
  65.  */
  66. public class NonblockingSocketClientHandler implements PServiceCallHandler, Closeable {
  67.     private static final Logger LOGGER = LoggerFactory.getLogger(NonblockingSocketClientHandler.class);

  68.     private final Serializer    serializer;
  69.     private final SocketAddress address;
  70.     private final int           connect_timeout;
  71.     private final int           read_timeout;
  72.     private final int           response_timeout;

  73.     private final Map<Integer, CompletableFuture<PServiceCall>> responseFutures;
  74.     private final ExecutorService                               executorService;
  75.     private final PServiceCallInstrumentation                   instrumentation;
  76.     private final BlockingQueue<PServiceCall>                   requestQueue;
  77.     private final PService                                      service;

  78.     private volatile SocketChannel channel;
  79.     private volatile FramedBufferOutputStream out;

  80.     public NonblockingSocketClientHandler(Serializer serializer,
  81.                                           SocketAddress address,
  82.                                           PService service) {
  83.         this(serializer, address, service, PServiceCallInstrumentation.NOOP);
  84.     }

  85.     public NonblockingSocketClientHandler(Serializer serializer,
  86.                                           SocketAddress address,
  87.                                           PService service,
  88.                                           PServiceCallInstrumentation instrumentation) {
  89.         this(serializer,
  90.              address,
  91.              service,
  92.              instrumentation,
  93.              10000,
  94.              10000);
  95.     }

  96.     public NonblockingSocketClientHandler(Serializer serializer,
  97.                                           SocketAddress address,
  98.                                           PService service,
  99.                                           int connect_timeout,
  100.                                           int read_timeout) {
  101.         this(serializer,
  102.              address,
  103.              service,
  104.              PServiceCallInstrumentation.NOOP,
  105.              connect_timeout,
  106.              read_timeout);
  107.     }

  108.     public NonblockingSocketClientHandler(Serializer serializer,
  109.                                           SocketAddress address,
  110.                                           PService service,
  111.                                           PServiceCallInstrumentation instrumentation,
  112.                                           int connect_timeout,
  113.                                           int read_timeout) {
  114.         this(serializer,
  115.              address,
  116.              service,
  117.              instrumentation,
  118.              connect_timeout,
  119.              read_timeout,
  120.              connect_timeout + 2 * read_timeout);
  121.     }

  122.     public NonblockingSocketClientHandler(Serializer serializer,
  123.                                           SocketAddress address,
  124.                                           PService service,
  125.                                           PServiceCallInstrumentation instrumentation,
  126.                                           int connect_timeout,
  127.                                           int read_timeout,
  128.                                           int response_timeout) {
  129.         this.serializer = serializer;
  130.         this.address = address;
  131.         this.service = service;
  132.         this.instrumentation = instrumentation;
  133.         this.connect_timeout = connect_timeout;
  134.         this.read_timeout = read_timeout;
  135.         this.response_timeout = response_timeout;
  136.         this.responseFutures = new ConcurrentHashMap<>();
  137.         this.requestQueue = new LinkedBlockingQueue<>();
  138.         this.executorService = Executors.newFixedThreadPool(
  139.                 2, NamedThreadFactory.builder().setDaemon(true).setNameFormat("non-blocking-%d").build());
  140.         this.executorService.submit(this::handleWriteRequests);
  141.     }

  142.     @Override
  143.     public synchronized void close() throws IOException {
  144.         executorService.shutdown();
  145.         if (channel != null) {
  146.             try (SocketChannel ignore = channel;
  147.                  OutputStream ignore2 = out) {
  148.                 channel = null;
  149.                 out = null;
  150.             }
  151.         }
  152.         try {
  153.             // triggers interrupt on children.
  154.             executorService.shutdownNow();
  155.             executorService.awaitTermination(10, TimeUnit.SECONDS);
  156.         } catch (InterruptedException e) {
  157.             Thread.currentThread().interrupt();
  158.         }
  159.     }

  160.     @Override
  161.     @SuppressWarnings("unchecked")
  162.     public <Request extends PMessage<Request>,
  163.             Response extends PMessage<Response>>
  164.     PServiceCall<Response> handleCall(PServiceCall<Request> call, PService service)
  165.             throws IOException {
  166.         if (call.getType() == PServiceCallType.EXCEPTION || call.getType() == PServiceCallType.REPLY) {
  167.             throw new PApplicationException("Request with invalid call type: " + call.getType(),
  168.                                             PApplicationExceptionType.INVALID_MESSAGE_TYPE);
  169.         }

  170.         long                            startTime      = System.nanoTime();
  171.         PServiceCall<Response>          response       = null;
  172.         CompletableFuture<PServiceCall> responseFuture = null;
  173.         responseFuture = new CompletableFuture<>();
  174.         // Each sequence No must be unique for the client, otherwise this will be messed up.
  175.         responseFutures.put(call.getSequence(), responseFuture);
  176.         try {
  177.             requestQueue.put(call);
  178.         } catch (InterruptedException e) {
  179.             Thread.currentThread().interrupt();
  180.             throw new RuntimeException(e.getMessage(), e);
  181.         }

  182.         try {
  183.             try {
  184.                 if (response_timeout > 0) {
  185.                     response = (PServiceCall<Response>) responseFuture.get(response_timeout, TimeUnit.MILLISECONDS);
  186.                 } else {
  187.                     response = (PServiceCall<Response>) responseFuture.get();
  188.                 }

  189.                 long   endTime  = System.nanoTime();
  190.                 double duration = ((double) (endTime - startTime)) / NS_IN_MILLIS;
  191.                 try {
  192.                     instrumentation.onComplete(duration, call, response);
  193.                 } catch (Exception ignore) {}

  194.                 return response;
  195.             } catch (TimeoutException | InterruptedException e) {
  196.                 responseFuture.completeExceptionally(e);
  197.                 throw new IOException(e.getMessage(), e);
  198.             } catch (ExecutionException e) {
  199.                 if (e.getCause() instanceof IOException) {
  200.                     e.getCause().addSuppressed(e);
  201.                     throw ((IOException) e.getCause());
  202.                 }

  203.                 throw new IOException(e.getMessage(), e);
  204.             } finally {
  205.                 responseFutures.remove(call.getSequence());
  206.             }
  207.         } catch (Exception e) {
  208.             long endTime = System.nanoTime();
  209.             double duration = ((double) (endTime - startTime)) / NS_IN_MILLIS;
  210.             try {
  211.                 instrumentation.onTransportException(e, duration, call, response);
  212.             } catch (Exception ie) {
  213.                 e.addSuppressed(ie);
  214.             }

  215.             throw e;
  216.         }
  217.     }

  218.     private synchronized FramedBufferOutputStream ensureConnected() throws IOException {
  219.         if (channel == null || !channel.isConnected()) {
  220.             channel = SocketChannel.open();
  221.             // The client channel is always in blocking mode. The read and write
  222.             // threads handle the asynchronous nature of the protocol.
  223.             channel.configureBlocking(true);

  224.             Socket socket = channel.socket();
  225.             socket.setSoLinger(false, 0);
  226.             socket.setTcpNoDelay(true);
  227.             socket.setKeepAlive(true);
  228.             socket.setSoTimeout(read_timeout);
  229.             socket.connect(address, connect_timeout);

  230.             this.out = new FramedBufferOutputStream(channel);
  231.             this.executorService.submit(() -> this.handleReadResponses(channel));
  232.         }
  233.         return out;
  234.     }

  235.     private void handleWriteRequests() {
  236.         while (!executorService.isShutdown()) {
  237.             PServiceCall<?> call;
  238.             try {
  239.                 if ((call = requestQueue.poll(100, TimeUnit.MILLISECONDS)) != null) {
  240.                     try {
  241.                         FramedBufferOutputStream out = ensureConnected();
  242.                         try {
  243.                             serializer.serialize(out, call);
  244.                             out.flush();
  245.                         } finally {
  246.                             out.completeFrame();
  247.                         }
  248.                         if (call.getType() == PServiceCallType.ONEWAY) {
  249.                             // It's sent, do not wait for response.
  250.                             CompletableFuture<?> future = responseFutures.remove(call.getSequence());
  251.                             if (future != null) {
  252.                                 future.complete(null);
  253.                             }
  254.                         }
  255.                     } catch (IOException e) {
  256.                         try {
  257.                             close();
  258.                         } catch (IOException e2) {
  259.                             e.addSuppressed(e2);
  260.                         }
  261.                         CompletableFuture future = responseFutures.remove(call.getSequence());
  262.                         if (future != null) {
  263.                             responseFutures.remove(call.getSequence());
  264.                             future.completeExceptionally(e);
  265.                         }
  266.                     }
  267.                 }
  268.             } catch (InterruptedException e) {
  269.                 Thread.currentThread().interrupt();
  270.                 return;
  271.             }
  272.         }
  273.     }

  274.     private void handleReadResponses(SocketChannel channel) {
  275.         FramedBufferInputStream in = new FramedBufferInputStream(channel);
  276.         while (this.channel == channel && channel.isOpen()) {
  277.             try {
  278.                 in.nextFrame();
  279.                 PServiceCall reply = serializer.deserialize(in, service);

  280.                 if (reply.getType() == PServiceCallType.CALL || reply.getType() == PServiceCallType.ONEWAY) {
  281.                     throw new PApplicationException("Reply with invalid call type: " + reply.getType(),
  282.                                                     PApplicationExceptionType.INVALID_MESSAGE_TYPE);
  283.                 }

  284.                 CompletableFuture<PServiceCall> future = responseFutures.remove(reply.getSequence());
  285.                 if (future == null) {
  286.                     // The item response timed out.
  287.                     LOGGER.debug("No future for sequence ID " + reply.getSequence());
  288.                     continue;
  289.                 }

  290.                 future.complete(reply);
  291.             } catch (Exception e) {
  292.                 if (this.channel != channel || !channel.isOpen()) {
  293.                     // If the channel is closed. Should not trigger on disconnected.
  294.                     break;
  295.                 }
  296.                 LOGGER.error("Exception in channel response reading", e);
  297.             }
  298.         }

  299.         if (responseFutures.size() > 0) {
  300.             LOGGER.warn("Channel closed with {} unfinished calls", responseFutures.size());
  301.             responseFutures.forEach((s, f) -> f.completeExceptionally(new IOException("Channel closed")));
  302.             responseFutures.clear();
  303.         }
  304.     }
  305. }