ServerInterceptorAdapter.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.flight.grpc;

import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.ForwardingServerCall.SimpleForwardingServerCall;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCall.Listener;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.arrow.flight.CallInfo;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.FlightConstants;
import org.apache.arrow.flight.FlightMethod;
import org.apache.arrow.flight.FlightProducer.CallContext;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightServerMiddleware;
import org.apache.arrow.flight.FlightServerMiddleware.Key;

/**
 * An adapter between Flight middleware and a gRPC interceptor.
 *
 * <p>This is implemented as a single gRPC interceptor that runs all Flight server middleware
 * sequentially. Flight middleware instances are stored in the gRPC Context so their state is
 * accessible later.
 */
public class ServerInterceptorAdapter implements ServerInterceptor {

  /**
   * A combination of a middleware Key and factory.
   *
   * @param <T> The middleware type.
   */
  public static class KeyFactory<T extends FlightServerMiddleware> {

    private final FlightServerMiddleware.Key<T> key;
    private final FlightServerMiddleware.Factory<T> factory;

    public KeyFactory(
        FlightServerMiddleware.Key<T> key, FlightServerMiddleware.Factory<T> factory) {
      this.key = key;
      this.factory = factory;
    }
  }

  /**
   * The {@link Context.Key} that stores the Flight middleware active for a particular call.
   *
   * <p>Applications should not use this directly. Instead, see {@link
   * CallContext#getMiddleware(Key)}.
   */
  public static final Context.Key<Map<FlightServerMiddleware.Key<?>, FlightServerMiddleware>>
      SERVER_MIDDLEWARE_KEY = Context.key("arrow.flight.server_middleware");

  private final List<KeyFactory<?>> factories;

  public ServerInterceptorAdapter(List<KeyFactory<?>> factories) {
    this.factories = factories;
  }

  @Override
  public <ReqT, RespT> Listener<ReqT> interceptCall(
      ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
    if (!FlightConstants.SERVICE.equals(call.getMethodDescriptor().getServiceName())) {
      return Contexts.interceptCall(Context.current(), call, headers, next);
    }

    final CallInfo info =
        new CallInfo(FlightMethod.fromProtocol(call.getMethodDescriptor().getFullMethodName()));
    final List<FlightServerMiddleware> middleware = new ArrayList<>();
    // Use LinkedHashMap to preserve insertion order
    final Map<FlightServerMiddleware.Key<?>, FlightServerMiddleware> middlewareMap =
        new LinkedHashMap<>();
    final MetadataAdapter headerAdapter = new MetadataAdapter(headers);
    final RequestContextAdapter requestContextAdapter = new RequestContextAdapter();
    for (final KeyFactory<?> factory : factories) {
      final FlightServerMiddleware m;
      try {
        m = factory.factory.onCallStarted(info, headerAdapter, requestContextAdapter);
      } catch (FlightRuntimeException e) {
        // Cancel call
        call.close(StatusUtils.toGrpcStatus(e.status()), new Metadata());
        return new Listener<ReqT>() {};
      }
      middleware.add(m);
      middlewareMap.put(factory.key, m);
    }

    // Inject the middleware into the context so RPC method implementations can communicate with
    // middleware instances
    final Context contextWithMiddlewareAndRequestsOptions =
        Context.current()
            .withValue(SERVER_MIDDLEWARE_KEY, Collections.unmodifiableMap(middlewareMap))
            .withValue(RequestContextAdapter.REQUEST_CONTEXT_KEY, requestContextAdapter);

    final SimpleForwardingServerCall<ReqT, RespT> forwardingServerCall =
        new SimpleForwardingServerCall<ReqT, RespT>(call) {
          boolean sentHeaders = false;

          @Override
          public void sendHeaders(Metadata headers) {
            sentHeaders = true;
            try {
              final MetadataAdapter headerAdapter = new MetadataAdapter(headers);
              middleware.forEach(m -> m.onBeforeSendingHeaders(headerAdapter));
            } finally {
              // Make sure to always call the gRPC callback to avoid interrupting the gRPC request
              // cycle
              super.sendHeaders(headers);
            }
          }

          @Override
          public void close(Status status, Metadata trailers) {
            try {
              if (!sentHeaders) {
                // gRPC doesn't always send response headers if the call errors or completes
                // immediately
                final MetadataAdapter headerAdapter = new MetadataAdapter(trailers);
                middleware.forEach(m -> m.onBeforeSendingHeaders(headerAdapter));
              }
            } finally {
              // Make sure to always call the gRPC callback to avoid interrupting the gRPC request
              // cycle
              super.close(status, trailers);
            }

            final CallStatus flightStatus = StatusUtils.fromGrpcStatus(status);
            middleware.forEach(m -> m.onCallCompleted(flightStatus));
          }
        };
    return Contexts.interceptCall(
        contextWithMiddlewareAndRequestsOptions, forwardingServerCall, headers, next);
  }
}