FlightClient.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.flight;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.ClientInterceptors;
import io.grpc.ManagedChannel;
import io.grpc.MethodDescriptor;
import io.grpc.StatusRuntimeException;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.ClientCallStreamObserver;
import io.grpc.stub.ClientCalls;
import io.grpc.stub.ClientResponseObserver;
import io.grpc.stub.StreamObserver;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.ServerChannel;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.InvocationTargetException;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.BooleanSupplier;
import javax.net.ssl.SSLException;
import org.apache.arrow.flight.FlightProducer.StreamListener;
import org.apache.arrow.flight.auth.BasicClientAuthHandler;
import org.apache.arrow.flight.auth.ClientAuthHandler;
import org.apache.arrow.flight.auth.ClientAuthInterceptor;
import org.apache.arrow.flight.auth.ClientAuthWrapper;
import org.apache.arrow.flight.auth2.BasicAuthCredentialWriter;
import org.apache.arrow.flight.auth2.ClientBearerHeaderHandler;
import org.apache.arrow.flight.auth2.ClientHandshakeWrapper;
import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware;
import org.apache.arrow.flight.grpc.ClientInterceptorAdapter;
import org.apache.arrow.flight.grpc.CredentialCallOption;
import org.apache.arrow.flight.grpc.StatusUtils;
import org.apache.arrow.flight.impl.Flight;
import org.apache.arrow.flight.impl.Flight.Empty;
import org.apache.arrow.flight.impl.FlightServiceGrpc;
import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceBlockingStub;
import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceStub;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider;
/** Client for Flight services. */
public class FlightClient implements AutoCloseable {
private static final int PENDING_REQUESTS = 5;
/**
* The maximum number of trace events to keep on the gRPC Channel. This value disables channel
* tracing.
*/
private static final int MAX_CHANNEL_TRACE_EVENTS = 0;
private final BufferAllocator allocator;
private final ManagedChannel channel;
private final FlightServiceBlockingStub blockingStub;
private final FlightServiceStub asyncStub;
private final ClientAuthInterceptor authInterceptor = new ClientAuthInterceptor();
private final MethodDescriptor<Flight.Ticket, ArrowMessage> doGetDescriptor;
private final MethodDescriptor<ArrowMessage, Flight.PutResult> doPutDescriptor;
private final MethodDescriptor<ArrowMessage, ArrowMessage> doExchangeDescriptor;
private final List<FlightClientMiddleware.Factory> middleware;
/** Create a Flight client from an allocator and a gRPC channel. */
FlightClient(
BufferAllocator incomingAllocator,
ManagedChannel channel,
List<FlightClientMiddleware.Factory> middleware) {
this.allocator = incomingAllocator.newChildAllocator("flight-client", 0, Long.MAX_VALUE);
this.channel = channel;
this.middleware = middleware;
final ClientInterceptor[] interceptors;
interceptors =
new ClientInterceptor[] {authInterceptor, new ClientInterceptorAdapter(middleware)};
// Create a channel with interceptors pre-applied for DoGet and DoPut
Channel interceptedChannel = ClientInterceptors.intercept(channel, interceptors);
blockingStub = FlightServiceGrpc.newBlockingStub(interceptedChannel);
asyncStub = FlightServiceGrpc.newStub(interceptedChannel);
doGetDescriptor = FlightBindingService.getDoGetDescriptor(allocator);
doPutDescriptor = FlightBindingService.getDoPutDescriptor(allocator);
doExchangeDescriptor = FlightBindingService.getDoExchangeDescriptor(allocator);
}
/**
* Get a list of available flights.
*
* @param criteria Criteria for selecting flights
* @param options RPC-layer hints for the call.
* @return FlightInfo Iterable
*/
public Iterable<FlightInfo> listFlights(Criteria criteria, CallOption... options) {
final Iterator<Flight.FlightInfo> flights;
try {
flights = CallOptions.wrapStub(blockingStub, options).listFlights(criteria.asCriteria());
} catch (StatusRuntimeException sre) {
throw StatusUtils.fromGrpcRuntimeException(sre);
}
return () ->
StatusUtils.wrapIterator(
flights,
t -> {
try {
return new FlightInfo(t);
} catch (URISyntaxException e) {
// We don't expect this will happen for conforming Flight implementations. For
// instance, a Java server
// itself wouldn't be able to construct an invalid Location.
throw new RuntimeException(e);
}
});
}
/**
* Lists actions available on the Flight service.
*
* @param options RPC-layer hints for the call.
*/
public Iterable<ActionType> listActions(CallOption... options) {
final Iterator<Flight.ActionType> actions;
try {
actions = CallOptions.wrapStub(blockingStub, options).listActions(Empty.getDefaultInstance());
} catch (StatusRuntimeException sre) {
throw StatusUtils.fromGrpcRuntimeException(sre);
}
return () -> StatusUtils.wrapIterator(actions, ActionType::new);
}
/**
* Performs an action on the Flight service.
*
* @param action The action to perform.
* @param options RPC-layer hints for this call.
* @return An iterator of results.
*/
public Iterator<Result> doAction(Action action, CallOption... options) {
return StatusUtils.wrapIterator(
CallOptions.wrapStub(blockingStub, options).doAction(action.toProtocol()), Result::new);
}
/** Authenticates with a username and password. */
public void authenticateBasic(String username, String password) {
BasicClientAuthHandler basicClient = new BasicClientAuthHandler(username, password);
authenticate(basicClient);
}
/**
* Authenticates against the Flight service.
*
* @param options RPC-layer hints for this call.
* @param handler The auth mechanism to use.
*/
public void authenticate(ClientAuthHandler handler, CallOption... options) {
Preconditions.checkArgument(!authInterceptor.hasAuthHandler(), "Auth already completed.");
ClientAuthWrapper.doClientAuth(handler, CallOptions.wrapStub(asyncStub, options));
authInterceptor.setAuthHandler(handler);
}
/**
* Authenticates with a username and password.
*
* @param username the username.
* @param password the password.
* @return a CredentialCallOption containing a bearer token if the server emitted one, or empty if
* no bearer token was returned. This can be used in subsequent API calls.
*/
public Optional<CredentialCallOption> authenticateBasicToken(String username, String password) {
final ClientIncomingAuthHeaderMiddleware.Factory clientAuthMiddleware =
new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler());
middleware.add(clientAuthMiddleware);
handshake(new CredentialCallOption(new BasicAuthCredentialWriter(username, password)));
return Optional.ofNullable(clientAuthMiddleware.getCredentialCallOption());
}
/**
* Executes the handshake against the Flight service.
*
* @param options RPC-layer hints for this call.
*/
public void handshake(CallOption... options) {
ClientHandshakeWrapper.doClientHandshake(CallOptions.wrapStub(asyncStub, options));
}
/**
* Create or append a descriptor with another stream.
*
* @param descriptor FlightDescriptor the descriptor for the data
* @param root VectorSchemaRoot the root containing data
* @param metadataListener A handler for metadata messages from the server. This will be passed
* buffers that will be freed after {@link StreamListener#onNext(Object)} is called!
* @param options RPC-layer hints for this call.
* @return ClientStreamListener an interface to control uploading data
*/
public ClientStreamListener startPut(
FlightDescriptor descriptor,
VectorSchemaRoot root,
PutListener metadataListener,
CallOption... options) {
return startPut(descriptor, root, new MapDictionaryProvider(), metadataListener, options);
}
/**
* Create or append a descriptor with another stream.
*
* @param descriptor FlightDescriptor the descriptor for the data
* @param root VectorSchemaRoot the root containing data
* @param metadataListener A handler for metadata messages from the server.
* @param options RPC-layer hints for this call.
* @return ClientStreamListener an interface to control uploading data. {@link
* ClientStreamListener#start(VectorSchemaRoot, DictionaryProvider)} will already have been
* called.
*/
public ClientStreamListener startPut(
FlightDescriptor descriptor,
VectorSchemaRoot root,
DictionaryProvider provider,
PutListener metadataListener,
CallOption... options) {
Preconditions.checkNotNull(root, "root must not be null");
Preconditions.checkNotNull(provider, "provider must not be null");
final ClientStreamListener writer = startPut(descriptor, metadataListener, options);
writer.start(root, provider);
return writer;
}
/**
* Create or append a descriptor with another stream.
*
* @param descriptor FlightDescriptor the descriptor for the data
* @param metadataListener A handler for metadata messages from the server.
* @param options RPC-layer hints for this call.
* @return ClientStreamListener an interface to control uploading data. {@link
* ClientStreamListener#start(VectorSchemaRoot, DictionaryProvider)} will NOT already have
* been called.
*/
public ClientStreamListener startPut(
FlightDescriptor descriptor, PutListener metadataListener, CallOption... options) {
Preconditions.checkNotNull(descriptor, "descriptor must not be null");
Preconditions.checkNotNull(metadataListener, "metadataListener must not be null");
try {
final ClientCall<ArrowMessage, Flight.PutResult> call =
asyncStubNewCall(doPutDescriptor, options);
final SetStreamObserver resultObserver = new SetStreamObserver(allocator, metadataListener);
ClientCallStreamObserver<ArrowMessage> observer =
(ClientCallStreamObserver<ArrowMessage>)
ClientCalls.asyncBidiStreamingCall(call, resultObserver);
return new PutObserver(
descriptor, observer, metadataListener::isCancelled, metadataListener::getResult);
} catch (StatusRuntimeException sre) {
throw StatusUtils.fromGrpcRuntimeException(sre);
}
}
/**
* Get info on a stream.
*
* @param descriptor The descriptor for the stream.
* @param options RPC-layer hints for this call.
*/
public FlightInfo getInfo(FlightDescriptor descriptor, CallOption... options) {
try {
return new FlightInfo(
CallOptions.wrapStub(blockingStub, options).getFlightInfo(descriptor.toProtocol()));
} catch (URISyntaxException e) {
// We don't expect this will happen for conforming Flight implementations. For instance, a
// Java server
// itself wouldn't be able to construct an invalid Location.
throw new RuntimeException(e);
} catch (StatusRuntimeException sre) {
throw StatusUtils.fromGrpcRuntimeException(sre);
}
}
/**
* Start or get info on execution of a long-running query.
*
* @param descriptor The descriptor for the stream.
* @param options RPC-layer hints for this call.
* @return Metadata about execution.
*/
public PollInfo pollInfo(FlightDescriptor descriptor, CallOption... options) {
try {
return new PollInfo(
CallOptions.wrapStub(blockingStub, options).pollFlightInfo(descriptor.toProtocol()));
} catch (URISyntaxException e) {
throw new RuntimeException(e);
} catch (StatusRuntimeException sre) {
throw StatusUtils.fromGrpcRuntimeException(sre);
}
}
/**
* Get schema for a stream.
*
* @param descriptor The descriptor for the stream.
* @param options RPC-layer hints for this call.
*/
public SchemaResult getSchema(FlightDescriptor descriptor, CallOption... options) {
try {
return SchemaResult.fromProtocol(
CallOptions.wrapStub(blockingStub, options).getSchema(descriptor.toProtocol()));
} catch (StatusRuntimeException sre) {
throw StatusUtils.fromGrpcRuntimeException(sre);
}
}
/**
* Retrieve a stream from the server.
*
* @param ticket The ticket granting access to the data stream.
* @param options RPC-layer hints for this call.
*/
public FlightStream getStream(Ticket ticket, CallOption... options) {
final ClientCall<Flight.Ticket, ArrowMessage> call = asyncStubNewCall(doGetDescriptor, options);
FlightStream stream =
new FlightStream(
allocator,
PENDING_REQUESTS,
(String message, Throwable cause) -> call.cancel(message, cause),
(count) -> call.request(count));
final StreamObserver<ArrowMessage> delegate = stream.asObserver();
ClientResponseObserver<Flight.Ticket, ArrowMessage> clientResponseObserver =
new ClientResponseObserver<Flight.Ticket, ArrowMessage>() {
@Override
public void beforeStart(
ClientCallStreamObserver<org.apache.arrow.flight.impl.Flight.Ticket> requestStream) {
requestStream.disableAutoInboundFlowControl();
}
@Override
public void onNext(ArrowMessage value) {
delegate.onNext(value);
}
@Override
public void onError(Throwable t) {
delegate.onError(StatusUtils.toGrpcException(t));
}
@Override
public void onCompleted() {
delegate.onCompleted();
}
};
ClientCalls.asyncServerStreamingCall(call, ticket.toProtocol(), clientResponseObserver);
return stream;
}
/**
* Initiate a bidirectional data exchange with the server.
*
* @param descriptor A descriptor for the data stream.
* @param options RPC call options.
* @return A pair of a readable stream and a writable stream.
*/
public ExchangeReaderWriter doExchange(FlightDescriptor descriptor, CallOption... options) {
Preconditions.checkNotNull(descriptor, "descriptor must not be null");
try {
final ClientCall<ArrowMessage, ArrowMessage> call =
asyncStubNewCall(doExchangeDescriptor, options);
final FlightStream stream =
new FlightStream(allocator, PENDING_REQUESTS, call::cancel, call::request);
final ClientCallStreamObserver<ArrowMessage> observer =
(ClientCallStreamObserver<ArrowMessage>)
ClientCalls.asyncBidiStreamingCall(call, stream.asObserver());
final ClientStreamListener writer =
new PutObserver(
descriptor,
observer,
stream.cancelled::isDone,
() -> {
try {
stream.completed.get();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw CallStatus.INTERNAL
.withDescription("Client error: interrupted while completing call")
.withCause(e)
.toRuntimeException();
} catch (ExecutionException e) {
throw CallStatus.INTERNAL
.withDescription("Client error: internal while completing call")
.withCause(e)
.toRuntimeException();
}
});
// Send the descriptor to start.
try (final ArrowMessage message = new ArrowMessage(descriptor.toProtocol())) {
observer.onNext(message);
} catch (Exception e) {
throw CallStatus.INTERNAL
.withCause(e)
.withDescription("Could not write descriptor " + descriptor)
.toRuntimeException();
}
return new ExchangeReaderWriter(stream, writer);
} catch (StatusRuntimeException sre) {
throw StatusUtils.fromGrpcRuntimeException(sre);
}
}
/** A pair of a reader and a writer for a DoExchange call. */
public static class ExchangeReaderWriter implements AutoCloseable {
private final FlightStream reader;
private final ClientStreamListener writer;
ExchangeReaderWriter(FlightStream reader, ClientStreamListener writer) {
this.reader = reader;
this.writer = writer;
}
/** Get the reader for the call. */
public FlightStream getReader() {
return reader;
}
/** Get the writer for the call. */
public ClientStreamListener getWriter() {
return writer;
}
/**
* Make sure stream is drained. You must call this to be notified of any errors that may have
* happened after the exchange is complete. This should be called after
* `getWriter().completed()` and instead of `getWriter().getResult()`.
*/
public void getResult() {
// After exchange is complete, make sure stream is drained to propagate errors through reader
while (reader.next()) {}
}
/** Shut down the streams in this call. */
@Override
public void close() throws Exception {
reader.close();
}
}
/** A stream observer for Flight.PutResult */
private static class SetStreamObserver implements StreamObserver<Flight.PutResult> {
private final BufferAllocator allocator;
private final StreamListener<PutResult> listener;
SetStreamObserver(BufferAllocator allocator, StreamListener<PutResult> listener) {
super();
this.allocator = allocator;
this.listener = listener == null ? NoOpStreamListener.getInstance() : listener;
}
@Override
public void onNext(Flight.PutResult value) {
try (final PutResult message = PutResult.fromProtocol(allocator, value)) {
listener.onNext(message);
}
}
@Override
public void onError(Throwable t) {
listener.onError(StatusUtils.fromThrowable(t));
}
@Override
public void onCompleted() {
listener.onCompleted();
}
}
/** The implementation of a {@link ClientStreamListener} for writing data to a Flight server. */
static class PutObserver extends OutboundStreamListenerImpl implements ClientStreamListener {
private final BooleanSupplier isCancelled;
private final Runnable getResult;
/**
* Create a new client stream listener.
*
* @param descriptor The descriptor for the stream.
* @param observer The write-side gRPC StreamObserver.
* @param isCancelled A flag to check if the call has been cancelled.
* @param getResult A flag that blocks until the overall call completes.
*/
PutObserver(
FlightDescriptor descriptor,
ClientCallStreamObserver<ArrowMessage> observer,
BooleanSupplier isCancelled,
Runnable getResult) {
super(descriptor, observer);
Preconditions.checkNotNull(descriptor, "descriptor must be provided");
Preconditions.checkNotNull(isCancelled, "isCancelled must be provided");
Preconditions.checkNotNull(getResult, "getResult must be provided");
this.isCancelled = isCancelled;
this.getResult = getResult;
this.unloader = null;
}
@Override
protected void waitUntilStreamReady() {
// Check isCancelled as well to avoid inadvertently blocking forever
// (so long as PutListener properly implements it)
while (!responseObserver.isReady() && !isCancelled.getAsBoolean()) {
/* busy wait */
}
}
@Override
public void getResult() {
getResult.run();
}
}
/**
* Cancel execution of a distributed query.
*
* @param request The query to cancel.
* @param options Call options.
* @return The server response.
*/
public CancelFlightInfoResult cancelFlightInfo(
CancelFlightInfoRequest request, CallOption... options) {
Action action =
new Action(FlightConstants.CANCEL_FLIGHT_INFO.getType(), request.serialize().array());
Iterator<Result> results = doAction(action, options);
if (!results.hasNext()) {
throw CallStatus.INTERNAL
.withDescription("Server did not return a response")
.toRuntimeException();
}
CancelFlightInfoResult result;
try {
result = CancelFlightInfoResult.deserialize(ByteBuffer.wrap(results.next().getBody()));
} catch (IOException e) {
throw CallStatus.INTERNAL
.withDescription("Failed to parse server response: " + e)
.withCause(e)
.toRuntimeException();
}
results.forEachRemaining((ignored) -> {});
return result;
}
/**
* Request the server to extend the lifetime of a query result set.
*
* @param request The result set partition.
* @param options Call options.
* @return The new endpoint with an updated expiration time.
*/
public FlightEndpoint renewFlightEndpoint(
RenewFlightEndpointRequest request, CallOption... options) {
Action action =
new Action(FlightConstants.RENEW_FLIGHT_ENDPOINT.getType(), request.serialize().array());
Iterator<Result> results = doAction(action, options);
if (!results.hasNext()) {
throw CallStatus.INTERNAL
.withDescription("Server did not return a response")
.toRuntimeException();
}
FlightEndpoint result;
try {
result = FlightEndpoint.deserialize(ByteBuffer.wrap(results.next().getBody()));
} catch (IOException | URISyntaxException e) {
throw CallStatus.INTERNAL
.withDescription("Failed to parse server response: " + e)
.withCause(e)
.toRuntimeException();
}
results.forEachRemaining((ignored) -> {});
return result;
}
/**
* Set server session option(s) by name/value.
*
* <p>Sessions are generally persisted via HTTP cookies.
*
* @param request The session options to set on the server.
* @param options Call options.
* @return The result containing per-value error statuses, if any.
*/
public SetSessionOptionsResult setSessionOptions(
SetSessionOptionsRequest request, CallOption... options) {
Action action =
new Action(FlightConstants.SET_SESSION_OPTIONS.getType(), request.serialize().array());
Iterator<Result> results = doAction(action, options);
if (!results.hasNext()) {
throw CallStatus.INTERNAL
.withDescription("Server did not return a response")
.toRuntimeException();
}
SetSessionOptionsResult result;
try {
result = SetSessionOptionsResult.deserialize(ByteBuffer.wrap(results.next().getBody()));
} catch (IOException e) {
throw CallStatus.INTERNAL
.withDescription("Failed to parse server response: " + e)
.withCause(e)
.toRuntimeException();
}
results.forEachRemaining((ignored) -> {});
return result;
}
/**
* Get the current server session options.
*
* <p>The session is generally accessed via an HTTP cookie.
*
* @param request The (empty) GetSessionOptionsRequest.
* @param options Call options.
* @return The result containing the set of session options configured on the server.
*/
public GetSessionOptionsResult getSessionOptions(
GetSessionOptionsRequest request, CallOption... options) {
Action action =
new Action(FlightConstants.GET_SESSION_OPTIONS.getType(), request.serialize().array());
Iterator<Result> results = doAction(action, options);
if (!results.hasNext()) {
throw CallStatus.INTERNAL
.withDescription("Server did not return a response")
.toRuntimeException();
}
GetSessionOptionsResult result;
try {
result = GetSessionOptionsResult.deserialize(ByteBuffer.wrap(results.next().getBody()));
} catch (IOException e) {
throw CallStatus.INTERNAL
.withDescription("Failed to parse server response: " + e)
.withCause(e)
.toRuntimeException();
}
results.forEachRemaining((ignored) -> {});
return result;
}
/**
* Close/invalidate the current server session.
*
* <p>The session is generally accessed via an HTTP cookie.
*
* @param request The (empty) CloseSessionRequest.
* @param options Call options.
* @return The result containing the status of the close operation.
*/
public CloseSessionResult closeSession(CloseSessionRequest request, CallOption... options) {
Action action =
new Action(FlightConstants.CLOSE_SESSION.getType(), request.serialize().array());
Iterator<Result> results = doAction(action, options);
if (!results.hasNext()) {
throw CallStatus.INTERNAL
.withDescription("Server did not return a response")
.toRuntimeException();
}
CloseSessionResult result;
try {
result = CloseSessionResult.deserialize(ByteBuffer.wrap(results.next().getBody()));
} catch (IOException e) {
throw CallStatus.INTERNAL
.withDescription("Failed to parse server response: " + e)
.withCause(e)
.toRuntimeException();
}
results.forEachRemaining((ignored) -> {});
return result;
}
/** Interface for writers to an Arrow data stream. */
public interface ClientStreamListener extends OutboundStreamListener {
/**
* Wait for the stream to finish on the server side. You must call this to be notified of any
* errors that may have happened during the upload.
*/
void getResult();
}
/**
* A handler for server-sent application metadata messages during a Flight DoPut operation.
*
* <p>Generally, instead of implementing this yourself, you should use {@link AsyncPutListener} or
* {@link SyncPutListener}.
*/
public interface PutListener extends StreamListener<PutResult> {
/**
* Wait for the stream to finish on the server side. You must call this to be notified of any
* errors that may have happened during the upload.
*/
void getResult();
/**
* Called when a message from the server is received.
*
* @param val The application metadata. This buffer will be reclaimed once onNext returns; you
* must retain a reference to use it outside this method.
*/
@Override
void onNext(PutResult val);
/**
* Check if the call has been cancelled.
*
* <p>By default, this always returns false. Implementations should provide an appropriate
* implementation, as otherwise, a DoPut operation may inadvertently block forever.
*/
default boolean isCancelled() {
return false;
}
}
/** Shut down this client. */
@Override
public void close() throws InterruptedException {
channel.shutdown().awaitTermination(5, TimeUnit.SECONDS);
allocator.close();
}
/** Create a builder for a Flight client. */
public static Builder builder() {
return new Builder();
}
/**
* Create a builder for a Flight client.
*
* @param allocator The allocator to use for the client.
* @param location The location to connect to.
*/
public static Builder builder(BufferAllocator allocator, Location location) {
return new Builder(allocator, location);
}
/** A builder for Flight clients. */
public static final class Builder {
private BufferAllocator allocator;
private Location location;
private boolean forceTls = false;
private int maxInboundMessageSize = FlightServer.MAX_GRPC_MESSAGE_SIZE;
private InputStream trustedCertificates = null;
private InputStream clientCertificate = null;
private InputStream clientKey = null;
private String overrideHostname = null;
private List<FlightClientMiddleware.Factory> middleware = new ArrayList<>();
private boolean verifyServer = true;
private Builder() {}
private Builder(BufferAllocator allocator, Location location) {
this.allocator = Preconditions.checkNotNull(allocator);
this.location = Preconditions.checkNotNull(location);
}
/** Force the client to connect over TLS. */
public Builder useTls() {
this.forceTls = true;
return this;
}
/** Override the hostname checked for TLS. Use with caution in production. */
public Builder overrideHostname(final String hostname) {
this.overrideHostname = hostname;
return this;
}
/** Set the maximum inbound message size. */
public Builder maxInboundMessageSize(int maxSize) {
Preconditions.checkArgument(maxSize > 0);
this.maxInboundMessageSize = maxSize;
return this;
}
/** Set the trusted TLS certificates. */
public Builder trustedCertificates(final InputStream stream) {
this.trustedCertificates = Preconditions.checkNotNull(stream);
return this;
}
/** Set the trusted TLS certificates. */
public Builder clientCertificate(
final InputStream clientCertificate, final InputStream clientKey) {
Preconditions.checkNotNull(clientKey);
this.clientCertificate = Preconditions.checkNotNull(clientCertificate);
this.clientKey = Preconditions.checkNotNull(clientKey);
return this;
}
public Builder allocator(BufferAllocator allocator) {
this.allocator = Preconditions.checkNotNull(allocator);
return this;
}
public Builder location(Location location) {
this.location = Preconditions.checkNotNull(location);
return this;
}
public Builder intercept(FlightClientMiddleware.Factory factory) {
middleware.add(factory);
return this;
}
public Builder verifyServer(boolean verifyServer) {
this.verifyServer = verifyServer;
return this;
}
/** Create the client from this builder. */
public FlightClient build() {
final NettyChannelBuilder builder;
switch (location.getUri().getScheme()) {
case LocationSchemes.GRPC:
case LocationSchemes.GRPC_INSECURE:
case LocationSchemes.GRPC_TLS:
{
builder = NettyChannelBuilder.forAddress(location.toSocketAddress());
break;
}
case LocationSchemes.GRPC_DOMAIN_SOCKET:
{
// The implementation is platform-specific, so we have to find the classes at runtime
builder = NettyChannelBuilder.forAddress(location.toSocketAddress());
try {
try {
// Linux
builder.channelType(
Class.forName("io.netty.channel.epoll.EpollDomainSocketChannel")
.asSubclass(ServerChannel.class));
final EventLoopGroup elg =
Class.forName("io.netty.channel.epoll.EpollEventLoopGroup")
.asSubclass(EventLoopGroup.class)
.getDeclaredConstructor()
.newInstance();
builder.eventLoopGroup(elg);
} catch (ClassNotFoundException e) {
// BSD
builder.channelType(
Class.forName("io.netty.channel.kqueue.KQueueDomainSocketChannel")
.asSubclass(ServerChannel.class));
final EventLoopGroup elg =
Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup")
.asSubclass(EventLoopGroup.class)
.getDeclaredConstructor()
.newInstance();
builder.eventLoopGroup(elg);
}
} catch (ClassNotFoundException
| InstantiationException
| IllegalAccessException
| NoSuchMethodException
| InvocationTargetException e) {
throw new UnsupportedOperationException(
"Could not find suitable Netty native transport implementation for domain socket address.");
}
break;
}
default:
throw new IllegalArgumentException(
"Scheme is not supported: " + location.getUri().getScheme());
}
if (this.forceTls || LocationSchemes.GRPC_TLS.equals(location.getUri().getScheme())) {
builder.useTransportSecurity();
final boolean hasTrustedCerts = this.trustedCertificates != null;
final boolean hasKeyCertPair = this.clientCertificate != null && this.clientKey != null;
if (!this.verifyServer && (hasTrustedCerts || hasKeyCertPair)) {
throw new IllegalArgumentException(
"FlightClient has been configured to disable server verification, "
+ "but certificate options have been specified.");
}
final SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient();
if (!this.verifyServer) {
sslContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE);
} else if (this.trustedCertificates != null
|| this.clientCertificate != null
|| this.clientKey != null) {
if (this.trustedCertificates != null) {
sslContextBuilder.trustManager(this.trustedCertificates);
}
if (this.clientCertificate != null && this.clientKey != null) {
sslContextBuilder.keyManager(this.clientCertificate, this.clientKey);
}
}
try {
builder.sslContext(sslContextBuilder.build());
} catch (SSLException e) {
throw new RuntimeException(e);
}
if (this.overrideHostname != null) {
builder.overrideAuthority(this.overrideHostname);
}
} else {
builder.usePlaintext();
}
builder
.maxTraceEvents(MAX_CHANNEL_TRACE_EVENTS)
.maxInboundMessageSize(maxInboundMessageSize)
.maxInboundMetadataSize(maxInboundMessageSize);
return new FlightClient(allocator, builder.build(), middleware);
}
}
/**
* Helper method to create a call from the asyncStub, method descriptor, and list of calling
* options.
*/
private <RequestT, ResponseT> ClientCall<RequestT, ResponseT> asyncStubNewCall(
MethodDescriptor<RequestT, ResponseT> descriptor, CallOption... options) {
FlightServiceStub wrappedStub = CallOptions.wrapStub(asyncStub, options);
return wrappedStub.getChannel().newCall(descriptor, wrappedStub.getCallOptions());
}
}