FlightServer.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.flight;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.grpc.Server;
import io.grpc.ServerInterceptors;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.NettyServerBuilder;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.ServerChannel;
import io.netty.handler.ssl.ClientAuth;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.InvocationTargetException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import javax.net.ssl.SSLException;
import org.apache.arrow.flight.auth.ServerAuthHandler;
import org.apache.arrow.flight.auth.ServerAuthInterceptor;
import org.apache.arrow.flight.auth2.Auth2Constants;
import org.apache.arrow.flight.auth2.CallHeaderAuthenticator;
import org.apache.arrow.flight.auth2.ServerCallHeaderAuthMiddleware;
import org.apache.arrow.flight.grpc.ServerBackpressureThresholdInterceptor;
import org.apache.arrow.flight.grpc.ServerInterceptorAdapter;
import org.apache.arrow.flight.grpc.ServerInterceptorAdapter.KeyFactory;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.util.VisibleForTesting;
/**
* Generic server of flight data that is customized via construction with delegate classes for the
* actual logic. The server currently uses GRPC as its transport mechanism.
*/
public class FlightServer implements AutoCloseable {
private static final org.slf4j.Logger logger =
org.slf4j.LoggerFactory.getLogger(FlightServer.class);
private final Location location;
private final Server server;
// The executor used by the gRPC server. We don't use it here, but we do need to clean it up with
// the server.
// May be null, if a user-supplied executor was provided (as we do not want to clean that up)
@VisibleForTesting final ExecutorService grpcExecutor;
/** The maximum size of an individual gRPC message. This effectively disables the limit. */
static final int MAX_GRPC_MESSAGE_SIZE = Integer.MAX_VALUE;
/** The default number of bytes that can be queued on an output stream before blocking. */
public static final int DEFAULT_BACKPRESSURE_THRESHOLD = 10 * 1024 * 1024; // 10MB
/** Create a new instance from a gRPC server. For internal use only. */
private FlightServer(Location location, Server server, ExecutorService grpcExecutor) {
this.location = location;
this.server = server;
this.grpcExecutor = grpcExecutor;
}
/** Start the server. */
public FlightServer start() throws IOException {
server.start();
return this;
}
/** Get the port the server is running on (if applicable). */
public int getPort() {
return server.getPort();
}
/** Get the location for this server. */
public Location getLocation() {
if (location.getUri().getPort() == 0) {
// If the server was bound to port 0, replace the port in the location with the real port.
final URI uri = location.getUri();
try {
return new Location(
new URI(
uri.getScheme(),
uri.getUserInfo(),
uri.getHost(),
getPort(),
uri.getPath(),
uri.getQuery(),
uri.getFragment()));
} catch (URISyntaxException e) {
// We don't expect this to happen
throw new RuntimeException(e);
}
}
return location;
}
/** Block until the server shuts down. */
public void awaitTermination() throws InterruptedException {
server.awaitTermination();
}
/** Request that the server shut down. */
public void shutdown() {
server.shutdown();
if (grpcExecutor != null) {
grpcExecutor.shutdown();
}
}
/**
* Wait for the server to shut down with a timeout.
*
* @return true if the server shut down successfully.
*/
public boolean awaitTermination(final long timeout, final TimeUnit unit)
throws InterruptedException {
return server.awaitTermination(timeout, unit);
}
/** Shutdown the server, waits for up to 6 seconds for successful shutdown before returning. */
@Override
public void close() throws InterruptedException {
shutdown();
final boolean terminated = awaitTermination(3000, TimeUnit.MILLISECONDS);
if (terminated) {
logger.debug("Server was terminated within 3s");
return;
}
// get more aggressive in termination.
server.shutdownNow();
int count = 0;
while (!server.isTerminated() && count < 30) {
count++;
logger.debug("Waiting for termination");
Thread.sleep(100);
}
if (!server.isTerminated()) {
logger.warn("Couldn't shutdown server, resources likely will be leaked.");
}
}
/** Create a builder for a Flight server. */
public static Builder builder() {
return new Builder();
}
/** Create a builder for a Flight server. */
public static Builder builder(
BufferAllocator allocator, Location location, FlightProducer producer) {
return new Builder(allocator, location, producer);
}
/** A builder for Flight servers. */
public static final class Builder {
private BufferAllocator allocator;
private Location location;
private FlightProducer producer;
private final Map<String, Object> builderOptions;
private ServerAuthHandler authHandler = ServerAuthHandler.NO_OP;
private CallHeaderAuthenticator headerAuthenticator = CallHeaderAuthenticator.NO_OP;
private ExecutorService executor = null;
private int maxInboundMessageSize = MAX_GRPC_MESSAGE_SIZE;
private int backpressureThreshold = DEFAULT_BACKPRESSURE_THRESHOLD;
private InputStream certChain;
private InputStream key;
private InputStream mTlsCACert;
private SslContext sslContext;
private final List<KeyFactory<?>> interceptors;
// Keep track of inserted interceptors
private final Set<String> interceptorKeys;
Builder() {
builderOptions = new HashMap<>();
interceptors = new ArrayList<>();
interceptorKeys = new HashSet<>();
}
Builder(BufferAllocator allocator, Location location, FlightProducer producer) {
this();
this.allocator = Preconditions.checkNotNull(allocator);
this.location = Preconditions.checkNotNull(location);
this.producer = Preconditions.checkNotNull(producer);
}
/** Create the server for this builder. */
public FlightServer build() {
// Add the auth middleware if applicable.
if (headerAuthenticator != CallHeaderAuthenticator.NO_OP) {
this.middleware(
FlightServerMiddleware.Key.of(Auth2Constants.AUTHORIZATION_HEADER),
new ServerCallHeaderAuthMiddleware.Factory(headerAuthenticator));
}
this.middleware(FlightConstants.HEADER_KEY, new ServerHeaderMiddleware.Factory());
final NettyServerBuilder builder;
switch (location.getUri().getScheme()) {
case LocationSchemes.GRPC_DOMAIN_SOCKET:
{
// The implementation is platform-specific, so we have to find the classes at runtime
builder = NettyServerBuilder.forAddress(location.toSocketAddress());
try {
try {
// Linux
builder.channelType(
Class.forName("io.netty.channel.epoll.EpollServerDomainSocketChannel")
.asSubclass(ServerChannel.class));
final EventLoopGroup elg =
Class.forName("io.netty.channel.epoll.EpollEventLoopGroup")
.asSubclass(EventLoopGroup.class)
.getConstructor()
.newInstance();
builder.bossEventLoopGroup(elg).workerEventLoopGroup(elg);
} catch (ClassNotFoundException e) {
// BSD
builder.channelType(
Class.forName("io.netty.channel.kqueue.KQueueServerDomainSocketChannel")
.asSubclass(ServerChannel.class));
final EventLoopGroup elg =
Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup")
.asSubclass(EventLoopGroup.class)
.getConstructor()
.newInstance();
builder.bossEventLoopGroup(elg).workerEventLoopGroup(elg);
}
} catch (ClassNotFoundException
| InstantiationException
| IllegalAccessException
| NoSuchMethodException
| InvocationTargetException e) {
throw new UnsupportedOperationException(
"Could not find suitable Netty native transport implementation for domain socket address.");
}
break;
}
case LocationSchemes.GRPC:
case LocationSchemes.GRPC_INSECURE:
{
builder = NettyServerBuilder.forAddress(location.toSocketAddress());
break;
}
case LocationSchemes.GRPC_TLS:
{
if (certChain == null) {
throw new IllegalArgumentException(
"Must provide a certificate and key to serve gRPC over TLS");
}
builder = NettyServerBuilder.forAddress(location.toSocketAddress());
break;
}
default:
throw new IllegalArgumentException(
"Scheme is not supported: " + location.getUri().getScheme());
}
if (certChain != null) {
SslContextBuilder sslContextBuilder = GrpcSslContexts.forServer(certChain, key);
if (mTlsCACert != null) {
sslContextBuilder.clientAuth(ClientAuth.REQUIRE).trustManager(mTlsCACert);
}
try {
sslContext = sslContextBuilder.build();
} catch (SSLException e) {
throw new RuntimeException(e);
} finally {
closeMTlsCACert();
closeCertChain();
closeKey();
}
builder.sslContext(sslContext);
}
// Share one executor between the gRPC service, DoPut, and Handshake
final ExecutorService exec;
// We only want to have FlightServer close the gRPC executor if we created it here. We should
// not close
// user-supplied executors.
final ExecutorService grpcExecutor;
if (executor != null) {
exec = executor;
grpcExecutor = null;
} else {
exec =
Executors.newCachedThreadPool(
// Name threads for better debuggability
new ThreadFactoryBuilder()
.setNameFormat("flight-server-default-executor-%d")
.build());
grpcExecutor = exec;
}
final FlightBindingService flightService =
new FlightBindingService(allocator, producer, authHandler, exec);
builder
.executor(exec)
.maxInboundMessageSize(maxInboundMessageSize)
.addService(
ServerInterceptors.intercept(
flightService,
new ServerBackpressureThresholdInterceptor(backpressureThreshold),
new ServerAuthInterceptor(authHandler)));
// Allow hooking into the gRPC builder. This is not guaranteed to be available on all Arrow
// versions or
// Flight implementations.
builderOptions.computeIfPresent(
"grpc.builderConsumer",
(key, builderConsumer) -> {
final Consumer<NettyServerBuilder> consumer =
(Consumer<NettyServerBuilder>) builderConsumer;
consumer.accept(builder);
return null;
});
// Allow explicitly setting some Netty-specific options
builderOptions.computeIfPresent(
"netty.channelType",
(key, channelType) -> {
builder.channelType((Class<? extends ServerChannel>) channelType);
return null;
});
builderOptions.computeIfPresent(
"netty.bossEventLoopGroup",
(key, elg) -> {
builder.bossEventLoopGroup((EventLoopGroup) elg);
return null;
});
builderOptions.computeIfPresent(
"netty.workerEventLoopGroup",
(key, elg) -> {
builder.workerEventLoopGroup((EventLoopGroup) elg);
return null;
});
builder.intercept(new ServerInterceptorAdapter(interceptors));
return new FlightServer(location, builder.build(), grpcExecutor);
}
/**
* Set the maximum size of a message. Defaults to "unlimited", depending on the underlying
* transport.
*/
public Builder maxInboundMessageSize(int maxMessageSize) {
this.maxInboundMessageSize = maxMessageSize;
return this;
}
/**
* Set the number of bytes that may be queued on a server output stream before writes are
* blocked.
*/
public Builder backpressureThreshold(int backpressureThreshold) {
Preconditions.checkArgument(backpressureThreshold > 0);
this.backpressureThreshold = backpressureThreshold;
return this;
}
/**
* A small utility function to ensure that InputStream attributes. are closed if they are not
* null
*
* @param stream The InputStream to close (if it is not null).
*/
private void closeInputStreamIfNotNull(InputStream stream) {
if (stream != null) {
try {
stream.close();
} catch (IOException expected) {
// stream closes gracefully, doesn't expect an exception.
}
}
}
/**
* A small utility function to ensure that the certChain attribute is closed if it is not null.
* It then sets the attribute to null.
*/
private void closeCertChain() {
closeInputStreamIfNotNull(certChain);
certChain = null;
}
/**
* A small utility function to ensure that the key attribute is closed if it is not null. It
* then sets the attribute to null.
*/
private void closeKey() {
closeInputStreamIfNotNull(key);
key = null;
}
/**
* A small utility function to ensure that the mTlsCACert attribute is closed if it is not null.
* It then sets the attribute to null.
*/
private void closeMTlsCACert() {
closeInputStreamIfNotNull(mTlsCACert);
mTlsCACert = null;
}
/**
* Enable TLS on the server.
*
* @param certChain The certificate chain to use.
* @param key The private key to use.
*/
public Builder useTls(final File certChain, final File key) throws IOException {
closeCertChain();
this.certChain = new FileInputStream(certChain);
closeKey();
this.key = new FileInputStream(key);
return this;
}
/**
* Enable Client Verification via mTLS on the server.
*
* @param mTlsCACert The CA certificate to use for verifying clients.
*/
public Builder useMTlsClientVerification(final File mTlsCACert) throws IOException {
closeMTlsCACert();
this.mTlsCACert = new FileInputStream(mTlsCACert);
return this;
}
/**
* Enable TLS on the server.
*
* @param certChain The certificate chain to use.
* @param key The private key to use.
*/
public Builder useTls(final InputStream certChain, final InputStream key) throws IOException {
closeCertChain();
this.certChain = certChain;
closeKey();
this.key = key;
return this;
}
/**
* Enable mTLS on the server.
*
* @param mTlsCACert The CA certificate to use for verifying clients.
*/
public Builder useMTlsClientVerification(final InputStream mTlsCACert) throws IOException {
closeMTlsCACert();
this.mTlsCACert = mTlsCACert;
return this;
}
/**
* Set the executor used by the server.
*
* <p>Flight will NOT take ownership of the executor. The application must clean it up if one is
* provided. (If not provided, Flight will use a default executor which it will clean up.)
*/
public Builder executor(ExecutorService executor) {
this.executor = executor;
return this;
}
/** Set the authentication handler. */
public Builder authHandler(ServerAuthHandler authHandler) {
this.authHandler = authHandler;
return this;
}
/** Set the header-based authentication mechanism. */
public Builder headerAuthenticator(CallHeaderAuthenticator headerAuthenticator) {
this.headerAuthenticator = headerAuthenticator;
return this;
}
/** Provide a transport-specific option. Not guaranteed to have any effect. */
public Builder transportHint(final String key, Object option) {
builderOptions.put(key, option);
return this;
}
/**
* Add a Flight middleware component to inspect and modify requests to this service.
*
* @param key An identifier for this middleware component. Service implementations can retrieve
* the middleware instance for the current call using {@link
* org.apache.arrow.flight.FlightProducer.CallContext}.
* @param factory A factory for the middleware.
* @param <T> The middleware type.
* @throws IllegalArgumentException if the key already exists
*/
public <T extends FlightServerMiddleware> Builder middleware(
final FlightServerMiddleware.Key<T> key, final FlightServerMiddleware.Factory<T> factory) {
if (interceptorKeys.contains(key.key)) {
throw new IllegalArgumentException("Key already exists: " + key.key);
}
interceptors.add(new KeyFactory<>(key, factory));
interceptorKeys.add(key.key);
return this;
}
public Builder allocator(BufferAllocator allocator) {
this.allocator = Preconditions.checkNotNull(allocator);
return this;
}
public Builder location(Location location) {
this.location = Preconditions.checkNotNull(location);
return this;
}
public Builder producer(FlightProducer producer) {
this.producer = Preconditions.checkNotNull(producer);
return this;
}
}
}