FlightService.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.flight;
import com.google.common.base.Strings;
import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.StreamObserver;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.function.BooleanSupplier;
import java.util.function.Consumer;
import org.apache.arrow.flight.FlightProducer.ServerStreamListener;
import org.apache.arrow.flight.auth.AuthConstants;
import org.apache.arrow.flight.auth.ServerAuthHandler;
import org.apache.arrow.flight.auth.ServerAuthWrapper;
import org.apache.arrow.flight.auth2.Auth2Constants;
import org.apache.arrow.flight.grpc.ContextPropagatingExecutorService;
import org.apache.arrow.flight.grpc.RequestContextAdapter;
import org.apache.arrow.flight.grpc.ServerInterceptorAdapter;
import org.apache.arrow.flight.grpc.StatusUtils;
import org.apache.arrow.flight.impl.Flight;
import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceImplBase;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** GRPC service implementation for a flight server. */
class FlightService extends FlightServiceImplBase {
private static final Logger logger = LoggerFactory.getLogger(FlightService.class);
private static final int PENDING_REQUESTS = 5;
private final BufferAllocator allocator;
private final FlightProducer producer;
private final ServerAuthHandler authHandler;
private final ExecutorService executors;
FlightService(
BufferAllocator allocator,
FlightProducer producer,
ServerAuthHandler authHandler,
ExecutorService executors) {
this.allocator = allocator;
this.producer = producer;
this.authHandler = authHandler;
this.executors = new ContextPropagatingExecutorService(executors);
}
private CallContext makeContext(ServerCallStreamObserver<?> responseObserver) {
// Try to get the peer identity from middleware first (using the auth2 interfaces).
final RequestContext context = RequestContextAdapter.REQUEST_CONTEXT_KEY.get();
String peerIdentity = null;
if (context != null) {
peerIdentity = context.get(Auth2Constants.PEER_IDENTITY_KEY);
}
if (Strings.isNullOrEmpty(peerIdentity)) {
// Try the legacy auth interface, which defaults to empty string.
peerIdentity = AuthConstants.PEER_IDENTITY_KEY.get();
}
return new CallContext(peerIdentity, responseObserver::isCancelled);
}
@Override
public StreamObserver<Flight.HandshakeRequest> handshake(
StreamObserver<Flight.HandshakeResponse> responseObserver) {
// This method is not meaningful with the auth2 interfaces. Authentication would already
// have happened by header/middleware with the auth2 classes.
return ServerAuthWrapper.wrapHandshake(authHandler, responseObserver, executors);
}
@Override
public void listFlights(
Flight.Criteria criteria, StreamObserver<Flight.FlightInfo> responseObserver) {
final StreamPipe<FlightInfo, Flight.FlightInfo> listener =
StreamPipe.wrap(
responseObserver, FlightInfo::toProtocol, this::handleExceptionWithMiddleware);
try {
final CallContext context = makeContext((ServerCallStreamObserver<?>) responseObserver);
producer.listFlights(context, new Criteria(criteria), listener);
} catch (Exception ex) {
listener.onError(ex);
}
// Do NOT call StreamPipe#onCompleted, as the FlightProducer implementation may be asynchronous
}
public void doGetCustom(
Flight.Ticket ticket, StreamObserver<ArrowMessage> responseObserverSimple) {
final ServerCallStreamObserver<ArrowMessage> responseObserver =
(ServerCallStreamObserver<ArrowMessage>) responseObserverSimple;
final GetListener listener =
new GetListener(responseObserver, this::handleExceptionWithMiddleware);
try {
producer.getStream(makeContext(responseObserver), new Ticket(ticket), listener);
} catch (Exception ex) {
listener.error(ex);
}
// Do NOT call GetListener#completed, as the implementation of getStream may be asynchronous
}
@Override
public void doAction(Flight.Action request, StreamObserver<Flight.Result> responseObserver) {
final StreamPipe<Result, Flight.Result> listener =
StreamPipe.wrap(responseObserver, Result::toProtocol, this::handleExceptionWithMiddleware);
try {
final CallContext context = makeContext((ServerCallStreamObserver<?>) responseObserver);
producer.doAction(context, new Action(request), listener);
} catch (Exception ex) {
listener.onError(ex);
}
// Do NOT call StreamPipe#onCompleted, as the FlightProducer implementation may be asynchronous
}
@Override
public void listActions(
Flight.Empty request, StreamObserver<Flight.ActionType> responseObserver) {
final StreamPipe<org.apache.arrow.flight.ActionType, Flight.ActionType> listener =
StreamPipe.wrap(
responseObserver, ActionType::toProtocol, this::handleExceptionWithMiddleware);
try {
final CallContext context = makeContext((ServerCallStreamObserver<?>) responseObserver);
producer.listActions(context, listener);
} catch (Exception ex) {
listener.onError(ex);
}
// Do NOT call StreamPipe#onCompleted, as the FlightProducer implementation may be asynchronous
}
private static class GetListener extends OutboundStreamListenerImpl
implements ServerStreamListener {
private final ServerCallStreamObserver<ArrowMessage> serverCallResponseObserver;
private final Consumer<Throwable> errorHandler;
private Runnable onCancelHandler = null;
private Runnable onReadyHandler = null;
private boolean completed;
public GetListener(
ServerCallStreamObserver<ArrowMessage> responseObserver, Consumer<Throwable> errorHandler) {
super(null, responseObserver);
this.errorHandler = errorHandler;
this.completed = false;
this.serverCallResponseObserver = responseObserver;
this.serverCallResponseObserver.setOnCancelHandler(this::onCancel);
this.serverCallResponseObserver.setOnReadyHandler(this::onReady);
this.serverCallResponseObserver.disableAutoInboundFlowControl();
}
private void onCancel() {
logger.debug("Stream cancelled by client.");
if (onCancelHandler != null) {
onCancelHandler.run();
}
}
private void onReady() {
if (onReadyHandler != null) {
onReadyHandler.run();
}
}
@Override
public void setOnCancelHandler(Runnable handler) {
this.onCancelHandler = handler;
}
@Override
public void setOnReadyHandler(Runnable handler) {
this.onReadyHandler = handler;
}
@Override
public boolean isCancelled() {
return serverCallResponseObserver.isCancelled();
}
@Override
protected void waitUntilStreamReady() {
// Don't do anything - service implementations are expected to manage backpressure themselves
}
@Override
public void error(Throwable ex) {
if (!completed) {
completed = true;
super.error(ex);
} else {
errorHandler.accept(ex);
}
}
@Override
public void completed() {
if (!completed) {
completed = true;
super.completed();
} else {
errorHandler.accept(new IllegalStateException("Tried to complete already-completed call"));
}
}
}
public StreamObserver<ArrowMessage> doPutCustom(
final StreamObserver<Flight.PutResult> responseObserverSimple) {
ServerCallStreamObserver<Flight.PutResult> responseObserver =
(ServerCallStreamObserver<Flight.PutResult>) responseObserverSimple;
responseObserver.disableAutoInboundFlowControl();
responseObserver.request(1);
final StreamPipe<PutResult, Flight.PutResult> ackStream =
StreamPipe.wrap(
responseObserver, PutResult::toProtocol, this::handleExceptionWithMiddleware);
final FlightStream fs =
new FlightStream(
allocator,
PENDING_REQUESTS,
/* server-upload streams are not cancellable */ null,
responseObserver::request);
// When the ackStream is completed, the FlightStream will be closed with it
ackStream.setAutoCloseable(fs);
final StreamObserver<ArrowMessage> observer = fs.asObserver();
Future<?> unused =
executors.submit(
() -> {
try {
producer.acceptPut(makeContext(responseObserver), fs, ackStream).run();
} catch (Throwable ex) {
ackStream.onError(ex);
} finally {
// ARROW-6136: Close the stream if and only if acceptPut hasn't closed it itself
// We don't do this for other streams since the implementation may be asynchronous
ackStream.ensureCompleted();
}
});
return observer;
}
@Override
public void getFlightInfo(
Flight.FlightDescriptor request, StreamObserver<Flight.FlightInfo> responseObserver) {
final FlightInfo info;
try {
info =
producer.getFlightInfo(
makeContext((ServerCallStreamObserver<?>) responseObserver),
new FlightDescriptor(request));
} catch (Exception ex) {
// Don't capture exceptions from onNext or onCompleted with this block - because then we can't
// call onError
responseObserver.onError(StatusUtils.toGrpcException(ex));
return;
}
responseObserver.onNext(info.toProtocol());
responseObserver.onCompleted();
}
@Override
public void pollFlightInfo(
Flight.FlightDescriptor request, StreamObserver<Flight.PollInfo> responseObserver) {
final PollInfo info;
try {
info =
producer.pollFlightInfo(
makeContext((ServerCallStreamObserver<?>) responseObserver),
new FlightDescriptor(request));
} catch (Exception ex) {
// Don't capture exceptions from onNext or onCompleted with this block - because then we can't
// call onError
responseObserver.onError(StatusUtils.toGrpcException(ex));
return;
}
responseObserver.onNext(info.toProtocol());
responseObserver.onCompleted();
}
/** Broadcast the given exception to all registered middleware. */
private void handleExceptionWithMiddleware(Throwable t) {
final Map<FlightServerMiddleware.Key<?>, FlightServerMiddleware> middleware =
ServerInterceptorAdapter.SERVER_MIDDLEWARE_KEY.get();
if (middleware == null || middleware.isEmpty()) {
logger.error("Uncaught exception in Flight method body", t);
return;
}
middleware.forEach((k, v) -> v.onCallErrored(t));
}
@Override
public void getSchema(
Flight.FlightDescriptor request, StreamObserver<Flight.SchemaResult> responseObserver) {
try {
SchemaResult result =
producer.getSchema(
makeContext((ServerCallStreamObserver<?>) responseObserver),
new FlightDescriptor(request));
responseObserver.onNext(result.toProtocol());
responseObserver.onCompleted();
} catch (Exception ex) {
responseObserver.onError(StatusUtils.toGrpcException(ex));
}
}
/** Ensures that other resources are cleaned up when the service finishes its call. */
private static class ExchangeListener extends GetListener {
private AutoCloseable resource;
private boolean closed = false;
private Runnable onCancelHandler = null;
public ExchangeListener(
ServerCallStreamObserver<ArrowMessage> responseObserver, Consumer<Throwable> errorHandler) {
super(responseObserver, errorHandler);
this.resource = null;
super.setOnCancelHandler(
() -> {
try {
if (onCancelHandler != null) {
onCancelHandler.run();
}
} finally {
cleanup();
}
});
}
private void cleanup() {
if (closed) {
// Prevent double-free. gRPC will call the OnCancelHandler even on a normal call end, which
// means that
// we'll double-free without this guard.
return;
}
closed = true;
try {
AutoCloseables.close(resource);
} catch (Exception e) {
throw CallStatus.INTERNAL
.withCause(e)
.withDescription("Server internal error cleaning up resources")
.toRuntimeException();
}
}
@Override
public void error(Throwable ex) {
try {
this.cleanup();
} finally {
super.error(ex);
}
}
@Override
public void completed() {
try {
this.cleanup();
} finally {
super.completed();
}
}
@Override
public void setOnCancelHandler(Runnable handler) {
onCancelHandler = handler;
}
}
public StreamObserver<ArrowMessage> doExchangeCustom(
StreamObserver<ArrowMessage> responseObserverSimple) {
final ServerCallStreamObserver<ArrowMessage> responseObserver =
(ServerCallStreamObserver<ArrowMessage>) responseObserverSimple;
final ExchangeListener listener =
new ExchangeListener(responseObserver, this::handleExceptionWithMiddleware);
final FlightStream fs =
new FlightStream(
allocator,
PENDING_REQUESTS,
/* server-upload streams are not cancellable */ null,
responseObserver::request);
// When service completes the call, this cleans up the FlightStream
listener.resource = fs;
responseObserver.disableAutoInboundFlowControl();
responseObserver.request(1);
final StreamObserver<ArrowMessage> observer = fs.asObserver();
try {
Future<?> unused =
executors.submit(
() -> {
try {
producer.doExchange(makeContext(responseObserver), fs, listener);
} catch (Exception ex) {
listener.error(ex);
}
// We do not clean up or close anything here, to allow long-running asynchronous
// implementations.
// It is the service's responsibility to call completed() or error(), which will
// then clean up the FlightStream.
});
} catch (Exception ex) {
listener.error(ex);
}
return observer;
}
/** Call context for the service. */
static class CallContext implements FlightProducer.CallContext {
private final String peerIdentity;
private final BooleanSupplier isCancelled;
CallContext(final String peerIdentity, BooleanSupplier isCancelled) {
this.peerIdentity = peerIdentity;
this.isCancelled = isCancelled;
}
@Override
public String peerIdentity() {
return peerIdentity;
}
@Override
public boolean isCancelled() {
return this.isCancelled.getAsBoolean();
}
@Override
public <T extends FlightServerMiddleware> T getMiddleware(FlightServerMiddleware.Key<T> key) {
final Map<FlightServerMiddleware.Key<?>, FlightServerMiddleware> middleware =
ServerInterceptorAdapter.SERVER_MIDDLEWARE_KEY.get();
if (middleware == null) {
return null;
}
final FlightServerMiddleware m = middleware.get(key);
if (m == null) {
return null;
}
@SuppressWarnings("unchecked")
final T result = (T) m;
return result;
}
@Override
public Map<FlightServerMiddleware.Key<?>, FlightServerMiddleware> getMiddleware() {
final Map<FlightServerMiddleware.Key<?>, FlightServerMiddleware> middleware =
ServerInterceptorAdapter.SERVER_MIDDLEWARE_KEY.get();
if (middleware == null) {
return Collections.emptyMap();
}
// This is an unmodifiable map
return middleware;
}
}
}