FlightStream.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.flight;
import com.google.common.util.concurrent.SettableFuture;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.arrow.flight.ArrowMessage.HeaderType;
import org.apache.arrow.flight.grpc.StatusUtils;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorLoader;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.types.MetadataVersion;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.DictionaryUtility;
import org.apache.arrow.vector.validate.MetadataV4UnionChecker;
/** An adaptor between protobuf streams and flight data streams. */
public class FlightStream implements AutoCloseable {
// Use AutoCloseable sentinel objects to simplify logic in #close
private final AutoCloseable DONE =
new AutoCloseable() {
@Override
public void close() throws Exception {}
};
private final AutoCloseable DONE_EX =
new AutoCloseable() {
@Override
public void close() throws Exception {}
};
private final BufferAllocator allocator;
private final Cancellable cancellable;
private final LinkedBlockingQueue<AutoCloseable> queue = new LinkedBlockingQueue<>();
private final SettableFuture<VectorSchemaRoot> root = SettableFuture.create();
private final SettableFuture<FlightDescriptor> descriptor = SettableFuture.create();
private final int pendingTarget;
private final Requestor requestor;
// The completion flags.
// This flag is only updated as the user iterates through the data, i.e. it tracks whether the
// user has read all the
// data and closed the stream
final CompletableFuture<Void> completed;
// This flag is immediately updated when gRPC signals that the server has ended the call. This is
// used to make sure
// we don't block forever trying to write to a server that has rejected a call.
final CompletableFuture<Void> cancelled;
private final AtomicInteger pending = new AtomicInteger();
private volatile VectorSchemaRoot fulfilledRoot;
private DictionaryProvider.MapDictionaryProvider dictionaries;
private volatile VectorLoader loader;
private volatile Throwable ex;
private volatile ArrowBuf applicationMetadata = null;
@VisibleForTesting volatile MetadataVersion metadataVersion = null;
/**
* Constructs a new instance.
*
* @param allocator The allocator to use for creating/reallocating buffers for Vectors.
* @param pendingTarget Target number of messages to receive.
* @param cancellable Used to cancel mid-stream requests.
* @param requestor A callback to determine how many pending items there are.
*/
public FlightStream(
BufferAllocator allocator, int pendingTarget, Cancellable cancellable, Requestor requestor) {
Objects.requireNonNull(allocator);
Objects.requireNonNull(requestor);
this.allocator = allocator;
this.pendingTarget = pendingTarget;
this.cancellable = cancellable;
this.requestor = requestor;
this.dictionaries = new DictionaryProvider.MapDictionaryProvider();
this.completed = new CompletableFuture<>();
this.cancelled = new CompletableFuture<>();
}
/** Get the schema for this stream. Blocks until the schema is available. */
public Schema getSchema() {
return getRoot().getSchema();
}
/**
* Get the provider for dictionaries in this stream.
*
* <p>Does NOT retain a reference to the underlying dictionaries. Dictionaries may be updated as
* the stream is read. This method is intended for stream processing, where the application code
* will not retain references to values after the stream is closed.
*
* @throws IllegalStateException if {@link #takeDictionaryOwnership()} was called
* @see #takeDictionaryOwnership()
*/
public DictionaryProvider getDictionaryProvider() {
if (dictionaries == null) {
throw new IllegalStateException("Dictionary ownership was claimed by the application.");
}
return dictionaries;
}
/**
* Get an owned reference to the dictionaries in this stream. Should be called after finishing
* reading the stream, but before closing.
*
* <p>If called, the client is responsible for closing the dictionaries in this provider. Can only
* be called once.
*
* @return The dictionary provider for the stream.
* @throws IllegalStateException if called more than once.
*/
public DictionaryProvider takeDictionaryOwnership() {
if (dictionaries == null) {
throw new IllegalStateException("Dictionary ownership was claimed by the application.");
}
// Swap out the provider so it is not closed
final DictionaryProvider provider = dictionaries;
dictionaries = null;
return provider;
}
/**
* Get the descriptor for this stream. Only applicable on the server side of a DoPut operation.
* Will block until the client sends the descriptor.
*/
public FlightDescriptor getDescriptor() {
// This blocks until the first message from the client is received.
try {
return descriptor.get();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw CallStatus.INTERNAL.withCause(e).withDescription("Interrupted").toRuntimeException();
} catch (ExecutionException e) {
throw CallStatus.INTERNAL
.withCause(e)
.withDescription("Error getting descriptor")
.toRuntimeException();
}
}
/**
* Closes the stream (freeing any existing resources).
*
* <p>If the stream isn't complete and is cancellable, this method will cancel and drain the
* stream first.
*/
@Override
public void close() throws Exception {
final List<AutoCloseable> closeables = new ArrayList<>();
Throwable suppressor = null;
if (cancellable != null) {
// Client-side stream. Cancel the call, to help ensure gRPC doesn't deliver a message after
// close() ends.
// On the server side, we can't rely on draining the stream , because this gRPC bug means the
// completion callback
// may never run https://github.com/grpc/grpc-java/issues/5882
try {
synchronized (cancellable) {
if (!cancelled.isDone()) {
// Only cancel if the call is not done on the gRPC side
cancellable.cancel("Stream closed before end", /* no exception to report */ null);
}
}
// Drain the stream without the lock (as next() implicitly needs the lock)
while (next()) {}
} catch (FlightRuntimeException e) {
suppressor = e;
}
}
// Perform these operations under a lock. This way the observer can't enqueue new messages while
// we're in the
// middle of cleanup. This should only be a concern for server-side streams since client-side
// streams are drained
// by the lambda above.
synchronized (completed) {
try {
if (fulfilledRoot != null) {
closeables.add(fulfilledRoot);
}
closeables.add(applicationMetadata);
closeables.addAll(queue);
if (dictionaries != null) {
dictionaries
.getDictionaryIds()
.forEach(id -> closeables.add(dictionaries.lookup(id).getVector()));
}
if (suppressor != null) {
AutoCloseables.close(suppressor, closeables);
} else {
AutoCloseables.close(closeables);
}
// Remove any metadata after closing to prevent negative refcnt
applicationMetadata = null;
} finally {
// The value of this CompletableFuture is meaningless, only whether it's completed (or has
// an exception)
// No-op if already complete
completed.complete(null);
}
}
}
/**
* Blocking request to load next item into list.
*
* @return Whether or not more data was found.
*/
public boolean next() {
try {
if (completed.isDone() && queue.isEmpty()) {
return false;
}
pending.decrementAndGet();
requestOutstanding();
Object data = queue.take();
if (DONE == data) {
queue.put(DONE);
// Other code ignores the value of this CompletableFuture, only whether it's completed (or
// has an exception)
completed.complete(null);
return false;
} else if (DONE_EX == data) {
queue.put(DONE_EX);
if (ex instanceof Exception) {
throw (Exception) ex;
} else {
throw new Exception(ex);
}
} else {
try (ArrowMessage msg = ((ArrowMessage) data)) {
if (msg.getMessageType() == HeaderType.NONE) {
updateMetadata(msg);
// We received a message without data, so erase any leftover data
if (fulfilledRoot != null) {
fulfilledRoot.clear();
}
} else if (msg.getMessageType() == HeaderType.RECORD_BATCH) {
checkMetadataVersion(msg);
// Ensure we have the root
root.get().clear();
try (ArrowRecordBatch arb = msg.asRecordBatch()) {
loader.load(arb);
}
updateMetadata(msg);
} else if (msg.getMessageType() == HeaderType.DICTIONARY_BATCH) {
checkMetadataVersion(msg);
// Ensure we have the root
root.get().clear();
try (ArrowDictionaryBatch arb = msg.asDictionaryBatch()) {
final long id = arb.getDictionaryId();
if (dictionaries == null) {
throw new IllegalStateException(
"Dictionary ownership was claimed by the application.");
}
final Dictionary dictionary = dictionaries.lookup(id);
if (dictionary == null) {
throw new IllegalArgumentException("Dictionary not defined in schema: ID " + id);
}
final FieldVector vector = dictionary.getVector();
final VectorSchemaRoot dictionaryRoot =
new VectorSchemaRoot(
Collections.singletonList(vector.getField()),
Collections.singletonList(vector),
0);
final VectorLoader dictionaryLoader = new VectorLoader(dictionaryRoot);
dictionaryLoader.load(arb.getDictionary());
}
return next();
} else {
throw new UnsupportedOperationException(
"Message type is unsupported: " + msg.getMessageType());
}
return true;
}
}
} catch (RuntimeException e) {
throw e;
} catch (ExecutionException e) {
throw StatusUtils.fromThrowable(e.getCause());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/** Update our metadata reference with a new one from this message. */
private void updateMetadata(ArrowMessage msg) {
if (this.applicationMetadata != null) {
this.applicationMetadata.close();
}
this.applicationMetadata = msg.getApplicationMetadata();
if (this.applicationMetadata != null) {
this.applicationMetadata.getReferenceManager().retain();
}
}
/** Ensure the Arrow metadata version doesn't change mid-stream. */
private void checkMetadataVersion(ArrowMessage msg) {
if (msg.asSchemaMessage() == null) {
return;
}
MetadataVersion receivedVersion =
MetadataVersion.fromFlatbufID(msg.asSchemaMessage().getMessage().version());
if (this.metadataVersion != receivedVersion) {
throw new IllegalStateException(
"Metadata version mismatch: stream started as "
+ this.metadataVersion
+ " but got message with version "
+ receivedVersion);
}
}
/**
* Get the current vector data from the stream.
*
* <p>The data in the root may change at any time. Clients should NOT modify the root, but instead
* unload the data into their own root.
*
* @throws FlightRuntimeException if there was an error reading the schema from the stream.
*/
public VectorSchemaRoot getRoot() {
try {
return root.get();
} catch (InterruptedException e) {
throw CallStatus.INTERNAL.withCause(e).toRuntimeException();
} catch (ExecutionException e) {
throw StatusUtils.fromThrowable(e.getCause());
}
}
/**
* Check if there is a root (i.e. whether the other end has started sending data).
*
* <p>Updated by calls to {@link #next()}.
*
* @return true if and only if the other end has started sending data.
*/
public boolean hasRoot() {
return root.isDone();
}
/**
* Get the most recent metadata sent from the server. This may be cleared by calls to {@link
* #next()} if the server sends a message without metadata. This does NOT take ownership of the
* buffer - call retain() to create a reference if you need the buffer after a call to {@link
* #next()}.
*
* @return the application metadata. May be null.
*/
public ArrowBuf getLatestMetadata() {
return applicationMetadata;
}
private synchronized void requestOutstanding() {
if (pending.get() < pendingTarget) {
requestor.request(pendingTarget - pending.get());
pending.set(pendingTarget);
}
}
private class Observer implements StreamObserver<ArrowMessage> {
Observer() {
super();
}
/** Helper to add an item to the queue under the appropriate lock. */
private void enqueue(AutoCloseable message) {
synchronized (completed) {
if (completed.isDone()) {
// The stream is already closed (RPC ended), discard the message
AutoCloseables.closeNoChecked(message);
} else {
queue.add(message);
}
}
}
@Override
public void onNext(ArrowMessage msg) {
// Operations here have to be under a lock so that we don't add a message to the queue while
// in the middle of
// close().
requestOutstanding();
switch (msg.getMessageType()) {
case NONE:
{
// No IPC message - pure metadata or descriptor
if (msg.getDescriptor() != null) {
descriptor.set(new FlightDescriptor(msg.getDescriptor()));
}
if (msg.getApplicationMetadata() != null) {
enqueue(msg);
}
break;
}
case SCHEMA:
{
Schema schema = msg.asSchema();
// if there is app metadata in the schema message, make sure
// that we don't leak it.
ArrowBuf meta = msg.getApplicationMetadata();
if (meta != null) {
meta.close();
}
final List<Field> fields = new ArrayList<>();
final Map<Long, Dictionary> dictionaryMap = new HashMap<>();
for (final Field originalField : schema.getFields()) {
final Field updatedField =
DictionaryUtility.toMemoryFormat(originalField, allocator, dictionaryMap);
fields.add(updatedField);
}
for (final Map.Entry<Long, Dictionary> entry : dictionaryMap.entrySet()) {
dictionaries.put(entry.getValue());
}
schema = new Schema(fields, schema.getCustomMetadata());
metadataVersion =
MetadataVersion.fromFlatbufID(msg.asSchemaMessage().getMessage().version());
try {
MetadataV4UnionChecker.checkRead(schema, metadataVersion);
} catch (IOException e) {
ex = e;
enqueue(DONE_EX);
break;
}
synchronized (completed) {
if (!completed.isDone()) {
fulfilledRoot = VectorSchemaRoot.create(schema, allocator);
loader = new VectorLoader(fulfilledRoot);
if (msg.getDescriptor() != null) {
descriptor.set(new FlightDescriptor(msg.getDescriptor()));
}
root.set(fulfilledRoot);
}
}
break;
}
case RECORD_BATCH:
case DICTIONARY_BATCH:
enqueue(msg);
break;
case TENSOR:
default:
ex =
new UnsupportedOperationException(
"Unable to handle message of type: " + msg.getMessageType());
enqueue(DONE_EX);
}
}
@Override
public void onError(Throwable t) {
ex = StatusUtils.fromThrowable(t);
queue.add(DONE_EX);
cancelled.complete(null);
root.setException(ex);
}
@Override
public void onCompleted() {
// Depends on gRPC calling onNext and onCompleted non-concurrently
cancelled.complete(null);
queue.add(DONE);
}
}
/**
* Cancels sending the stream to a client.
*
* <p>Callers should drain the stream (with {@link #next()}) to ensure all messages sent before
* cancellation are received and to wait for the underlying transport to acknowledge cancellation.
*/
public void cancel(String message, Throwable exception) {
if (cancellable == null) {
throw new UnsupportedOperationException(
"Streams cannot be cancelled that are produced by client. "
+ "Instead, server should reject incoming messages.");
}
cancellable.cancel(message, exception);
// Do not mark the stream as completed, as gRPC may still be delivering messages.
}
StreamObserver<ArrowMessage> asObserver() {
return new Observer();
}
/** Provides a callback to cancel a process that is in progress. */
@FunctionalInterface
public interface Cancellable {
void cancel(String message, Throwable exception);
}
/** Provides a interface to request more items from a stream producer. */
@FunctionalInterface
public interface Requestor {
/** Requests <code>count</code> more messages from the instance of this object. */
void request(int count);
}
}