TestRoundTrip.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.vector.ipc;
import static org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.channels.Channels;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import java.util.stream.Stream;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.Collections2;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.FixedSizeBinaryVector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.TinyIntVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.complex.FixedSizeListVector;
import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.message.ArrowBlock;
import org.apache.arrow.vector.ipc.message.ArrowBuffer;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.ipc.message.MessageMetadataResult;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.MetadataVersion;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class TestRoundTrip extends BaseFileTest {
private static final Logger LOGGER = LoggerFactory.getLogger(TestRoundTrip.class);
private static BufferAllocator allocator;
static Stream<Object[]> getWriteOption() {
final IpcOption legacy = new IpcOption(true, MetadataVersion.V4);
final IpcOption version4 = new IpcOption(false, MetadataVersion.V4);
return Stream.of(
new Object[] {"V4Legacy", legacy},
new Object[] {"V4", version4},
new Object[] {"V5", IpcOption.DEFAULT});
}
@BeforeAll
public static void setUpClass() {
allocator = new RootAllocator(Integer.MAX_VALUE);
}
@AfterAll
public static void tearDownClass() {
allocator.close();
}
@ParameterizedTest(name = "options = {0}")
@MethodSource("getWriteOption")
public void testStruct(String name, IpcOption writeOption) throws Exception {
try (final BufferAllocator originalVectorAllocator =
allocator.newChildAllocator("original vectors", 0, allocator.getLimit());
final StructVector parent = StructVector.empty("parent", originalVectorAllocator)) {
writeData(COUNT, parent);
roundTrip(
name,
writeOption,
new VectorSchemaRoot(parent.getChild("root")),
/* dictionaryProvider */ null,
TestRoundTrip::writeSingleBatch,
validateFileBatches(new int[] {COUNT}, this::validateContent),
validateStreamBatches(new int[] {COUNT}, this::validateContent));
}
}
@ParameterizedTest(name = "options = {0}")
@MethodSource("getWriteOption")
public void testComplex(String name, IpcOption writeOption) throws Exception {
try (final BufferAllocator originalVectorAllocator =
allocator.newChildAllocator("original vectors", 0, allocator.getLimit());
final StructVector parent = StructVector.empty("parent", originalVectorAllocator)) {
writeComplexData(COUNT, parent);
roundTrip(
name,
writeOption,
new VectorSchemaRoot(parent.getChild("root")),
/* dictionaryProvider */ null,
TestRoundTrip::writeSingleBatch,
validateFileBatches(new int[] {COUNT}, this::validateComplexContent),
validateStreamBatches(new int[] {COUNT}, this::validateComplexContent));
}
}
@ParameterizedTest(name = "options = {0}")
@MethodSource("getWriteOption")
public void testMultipleRecordBatches(String name, IpcOption writeOption) throws Exception {
int[] counts = {10, 5};
try (final BufferAllocator originalVectorAllocator =
allocator.newChildAllocator("original vectors", 0, allocator.getLimit());
final StructVector parent = StructVector.empty("parent", originalVectorAllocator)) {
writeData(counts[0], parent);
roundTrip(
name,
writeOption,
new VectorSchemaRoot(parent.getChild("root")),
/* dictionaryProvider */ null,
(root, writer) -> {
writer.start();
parent.allocateNew();
writeData(counts[0], parent);
root.setRowCount(counts[0]);
writer.writeBatch();
parent.allocateNew();
// if we write the same data we don't catch that the metadata is stored in the wrong
// order.
writeData(counts[1], parent);
root.setRowCount(counts[1]);
writer.writeBatch();
writer.end();
},
validateFileBatches(counts, this::validateContent),
validateStreamBatches(counts, this::validateContent));
}
}
@ParameterizedTest(name = "options = {0}")
@MethodSource("getWriteOption")
public void testUnionV4(String name, IpcOption writeOption) throws Exception {
assumeTrue(writeOption.metadataVersion == MetadataVersion.V4);
final File temp = File.createTempFile("arrow-test-" + name + "-", ".arrow");
temp.deleteOnExit();
final ByteArrayOutputStream memoryStream = new ByteArrayOutputStream();
try (final BufferAllocator originalVectorAllocator =
allocator.newChildAllocator("original vectors", 0, allocator.getLimit());
final StructVector parent = StructVector.empty("parent", originalVectorAllocator)) {
writeUnionData(COUNT, parent);
final VectorSchemaRoot root = new VectorSchemaRoot(parent.getChild("root"));
IllegalArgumentException e =
assertThrows(
IllegalArgumentException.class,
() -> {
try (final FileOutputStream fileStream = new FileOutputStream(temp)) {
new ArrowFileWriter(root, null, fileStream.getChannel(), writeOption);
new ArrowStreamWriter(root, null, Channels.newChannel(memoryStream), writeOption);
}
});
assertTrue(e.getMessage().contains("Cannot write union with V4 metadata"), e.getMessage());
e =
assertThrows(
IllegalArgumentException.class,
() -> {
new ArrowStreamWriter(root, null, Channels.newChannel(memoryStream), writeOption);
});
assertTrue(e.getMessage().contains("Cannot write union with V4 metadata"), e.getMessage());
}
}
@ParameterizedTest(name = "options = {0}")
@MethodSource("getWriteOption")
public void testUnionV5(String name, IpcOption writeOption) throws Exception {
assumeTrue(writeOption.metadataVersion == MetadataVersion.V5);
try (final BufferAllocator originalVectorAllocator =
allocator.newChildAllocator("original vectors", 0, allocator.getLimit());
final StructVector parent = StructVector.empty("parent", originalVectorAllocator)) {
writeUnionData(COUNT, parent);
VectorSchemaRoot root = new VectorSchemaRoot(parent.getChild("root"));
validateUnionData(COUNT, root);
roundTrip(
name,
writeOption,
root,
/* dictionaryProvider */ null,
TestRoundTrip::writeSingleBatch,
validateFileBatches(new int[] {COUNT}, this::validateUnionData),
validateStreamBatches(new int[] {COUNT}, this::validateUnionData));
}
}
@ParameterizedTest(name = "options = {0}")
@MethodSource("getWriteOption")
public void testTiny(String name, IpcOption writeOption) throws Exception {
try (final VectorSchemaRoot root =
VectorSchemaRoot.create(MessageSerializerTest.testSchema(), allocator)) {
root.getFieldVectors().get(0).allocateNew();
int count = 16;
TinyIntVector vector = (TinyIntVector) root.getFieldVectors().get(0);
for (int i = 0; i < count; i++) {
vector.set(i, i < 8 ? 1 : 0, (byte) (i + 1));
}
vector.setValueCount(count);
root.setRowCount(count);
roundTrip(
name,
writeOption,
root,
/* dictionaryProvider */ null,
TestRoundTrip::writeSingleBatch,
validateFileBatches(new int[] {count}, this::validateTinyData),
validateStreamBatches(new int[] {count}, this::validateTinyData));
}
}
private void validateTinyData(int count, VectorSchemaRoot root) {
assertEquals(count, root.getRowCount());
TinyIntVector vector = (TinyIntVector) root.getFieldVectors().get(0);
for (int i = 0; i < count; i++) {
if (i < 8) {
assertEquals((byte) (i + 1), vector.get(i));
} else {
assertTrue(vector.isNull(i));
}
}
}
@ParameterizedTest(name = "options = {0}")
@MethodSource("getWriteOption")
public void testMetadata(String name, IpcOption writeOption) throws Exception {
List<Field> childFields = new ArrayList<>();
childFields.add(
new Field(
"varchar-child",
new FieldType(true, ArrowType.Utf8.INSTANCE, null, metadata(1)),
null));
childFields.add(
new Field(
"float-child",
new FieldType(
true,
new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE),
null,
metadata(2)),
null));
childFields.add(
new Field(
"int-child",
new FieldType(false, new ArrowType.Int(32, true), null, metadata(3)),
null));
childFields.add(
new Field(
"list-child",
new FieldType(true, ArrowType.List.INSTANCE, null, metadata(4)),
Collections2.asImmutableList(
new Field("l1", FieldType.nullable(new ArrowType.Int(16, true)), null))));
Field field =
new Field(
"meta", new FieldType(true, ArrowType.Struct.INSTANCE, null, metadata(0)), childFields);
Map<String, String> metadata = new HashMap<>();
metadata.put("s1", "v1");
metadata.put("s2", "v2");
Schema originalSchema = new Schema(Collections2.asImmutableList(field), metadata);
assertEquals(metadata, originalSchema.getCustomMetadata());
try (final BufferAllocator originalVectorAllocator =
allocator.newChildAllocator("original vectors", 0, allocator.getLimit());
final StructVector vector = (StructVector) field.createVector(originalVectorAllocator)) {
vector.allocateNewSafe();
vector.setValueCount(0);
List<FieldVector> vectors = Collections2.asImmutableList(vector);
VectorSchemaRoot root = new VectorSchemaRoot(originalSchema, vectors, 0);
BiConsumer<Integer, VectorSchemaRoot> validate =
(count, readRoot) -> {
Schema schema = readRoot.getSchema();
assertEquals(originalSchema, schema);
assertEquals(originalSchema.getCustomMetadata(), schema.getCustomMetadata());
Field top = schema.getFields().get(0);
assertEquals(metadata(0), top.getMetadata());
for (int i = 0; i < 4; i++) {
assertEquals(metadata(i + 1), top.getChildren().get(i).getMetadata());
}
};
roundTrip(
name,
writeOption,
root,
/* dictionaryProvider */ null,
TestRoundTrip::writeSingleBatch,
validateFileBatches(new int[] {0}, validate),
validateStreamBatches(new int[] {0}, validate));
}
}
private Map<String, String> metadata(int i) {
Map<String, String> map = new HashMap<>();
map.put("k_" + i, "v_" + i);
map.put("k2_" + i, "v2_" + i);
return Collections.unmodifiableMap(map);
}
@ParameterizedTest(name = "options = {0}")
@MethodSource("getWriteOption")
public void testFlatDictionary(String name, IpcOption writeOption) throws Exception {
AtomicInteger numDictionaryBlocksWritten = new AtomicInteger();
MapDictionaryProvider provider = new MapDictionaryProvider();
try (final BufferAllocator originalVectorAllocator =
allocator.newChildAllocator("original vectors", 0, allocator.getLimit());
final VectorSchemaRoot root = writeFlatDictionaryData(originalVectorAllocator, provider)) {
roundTrip(
name,
writeOption,
root,
provider,
(ignored, writer) -> {
writer.start();
writer.writeBatch();
writer.end();
if (writer instanceof ArrowFileWriter) {
numDictionaryBlocksWritten.set(
((ArrowFileWriter) writer).getDictionaryBlocks().size());
}
},
(fileReader) -> {
VectorSchemaRoot readRoot = fileReader.getVectorSchemaRoot();
Schema schema = readRoot.getSchema();
LOGGER.debug("reading schema: " + schema);
assertTrue(fileReader.loadNextBatch());
validateFlatDictionary(readRoot, fileReader);
assertEquals(numDictionaryBlocksWritten.get(), fileReader.getDictionaryBlocks().size());
},
(streamReader) -> {
VectorSchemaRoot readRoot = streamReader.getVectorSchemaRoot();
Schema schema = readRoot.getSchema();
LOGGER.debug("reading schema: " + schema);
assertTrue(streamReader.loadNextBatch());
validateFlatDictionary(readRoot, streamReader);
});
// Need to close dictionary vectors
for (long id : provider.getDictionaryIds()) {
provider.lookup(id).getVector().close();
}
}
}
@ParameterizedTest(name = "options = {0}")
@MethodSource("getWriteOption")
public void testNestedDictionary(String name, IpcOption writeOption) throws Exception {
AtomicInteger numDictionaryBlocksWritten = new AtomicInteger();
MapDictionaryProvider provider = new MapDictionaryProvider();
// data being written:
// [['foo', 'bar'], ['foo'], ['bar']] -> [[0, 1], [0], [1]]
try (final BufferAllocator originalVectorAllocator =
allocator.newChildAllocator("original vectors", 0, allocator.getLimit());
final VectorSchemaRoot root =
writeNestedDictionaryData(originalVectorAllocator, provider)) {
CheckedConsumer<ArrowReader> validateDictionary =
(streamReader) -> {
VectorSchemaRoot readRoot = streamReader.getVectorSchemaRoot();
Schema schema = readRoot.getSchema();
LOGGER.debug("reading schema: " + schema);
assertTrue(streamReader.loadNextBatch());
validateNestedDictionary(readRoot, streamReader);
};
roundTrip(
name,
writeOption,
root,
provider,
(ignored, writer) -> {
writer.start();
writer.writeBatch();
writer.end();
if (writer instanceof ArrowFileWriter) {
numDictionaryBlocksWritten.set(
((ArrowFileWriter) writer).getDictionaryBlocks().size());
}
},
validateDictionary,
validateDictionary);
// Need to close dictionary vectors
for (long id : provider.getDictionaryIds()) {
provider.lookup(id).getVector().close();
}
}
}
@ParameterizedTest(name = "options = {0}")
@MethodSource("getWriteOption")
public void testFixedSizeBinary(String name, IpcOption writeOption) throws Exception {
final int count = 10;
final int typeWidth = 11;
byte[][] byteValues = new byte[count][typeWidth];
for (int i = 0; i < count; i++) {
for (int j = 0; j < typeWidth; j++) {
byteValues[i][j] = ((byte) i);
}
}
BiConsumer<Integer, VectorSchemaRoot> validator =
(expectedCount, root) -> {
for (int i = 0; i < expectedCount; i++) {
assertArrayEquals(
byteValues[i], ((byte[]) root.getVector("fixed-binary").getObject(i)));
}
};
try (final BufferAllocator originalVectorAllocator =
allocator.newChildAllocator("original vectors", 0, allocator.getLimit());
final StructVector parent = StructVector.empty("parent", originalVectorAllocator)) {
FixedSizeBinaryVector fixedSizeBinaryVector =
parent.addOrGet(
"fixed-binary",
FieldType.nullable(new ArrowType.FixedSizeBinary(typeWidth)),
FixedSizeBinaryVector.class);
parent.allocateNew();
for (int i = 0; i < count; i++) {
fixedSizeBinaryVector.set(i, byteValues[i]);
}
parent.setValueCount(count);
roundTrip(
name,
writeOption,
new VectorSchemaRoot(parent),
/* dictionaryProvider */ null,
TestRoundTrip::writeSingleBatch,
validateFileBatches(new int[] {count}, validator),
validateStreamBatches(new int[] {count}, validator));
}
}
@ParameterizedTest(name = "options = {0}")
@MethodSource("getWriteOption")
public void testFixedSizeList(String name, IpcOption writeOption) throws Exception {
BiConsumer<Integer, VectorSchemaRoot> validator =
(expectedCount, root) -> {
for (int i = 0; i < expectedCount; i++) {
assertEquals(
Collections2.asImmutableList(i + 0.1f, i + 10.1f),
root.getVector("float-pairs").getObject(i));
assertEquals(i, root.getVector("ints").getObject(i));
}
};
try (final BufferAllocator originalVectorAllocator =
allocator.newChildAllocator("original vectors", 0, allocator.getLimit());
final StructVector parent = StructVector.empty("parent", originalVectorAllocator)) {
FixedSizeListVector tuples =
parent.addOrGet(
"float-pairs",
FieldType.nullable(new ArrowType.FixedSizeList(2)),
FixedSizeListVector.class);
Float4Vector floats =
(Float4Vector)
tuples
.addOrGetVector(FieldType.nullable(Types.MinorType.FLOAT4.getType()))
.getVector();
IntVector ints =
parent.addOrGet("ints", FieldType.nullable(new ArrowType.Int(32, true)), IntVector.class);
parent.allocateNew();
for (int i = 0; i < COUNT; i++) {
tuples.setNotNull(i);
floats.set(i * 2, i + 0.1f);
floats.set(i * 2 + 1, i + 10.1f);
ints.set(i, i);
}
parent.setValueCount(COUNT);
roundTrip(
name,
writeOption,
new VectorSchemaRoot(parent),
/* dictionaryProvider */ null,
TestRoundTrip::writeSingleBatch,
validateFileBatches(new int[] {COUNT}, validator),
validateStreamBatches(new int[] {COUNT}, validator));
}
}
@ParameterizedTest(name = "options = {0}")
@MethodSource("getWriteOption")
public void testVarBinary(String name, IpcOption writeOption) throws Exception {
try (final BufferAllocator originalVectorAllocator =
allocator.newChildAllocator("original vectors", 0, allocator.getLimit());
final StructVector parent = StructVector.empty("parent", originalVectorAllocator)) {
writeVarBinaryData(COUNT, parent);
VectorSchemaRoot root = new VectorSchemaRoot(parent.getChild("root"));
validateVarBinary(COUNT, root);
roundTrip(
name,
writeOption,
root,
/* dictionaryProvider */ null,
TestRoundTrip::writeSingleBatch,
validateFileBatches(new int[] {COUNT}, this::validateVarBinary),
validateStreamBatches(new int[] {COUNT}, this::validateVarBinary));
}
}
@ParameterizedTest(name = "options = {0}")
@MethodSource("getWriteOption")
public void testReadWriteMultipleBatches(String name, IpcOption writeOption) throws IOException {
File file = new File("target/mytest_nulls_multibatch.arrow");
int numBlocksWritten = 0;
try (IntVector vector = new IntVector("foo", allocator); ) {
Schema schema = new Schema(Collections.singletonList(vector.getField()));
try (FileOutputStream fileOutputStream = new FileOutputStream(file);
VectorSchemaRoot root =
new VectorSchemaRoot(
schema, Collections.singletonList((FieldVector) vector), vector.getValueCount());
ArrowFileWriter writer =
new ArrowFileWriter(root, null, fileOutputStream.getChannel(), writeOption)) {
writeBatchData(writer, vector, root);
numBlocksWritten = writer.getRecordBlocks().size();
}
}
try (FileInputStream fileInputStream = new FileInputStream(file);
ArrowFileReader reader = new ArrowFileReader(fileInputStream.getChannel(), allocator); ) {
IntVector vector = (IntVector) reader.getVectorSchemaRoot().getFieldVectors().get(0);
validateBatchData(reader, vector);
assertEquals(numBlocksWritten, reader.getRecordBlocks().size());
}
}
@ParameterizedTest(name = "options = {0}")
@MethodSource("getWriteOption")
public void testMap(String name, IpcOption writeOption) throws Exception {
try (final BufferAllocator originalVectorAllocator =
allocator.newChildAllocator("original vectors", 0, allocator.getLimit());
final VectorSchemaRoot root = writeMapData(originalVectorAllocator)) {
roundTrip(
name,
writeOption,
root,
/* dictionaryProvider */ null,
TestRoundTrip::writeSingleBatch,
validateFileBatches(
new int[] {root.getRowCount()}, (count, readRoot) -> validateMapData(readRoot)),
validateStreamBatches(
new int[] {root.getRowCount()}, (count, readRoot) -> validateMapData(readRoot)));
}
}
@ParameterizedTest(name = "options = {0}")
@MethodSource("getWriteOption")
public void testListAsMap(String name, IpcOption writeOption) throws Exception {
try (final BufferAllocator originalVectorAllocator =
allocator.newChildAllocator("original vectors", 0, allocator.getLimit());
final VectorSchemaRoot root = writeListAsMapData(originalVectorAllocator)) {
roundTrip(
name,
writeOption,
root,
/* dictionaryProvider */ null,
TestRoundTrip::writeSingleBatch,
validateFileBatches(
new int[] {root.getRowCount()}, (count, readRoot) -> validateListAsMapData(readRoot)),
validateStreamBatches(
new int[] {root.getRowCount()},
(count, readRoot) -> validateListAsMapData(readRoot)));
}
}
// Generic test helpers
private static void writeSingleBatch(VectorSchemaRoot root, ArrowWriter writer)
throws IOException {
writer.start();
writer.writeBatch();
writer.end();
}
private CheckedConsumer<ArrowFileReader> validateFileBatches(
int[] counts, BiConsumer<Integer, VectorSchemaRoot> validator) {
return (arrowReader) -> {
VectorSchemaRoot root = arrowReader.getVectorSchemaRoot();
VectorUnloader unloader = new VectorUnloader(root);
Schema schema = root.getSchema();
LOGGER.debug("reading schema: " + schema);
int i = 0;
List<ArrowBlock> recordBatches = arrowReader.getRecordBlocks();
assertEquals(counts.length, recordBatches.size());
long previousOffset = 0;
for (ArrowBlock rbBlock : recordBatches) {
assertTrue(
rbBlock.getOffset() > previousOffset, rbBlock.getOffset() + " > " + previousOffset);
previousOffset = rbBlock.getOffset();
arrowReader.loadRecordBatch(rbBlock);
assertEquals(counts[i], root.getRowCount(), "RB #" + i);
validator.accept(counts[i], root);
try (final ArrowRecordBatch batch = unloader.getRecordBatch()) {
List<ArrowBuffer> buffersLayout = batch.getBuffersLayout();
for (ArrowBuffer arrowBuffer : buffersLayout) {
assertEquals(0, arrowBuffer.getOffset() % 8);
}
}
++i;
}
};
}
private CheckedConsumer<ArrowStreamReader> validateStreamBatches(
int[] counts, BiConsumer<Integer, VectorSchemaRoot> validator) {
return (arrowReader) -> {
VectorSchemaRoot root = arrowReader.getVectorSchemaRoot();
VectorUnloader unloader = new VectorUnloader(root);
Schema schema = root.getSchema();
LOGGER.debug("reading schema: " + schema);
int i = 0;
for (int n = 0; n < counts.length; n++) {
assertTrue(arrowReader.loadNextBatch());
assertEquals(counts[i], root.getRowCount(), "RB #" + i);
validator.accept(counts[i], root);
try (final ArrowRecordBatch batch = unloader.getRecordBatch()) {
final List<ArrowBuffer> buffersLayout = batch.getBuffersLayout();
for (ArrowBuffer arrowBuffer : buffersLayout) {
assertEquals(0, arrowBuffer.getOffset() % 8);
}
}
++i;
}
assertFalse(arrowReader.loadNextBatch());
};
}
@FunctionalInterface
interface CheckedConsumer<T> {
void accept(T t) throws Exception;
}
@FunctionalInterface
interface CheckedBiConsumer<T, U> {
void accept(T t, U u) throws Exception;
}
private void roundTrip(
String name,
IpcOption writeOption,
VectorSchemaRoot root,
DictionaryProvider provider,
CheckedBiConsumer<VectorSchemaRoot, ArrowWriter> writer,
CheckedConsumer<? super ArrowFileReader> fileValidator,
CheckedConsumer<? super ArrowStreamReader> streamValidator)
throws Exception {
final File temp = File.createTempFile("arrow-test-" + name + "-", ".arrow");
temp.deleteOnExit();
final ByteArrayOutputStream memoryStream = new ByteArrayOutputStream();
final Map<String, String> metadata = new HashMap<>();
metadata.put("foo", "bar");
try (final FileOutputStream fileStream = new FileOutputStream(temp);
final ArrowFileWriter fileWriter =
new ArrowFileWriter(root, provider, fileStream.getChannel(), metadata, writeOption);
final ArrowStreamWriter streamWriter =
new ArrowStreamWriter(root, provider, Channels.newChannel(memoryStream), writeOption)) {
writer.accept(root, fileWriter);
writer.accept(root, streamWriter);
}
MessageMetadataResult metadataResult =
MessageSerializer.readMessage(
new ReadChannel(
Channels.newChannel(new ByteArrayInputStream(memoryStream.toByteArray()))));
assertNotNull(metadataResult);
assertEquals(writeOption.metadataVersion.toFlatbufID(), metadataResult.getMessage().version());
try (BufferAllocator readerAllocator =
allocator.newChildAllocator("reader", 0, allocator.getLimit());
FileInputStream fileInputStream = new FileInputStream(temp);
ByteArrayInputStream inputStream = new ByteArrayInputStream(memoryStream.toByteArray());
ArrowFileReader fileReader =
new ArrowFileReader(fileInputStream.getChannel(), readerAllocator);
ArrowStreamReader streamReader = new ArrowStreamReader(inputStream, readerAllocator)) {
fileValidator.accept(fileReader);
streamValidator.accept(streamReader);
assertEquals(writeOption.metadataVersion, fileReader.getFooter().getMetadataVersion());
assertEquals(metadata, fileReader.getMetaData());
}
}
}