FlightEndpointDataQueue.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.driver.jdbc.utils;

import static java.lang.String.format;
import static java.util.Collections.synchronizedSet;
import static org.apache.arrow.util.Preconditions.checkNotNull;
import static org.apache.arrow.util.Preconditions.checkState;

import java.sql.SQLException;
import java.sql.SQLTimeoutException;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.arrow.driver.jdbc.client.CloseableEndpointStreamPair;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightStream;
import org.apache.calcite.avatica.AvaticaConnection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Auxiliary class used to handle consuming of multiple {@link FlightStream}.
 *
 * <p>The usage follows this routine:
 *
 * <ol>
 *   <li>Create a <code>FlightStreamQueue</code>;
 *   <li>Call <code>enqueue(FlightStream)</code> for all streams to be consumed;
 *   <li>Call <code>next()</code> to get a FlightStream that is ready to consume
 *   <li>Consume the given FlightStream and add it back to the queue - call <code>
 *       enqueue(FlightStream)</code>
 *   <li>Repeat from (3) until <code>next()</code> returns null.
 * </ol>
 */
public class FlightEndpointDataQueue implements AutoCloseable {
  private static final Logger LOGGER = LoggerFactory.getLogger(FlightEndpointDataQueue.class);
  private final CompletionService<CloseableEndpointStreamPair> completionService;
  private final Set<Future<CloseableEndpointStreamPair>> futures = synchronizedSet(new HashSet<>());
  private final Set<CloseableEndpointStreamPair> endpointsToClose =
      synchronizedSet(new HashSet<>());
  private final AtomicBoolean closed = new AtomicBoolean();

  /** Instantiate a new FlightStreamQueue. */
  protected FlightEndpointDataQueue(
      final CompletionService<CloseableEndpointStreamPair> executorService) {
    completionService = checkNotNull(executorService);
  }

  /**
   * Creates a new {@link FlightEndpointDataQueue} from the provided {@link ExecutorService}.
   *
   * @param service the service from which to create a new queue.
   * @return a new queue.
   */
  public static FlightEndpointDataQueue createNewQueue(final ExecutorService service) {
    return new FlightEndpointDataQueue(new ExecutorCompletionService<>(service));
  }

  /**
   * Gets whether this queue is closed.
   *
   * @return a boolean indicating whether this resource is closed.
   */
  public boolean isClosed() {
    return closed.get();
  }

  /** Auxiliary functional interface for getting ready-to-consume FlightStreams. */
  @FunctionalInterface
  interface EndpointStreamSupplier {
    Future<CloseableEndpointStreamPair> get() throws SQLException;
  }

  private CloseableEndpointStreamPair next(final EndpointStreamSupplier endpointStreamSupplier)
      throws SQLException {
    checkOpen();
    while (!futures.isEmpty()) {
      final Future<CloseableEndpointStreamPair> future = endpointStreamSupplier.get();
      futures.remove(future);
      try {
        final CloseableEndpointStreamPair endpoint = future.get();
        // Get the next FlightStream that has a root with content.
        if (endpoint != null) {
          return endpoint;
        }
      } catch (final ExecutionException e) {
        // Unwrap one layer
        final Throwable cause = e.getCause();
        if (cause instanceof FlightRuntimeException) {
          throw (FlightRuntimeException) cause;
        }
        throw AvaticaConnection.HELPER.wrap(e.getMessage(), e);
      } catch (InterruptedException | CancellationException e) {
        throw AvaticaConnection.HELPER.wrap(e.getMessage(), e);
      }
    }
    return null;
  }

  /**
   * Blocking request with timeout to get the next ready FlightStream in queue.
   *
   * @param timeoutValue the amount of time to be waited
   * @param timeoutUnit the timeoutValue time unit
   * @return a FlightStream that is ready to consume or null if all FlightStreams are ended.
   */
  public CloseableEndpointStreamPair next(final long timeoutValue, final TimeUnit timeoutUnit)
      throws SQLException {
    return next(
        () -> {
          try {
            final Future<CloseableEndpointStreamPair> future =
                completionService.poll(timeoutValue, timeoutUnit);
            if (future != null) {
              return future;
            }
          } catch (final InterruptedException e) {
            throw new SQLTimeoutException("Query was interrupted", e);
          }

          throw new SQLTimeoutException(
              String.format("Query timed out after %d %s", timeoutValue, timeoutUnit));
        });
  }

  /**
   * Blocking request to get the next ready FlightStream in queue.
   *
   * @return a FlightStream that is ready to consume or null if all FlightStreams are ended.
   */
  public CloseableEndpointStreamPair next() throws SQLException {
    return next(
        () -> {
          try {
            return completionService.take();
          } catch (final InterruptedException e) {
            throw AvaticaConnection.HELPER.wrap(e.getMessage(), e);
          }
        });
  }

  /** Checks if this queue is open. */
  public synchronized void checkOpen() {
    checkState(!isClosed(), format("%s closed", this.getClass().getSimpleName()));
  }

  /** Readily adds given {@link FlightStream}s to the queue. */
  public void enqueue(final Collection<CloseableEndpointStreamPair> endpointRequests) {
    endpointRequests.forEach(this::enqueue);
  }

  /** Adds given {@link FlightStream} to the queue. */
  public synchronized void enqueue(final CloseableEndpointStreamPair endpointRequest) {
    checkNotNull(endpointRequest);
    checkOpen();
    endpointsToClose.add(endpointRequest);
    futures.add(
        completionService.submit(
            () -> {
              // `FlightStream#next` will block until new data can be read or stream is over.
              while (endpointRequest.getStream().next()) {
                if (endpointRequest.getStream().getRoot().getRowCount() > 0) {
                  return endpointRequest;
                }
              }
              return null;
            }));
  }

  private static boolean isCallStatusCancelled(final Exception e) {
    return e.getCause() instanceof FlightRuntimeException
        && ((FlightRuntimeException) e.getCause()).status().code() == CallStatus.CANCELLED.code();
  }

  @Override
  public synchronized void close() throws SQLException {
    if (isClosed()) {
      return;
    }

    final Set<SQLException> exceptions = new HashSet<>();
    try {
      for (final CloseableEndpointStreamPair endpointToClose : endpointsToClose) {
        try {
          endpointToClose.getStream().cancel("Cancelling this FlightStream.", null);
        } catch (final Exception e) {
          final String errorMsg = "Failed to cancel a FlightStream.";
          LOGGER.error(errorMsg, e);
          exceptions.add(new SQLException(errorMsg, e));
        }
      }
      futures.forEach(
          future -> {
            try {
              // TODO: Consider adding a hardcoded timeout?
              future.get();
            } catch (final InterruptedException | ExecutionException e) {
              // Ignore if future is already cancelled
              if (!isCallStatusCancelled(e)) {
                final String errorMsg = "Failed consuming a future during close.";
                LOGGER.error(errorMsg, e);
                exceptions.add(new SQLException(errorMsg, e));
              }
            }
          });
      for (final CloseableEndpointStreamPair endpointToClose : endpointsToClose) {
        try {
          endpointToClose.close();
        } catch (final Exception e) {
          final String errorMsg = "Failed to close a FlightStream.";
          LOGGER.error(errorMsg, e);
          exceptions.add(new SQLException(errorMsg, e));
        }
      }
    } finally {
      endpointsToClose.clear();
      futures.clear();
      closed.set(true);
    }
    if (!exceptions.isEmpty()) {
      final SQLException sqlException = new SQLException("Failed to close streams.");
      exceptions.forEach(sqlException::setNextException);
      throw sqlException;
    }
  }
}