ClientInterceptorAdapter.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.flight.grpc;

import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import java.util.ArrayList;
import java.util.List;
import org.apache.arrow.flight.CallInfo;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.FlightClientMiddleware;
import org.apache.arrow.flight.FlightMethod;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightStatusCode;

/**
 * An adapter between Flight client middleware and gRPC interceptors.
 *
 * <p>This is implemented as a single gRPC interceptor that runs all Flight client middleware
 * sequentially.
 */
public class ClientInterceptorAdapter implements ClientInterceptor {

  private final List<FlightClientMiddleware.Factory> factories;

  public ClientInterceptorAdapter(List<FlightClientMiddleware.Factory> factories) {
    this.factories = factories;
  }

  @Override
  public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
      MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
    final List<FlightClientMiddleware> middleware = new ArrayList<>();
    final CallInfo info = new CallInfo(FlightMethod.fromProtocol(method.getFullMethodName()));

    try {
      for (final FlightClientMiddleware.Factory factory : factories) {
        middleware.add(factory.onCallStarted(info));
      }
    } catch (FlightRuntimeException e) {
      // Explicitly propagate
      throw e;
    } catch (StatusRuntimeException e) {
      throw StatusUtils.fromGrpcRuntimeException(e);
    } catch (RuntimeException e) {
      throw StatusUtils.fromThrowable(e);
    }
    return new FlightClientCall<>(next.newCall(method, callOptions), middleware);
  }

  /**
   * The ClientCallListener which hooks into the gRPC request cycle and actually runs middleware at
   * certain points.
   */
  private static class FlightClientCallListener<RespT>
      extends SimpleForwardingClientCallListener<RespT> {

    private final List<FlightClientMiddleware> middleware;
    boolean receivedHeaders;

    public FlightClientCallListener(
        ClientCall.Listener<RespT> responseListener, List<FlightClientMiddleware> middleware) {
      super(responseListener);
      this.middleware = middleware;
      receivedHeaders = false;
    }

    @Override
    public void onHeaders(Metadata headers) {
      receivedHeaders = true;
      final MetadataAdapter adapter = new MetadataAdapter(headers);
      try {
        middleware.forEach(m -> m.onHeadersReceived(adapter));
      } finally {
        // Make sure to always call the gRPC callback to avoid interrupting the gRPC request cycle
        super.onHeaders(headers);
      }
    }

    @Override
    public void onClose(Status status, Metadata trailers) {
      try {
        if (!receivedHeaders) {
          // gRPC doesn't always send response headers if the call errors or completes immediately,
          // but instead
          // consolidates them with the trailers. If we never got headers, assume this happened and
          // run the header
          // callback with the trailers.
          final MetadataAdapter adapter = new MetadataAdapter(trailers);
          middleware.forEach(m -> m.onHeadersReceived(adapter));
        }
        final CallStatus flightStatus = StatusUtils.fromGrpcStatusAndTrailers(status, trailers);
        middleware.forEach(m -> m.onCallCompleted(flightStatus));
      } finally {
        // Make sure to always call the gRPC callback to avoid interrupting the gRPC request cycle
        super.onClose(status, trailers);
      }
    }
  }

  /**
   * The gRPC ClientCall which hooks into the gRPC request cycle and injects our ClientCallListener.
   */
  private static class FlightClientCall<ReqT, RespT>
      extends SimpleForwardingClientCall<ReqT, RespT> {

    private final List<FlightClientMiddleware> middleware;

    public FlightClientCall(
        ClientCall<ReqT, RespT> clientCall, List<FlightClientMiddleware> middleware) {
      super(clientCall);
      this.middleware = middleware;
    }

    @Override
    public void start(Listener<RespT> responseListener, Metadata headers) {
      final MetadataAdapter metadataAdapter = new MetadataAdapter(headers);
      middleware.forEach(m -> m.onBeforeSendingHeaders(metadataAdapter));

      super.start(new FlightClientCallListener<>(responseListener, middleware), headers);
    }

    @Override
    public void cancel(String message, Throwable cause) {
      final CallStatus flightStatus =
          new CallStatus(FlightStatusCode.CANCELLED, cause, message, null);
      middleware.forEach(m -> m.onCallCompleted(flightStatus));
      super.cancel(message, cause);
    }
  }
}