ArrowMessage.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.flight;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.io.ByteStreams;
import com.google.protobuf.ByteString;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.CodedOutputStream;
import com.google.protobuf.WireFormat;
import io.grpc.Drainable;
import io.grpc.MethodDescriptor.Marshaller;
import io.grpc.protobuf.ProtoUtils;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufInputStream;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.arrow.flight.grpc.AddWritableBuffer;
import org.apache.arrow.flight.grpc.GetReadableBuffer;
import org.apache.arrow.flight.impl.Flight.FlightData;
import org.apache.arrow.flight.impl.Flight.FlightDescriptor;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.ipc.message.MessageMetadataResult;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.MetadataVersion;
import org.apache.arrow.vector.types.pojo.Schema;
/** The in-memory representation of FlightData used to manage a stream of Arrow messages. */
class ArrowMessage implements AutoCloseable {
// If true, deserialize Arrow data by giving Arrow a reference to the underlying gRPC buffer
// instead of copying the data. Defaults to true.
public static final boolean ENABLE_ZERO_COPY_READ;
// If true, serialize Arrow data by giving gRPC a reference to the underlying Arrow buffer
// instead of copying the data. Defaults to false.
public static final boolean ENABLE_ZERO_COPY_WRITE;
static {
String zeroCopyReadFlag = System.getProperty("arrow.flight.enable_zero_copy_read");
if (zeroCopyReadFlag == null) {
zeroCopyReadFlag = System.getenv("ARROW_FLIGHT_ENABLE_ZERO_COPY_READ");
}
String zeroCopyWriteFlag = System.getProperty("arrow.flight.enable_zero_copy_write");
if (zeroCopyWriteFlag == null) {
zeroCopyWriteFlag = System.getenv("ARROW_FLIGHT_ENABLE_ZERO_COPY_WRITE");
}
ENABLE_ZERO_COPY_READ = !"false".equalsIgnoreCase(zeroCopyReadFlag);
ENABLE_ZERO_COPY_WRITE = "true".equalsIgnoreCase(zeroCopyWriteFlag);
}
private static final int DESCRIPTOR_TAG =
(FlightData.FLIGHT_DESCRIPTOR_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED;
private static final int BODY_TAG =
(FlightData.DATA_BODY_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED;
private static final int HEADER_TAG =
(FlightData.DATA_HEADER_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED;
private static final int APP_METADATA_TAG =
(FlightData.APP_METADATA_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED;
private static final Marshaller<FlightData> NO_BODY_MARSHALLER =
ProtoUtils.marshaller(FlightData.getDefaultInstance());
/**
* Get the application-specific metadata in this message. The ArrowMessage retains ownership of
* the buffer.
*/
public ArrowBuf getApplicationMetadata() {
return appMetadata;
}
/** Types of messages that can be sent. */
public enum HeaderType {
NONE,
SCHEMA,
DICTIONARY_BATCH,
RECORD_BATCH,
TENSOR;
public static HeaderType getHeader(byte b) {
switch (b) {
case 0:
return NONE;
case 1:
return SCHEMA;
case 2:
return DICTIONARY_BATCH;
case 3:
return RECORD_BATCH;
case 4:
return TENSOR;
default:
throw new UnsupportedOperationException("unknown type: " + b);
}
}
}
// Pre-allocated buffers for padding serialized ArrowMessages.
private static final List<ByteBuf> PADDING_BUFFERS =
Arrays.asList(
null,
Unpooled.copiedBuffer(new byte[] {0}),
Unpooled.copiedBuffer(new byte[] {0, 0}),
Unpooled.copiedBuffer(new byte[] {0, 0, 0}),
Unpooled.copiedBuffer(new byte[] {0, 0, 0, 0}),
Unpooled.copiedBuffer(new byte[] {0, 0, 0, 0, 0}),
Unpooled.copiedBuffer(new byte[] {0, 0, 0, 0, 0, 0}),
Unpooled.copiedBuffer(new byte[] {0, 0, 0, 0, 0, 0, 0}));
private final IpcOption writeOption;
private final FlightDescriptor descriptor;
private final MessageMetadataResult message;
private final ArrowBuf appMetadata;
private final List<ArrowBuf> bufs;
private final boolean tryZeroCopyWrite;
public ArrowMessage(FlightDescriptor descriptor, Schema schema, IpcOption option) {
this.writeOption = option;
ByteBuffer serializedMessage = MessageSerializer.serializeMetadata(schema, writeOption);
this.message =
MessageMetadataResult.create(serializedMessage.slice(), serializedMessage.remaining());
bufs = ImmutableList.of();
this.descriptor = descriptor;
this.appMetadata = null;
this.tryZeroCopyWrite = false;
}
/**
* Create an ArrowMessage from a record batch and app metadata.
*
* @param batch The record batch.
* @param appMetadata The app metadata. May be null. Takes ownership of the buffer otherwise.
* @param tryZeroCopy Whether to enable the zero-copy optimization.
*/
public ArrowMessage(
ArrowRecordBatch batch, ArrowBuf appMetadata, boolean tryZeroCopy, IpcOption option) {
this.writeOption = option;
ByteBuffer serializedMessage = MessageSerializer.serializeMetadata(batch, writeOption);
this.message =
MessageMetadataResult.create(serializedMessage.slice(), serializedMessage.remaining());
this.bufs = ImmutableList.copyOf(batch.getBuffers());
this.descriptor = null;
this.appMetadata = appMetadata;
this.tryZeroCopyWrite = tryZeroCopy;
}
public ArrowMessage(ArrowDictionaryBatch batch, IpcOption option) {
this.writeOption = option;
ByteBuffer serializedMessage = MessageSerializer.serializeMetadata(batch, writeOption);
serializedMessage = serializedMessage.slice();
this.message = MessageMetadataResult.create(serializedMessage, serializedMessage.remaining());
// asInputStream will free the buffers implicitly, so increment the reference count
batch.getDictionary().getBuffers().forEach(buf -> buf.getReferenceManager().retain());
this.bufs = ImmutableList.copyOf(batch.getDictionary().getBuffers());
this.descriptor = null;
this.appMetadata = null;
this.tryZeroCopyWrite = false;
}
/**
* Create an ArrowMessage containing only application metadata.
*
* @param appMetadata The application-provided metadata buffer.
*/
public ArrowMessage(ArrowBuf appMetadata) {
// No need to take IpcOption as it's not used to serialize this kind of message.
this.writeOption = IpcOption.DEFAULT;
this.message = null;
this.bufs = ImmutableList.of();
this.descriptor = null;
this.appMetadata = appMetadata;
this.tryZeroCopyWrite = false;
}
public ArrowMessage(FlightDescriptor descriptor) {
// No need to take IpcOption as it's not used to serialize this kind of message.
this.writeOption = IpcOption.DEFAULT;
this.message = null;
this.bufs = ImmutableList.of();
this.descriptor = descriptor;
this.appMetadata = null;
this.tryZeroCopyWrite = false;
}
private ArrowMessage(
FlightDescriptor descriptor,
MessageMetadataResult message,
ArrowBuf appMetadata,
ArrowBuf buf) {
// No need to take IpcOption as this is used for deserialized ArrowMessage coming from the wire.
this.writeOption =
message != null
?
// avoid writing legacy ipc format by default
new IpcOption(false, MetadataVersion.fromFlatbufID(message.getMessage().version()))
: IpcOption.DEFAULT;
this.message = message;
this.descriptor = descriptor;
this.appMetadata = appMetadata;
this.bufs = buf == null ? ImmutableList.of() : ImmutableList.of(buf);
this.tryZeroCopyWrite = false;
}
public MessageMetadataResult asSchemaMessage() {
return message;
}
public FlightDescriptor getDescriptor() {
return descriptor;
}
public HeaderType getMessageType() {
if (message == null) {
// Null message occurs for metadata-only messages (in DoExchange)
return HeaderType.NONE;
}
return HeaderType.getHeader(message.headerType());
}
public Schema asSchema() {
Preconditions.checkArgument(bufs.size() == 0);
Preconditions.checkArgument(getMessageType() == HeaderType.SCHEMA);
return MessageSerializer.deserializeSchema(message);
}
public ArrowRecordBatch asRecordBatch() throws IOException {
Preconditions.checkArgument(
bufs.size() == 1, "A batch can only be consumed if it contains a single ArrowBuf.");
Preconditions.checkArgument(getMessageType() == HeaderType.RECORD_BATCH);
ArrowBuf underlying = bufs.get(0);
underlying.getReferenceManager().retain();
return MessageSerializer.deserializeRecordBatch(message, underlying);
}
public ArrowDictionaryBatch asDictionaryBatch() throws IOException {
Preconditions.checkArgument(
bufs.size() == 1, "A batch can only be consumed if it contains a single ArrowBuf.");
Preconditions.checkArgument(getMessageType() == HeaderType.DICTIONARY_BATCH);
ArrowBuf underlying = bufs.get(0);
// Retain a reference to keep the batch alive when the message is closed
underlying.getReferenceManager().retain();
// Do not set drained - we still want to release our reference
return MessageSerializer.deserializeDictionaryBatch(message, underlying);
}
public Iterable<ArrowBuf> getBufs() {
return Iterables.unmodifiableIterable(bufs);
}
private static ArrowMessage frame(BufferAllocator allocator, final InputStream stream) {
try {
FlightDescriptor descriptor = null;
MessageMetadataResult header = null;
ArrowBuf body = null;
ArrowBuf appMetadata = null;
while (stream.available() > 0) {
int tag = readRawVarint32(stream);
switch (tag) {
case DESCRIPTOR_TAG:
{
int size = readRawVarint32(stream);
byte[] bytes = new byte[size];
ByteStreams.readFully(stream, bytes);
descriptor = FlightDescriptor.parseFrom(bytes);
break;
}
case HEADER_TAG:
{
int size = readRawVarint32(stream);
byte[] bytes = new byte[size];
ByteStreams.readFully(stream, bytes);
header = MessageMetadataResult.create(ByteBuffer.wrap(bytes), size);
break;
}
case APP_METADATA_TAG:
{
int size = readRawVarint32(stream);
appMetadata = allocator.buffer(size);
GetReadableBuffer.readIntoBuffer(stream, appMetadata, size, ENABLE_ZERO_COPY_READ);
break;
}
case BODY_TAG:
if (body != null) {
// only read last body.
body.getReferenceManager().release();
body = null;
}
int size = readRawVarint32(stream);
body = allocator.buffer(size);
GetReadableBuffer.readIntoBuffer(stream, body, size, ENABLE_ZERO_COPY_READ);
break;
default:
// ignore unknown fields.
}
}
// Protobuf implementations can omit empty fields, such as body; for some message types, like
// RecordBatch,
// this will fail later as we still expect an empty buffer. In those cases only, fill in an
// empty buffer here -
// in other cases, like Schema, having an unexpected empty buffer will also cause failures.
// We don't fill in defaults for fields like header, for which there is no reasonable default,
// or for appMetadata
// or descriptor, which are intended to be empty in some cases.
if (header != null) {
switch (HeaderType.getHeader(header.headerType())) {
case SCHEMA:
// Ignore 0-length buffers in case a Protobuf implementation wrote it out
if (body != null && body.capacity() == 0) {
body.close();
body = null;
}
break;
case DICTIONARY_BATCH:
case RECORD_BATCH:
// A Protobuf implementation can skip 0-length bodies, so ensure we fill it in here
if (body == null) {
body = allocator.getEmpty();
}
break;
case NONE:
case TENSOR:
default:
// Do nothing
break;
}
}
return new ArrowMessage(descriptor, header, appMetadata, body);
} catch (Exception ioe) {
throw new RuntimeException(ioe);
}
}
private static int readRawVarint32(InputStream is) throws IOException {
int firstByte = is.read();
return CodedInputStream.readRawVarint32(firstByte, is);
}
/**
* Convert the ArrowMessage to an InputStream.
*
* <p>Implicitly, this transfers ownership of the contained buffers to the InputStream.
*
* @return InputStream
*/
private InputStream asInputStream() {
if (message == null) {
// If we have no IPC message, it's a pure-metadata message
final FlightData.Builder builder = FlightData.newBuilder();
if (descriptor != null) {
builder.setFlightDescriptor(descriptor);
}
if (appMetadata != null) {
builder.setAppMetadata(ByteString.copyFrom(appMetadata.nioBuffer()));
}
return NO_BODY_MARSHALLER.stream(builder.build());
}
try {
final ByteString bytes =
ByteString.copyFrom(message.getMessageBuffer(), message.bytesAfterMessage());
if (getMessageType() == HeaderType.SCHEMA) {
final FlightData.Builder builder = FlightData.newBuilder().setDataHeader(bytes);
if (descriptor != null) {
builder.setFlightDescriptor(descriptor);
}
Preconditions.checkArgument(bufs.isEmpty());
return NO_BODY_MARSHALLER.stream(builder.build());
}
Preconditions.checkArgument(
getMessageType() == HeaderType.RECORD_BATCH
|| getMessageType() == HeaderType.DICTIONARY_BATCH);
// There may be no buffers in the case that we write only a null array
Preconditions.checkArgument(
descriptor == null, "Descriptor should only be included in the schema message.");
ByteArrayOutputStream baos = new ByteArrayOutputStream();
CodedOutputStream cos = CodedOutputStream.newInstance(baos);
cos.writeBytes(FlightData.DATA_HEADER_FIELD_NUMBER, bytes);
if (appMetadata != null && appMetadata.capacity() > 0) {
// Must call slice() as CodedOutputStream#writeByteBuffer writes -capacity- bytes, not
// -limit- bytes
cos.writeByteBuffer(FlightData.APP_METADATA_FIELD_NUMBER, appMetadata.nioBuffer().slice());
}
cos.writeTag(FlightData.DATA_BODY_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED);
int size = 0;
List<ByteBuf> allBufs = new ArrayList<>();
for (ArrowBuf b : bufs) {
// [ARROW-11066] This creates a Netty buffer whose refcnt is INDEPENDENT of the backing
// Arrow buffer. This is susceptible to use-after-free, so we subclass CompositeByteBuf
// below to tie the Arrow buffer refcnt to the Netty buffer refcnt
allBufs.add(Unpooled.wrappedBuffer(b.nioBuffer()).retain());
size += (int) b.readableBytes();
// [ARROW-4213] These buffers must be aligned to an 8-byte boundary in order to be readable
// from C++.
if (b.readableBytes() % 8 != 0) {
int paddingBytes = (int) (8 - (b.readableBytes() % 8));
assert paddingBytes > 0 && paddingBytes < 8;
size += paddingBytes;
allBufs.add(PADDING_BUFFERS.get(paddingBytes).retain());
}
}
// rawvarint is used for length definition.
cos.writeUInt32NoTag(size);
cos.flush();
ByteBuf initialBuf = Unpooled.buffer(baos.size());
initialBuf.writeBytes(baos.toByteArray());
final CompositeByteBuf bb;
final ImmutableList<ByteBuf> byteBufs =
ImmutableList.<ByteBuf>builder().add(initialBuf).addAll(allBufs).build();
// See: https://github.com/apache/arrow/issues/40039
// CompositeByteBuf requires us to pass maxNumComponents to constructor.
// This number will be used to decide when to stop adding new components as separate buffers
// and instead merge existing components into a new buffer by performing a memory copy.
// We want to avoind memory copies as much as possible so we want to set the limit that won't
// be reached.
// At a first glance it seems reasonable to set limit to byteBufs.size() + 1,
// because it will be enough to avoid merges of byteBufs that we pass to constructor.
// But later this buffer will be written to socket by Netty
// and DefaultHttp2ConnectionEncoder uses CoalescingBufferQueue to combine small buffers into
// one.
// Method CoalescingBufferQueue.compose will check if current buffer is already a
// CompositeByteBuf
// and if it's the case it will just add a new component to this buffer.
// But in out case if we set maxNumComponents=byteBufs.size() + 1 it will happen on the first
// attempt
// to write data to socket because header message is small and Netty will always try to
// compine it with the
// large CompositeByteBuf we're creating here.
// We never want additional memory copies so setting the limit to Integer.MAX_VALUE
final int maxNumComponents = Integer.MAX_VALUE;
if (tryZeroCopyWrite) {
bb = new ArrowBufRetainingCompositeByteBuf(maxNumComponents, byteBufs, bufs);
} else {
// Don't retain the buffers in the non-zero-copy path since we're copying them
bb =
new CompositeByteBuf(
UnpooledByteBufAllocator.DEFAULT, /* direct */ true, maxNumComponents, byteBufs);
}
return new DrainableByteBufInputStream(bb, tryZeroCopyWrite);
} catch (Exception ex) {
throw new RuntimeException("Unexpected IO Exception", ex);
}
}
/**
* ARROW-11066: enable the zero-copy optimization and protect against use-after-free.
*
* <p>When you send a message through gRPC, the following happens: 1. gRPC immediately serializes
* the message, eventually calling asInputStream above. 2. gRPC buffers the serialized message for
* sending. 3. Later, gRPC will actually write out the message.
*
* <p>The problem with this is that when the zero-copy optimization is enabled, Flight
* "serializes" the message by handing gRPC references to Arrow data. That means we need a way to
* keep the Arrow buffers valid until gRPC actually writes them, else, we'll read invalid data or
* segfault. gRPC doesn't know anything about Arrow buffers, either.
*
* <p>This class solves that issue by bridging Arrow and Netty/gRPC. We increment the refcnt on a
* set of Arrow backing buffers and decrement them once the Netty buffers are freed by gRPC.
*/
private static final class ArrowBufRetainingCompositeByteBuf extends CompositeByteBuf {
// Arrow buffers that back the Netty ByteBufs here; ByteBufs held by this class are
// either slices of one of the ArrowBufs or independently allocated.
final List<ArrowBuf> backingBuffers;
boolean freed;
ArrowBufRetainingCompositeByteBuf(
int maxNumComponents, Iterable<ByteBuf> buffers, List<ArrowBuf> backingBuffers) {
super(UnpooledByteBufAllocator.DEFAULT, /* direct */ true, maxNumComponents, buffers);
this.backingBuffers = backingBuffers;
this.freed = false;
// N.B. the Netty superclass avoids enhanced-for to reduce GC pressure, so follow that here
for (int i = 0; i < backingBuffers.size(); i++) {
backingBuffers.get(i).getReferenceManager().retain();
}
}
@Override
protected void deallocate() {
super.deallocate();
if (freed) {
return;
}
freed = true;
for (int i = 0; i < backingBuffers.size(); i++) {
backingBuffers.get(i).getReferenceManager().release();
}
}
}
private static class DrainableByteBufInputStream extends ByteBufInputStream implements Drainable {
private final CompositeByteBuf buf;
private final boolean isZeroCopy;
public DrainableByteBufInputStream(CompositeByteBuf buffer, boolean isZeroCopy) {
super(buffer, buffer.readableBytes(), true);
this.buf = buffer;
this.isZeroCopy = isZeroCopy;
}
@Override
public int drainTo(OutputStream target) throws IOException {
int size = buf.readableBytes();
AddWritableBuffer.add(buf, target, isZeroCopy);
return size;
}
@Override
public void close() {
buf.release();
}
}
public static Marshaller<ArrowMessage> createMarshaller(BufferAllocator allocator) {
return new ArrowMessageHolderMarshaller(allocator);
}
private static class ArrowMessageHolderMarshaller implements Marshaller<ArrowMessage> {
private final BufferAllocator allocator;
public ArrowMessageHolderMarshaller(BufferAllocator allocator) {
this.allocator = allocator;
}
@Override
public InputStream stream(ArrowMessage value) {
return value.asInputStream();
}
@Override
public ArrowMessage parse(InputStream stream) {
return ArrowMessage.frame(allocator, stream);
}
}
@Override
public void close() throws Exception {
AutoCloseables.close(Iterables.concat(bufs, Collections.singletonList(appMetadata)));
}
}