TestArrowReaderWriter.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.vector.ipc;

import static java.nio.channels.Channels.newChannel;
import static java.util.Arrays.asList;
import static org.apache.arrow.memory.util.LargeMemoryUtil.checkedCastToInt;
import static org.apache.arrow.vector.TestUtils.newVarCharVector;
import static org.apache.arrow.vector.TestUtils.newVector;
import static org.apache.arrow.vector.testing.ValueVectorDataPopulator.setVector;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.Channels;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import org.apache.arrow.flatbuf.FieldNode;
import org.apache.arrow.flatbuf.Message;
import org.apache.arrow.flatbuf.RecordBatch;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.util.Collections2;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.NullVector;
import org.apache.arrow.vector.TestUtils;
import org.apache.arrow.vector.ValueVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorLoader;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.compare.Range;
import org.apache.arrow.vector.compare.RangeEqualsVisitor;
import org.apache.arrow.vector.compare.TypeEqualsVisitor;
import org.apache.arrow.vector.compare.VectorEqualsVisitor;
import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryEncoder;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.message.ArrowBlock;
import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.MetadataVersion;
import org.apache.arrow.vector.types.Types.MinorType;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
import org.apache.arrow.vector.util.DictionaryUtility;
import org.apache.arrow.vector.util.TransferPair;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

public class TestArrowReaderWriter {

  private BufferAllocator allocator;

  private VarCharVector dictionaryVector1;
  private VarCharVector dictionaryVector2;
  private VarCharVector dictionaryVector3;
  private StructVector dictionaryVector4;

  private Dictionary dictionary1;
  private Dictionary dictionary2;
  private Dictionary dictionary3;
  private Dictionary dictionary4;

  private Schema schema;
  private Schema encodedSchema;

  @BeforeEach
  public void init() {
    allocator = new RootAllocator(Long.MAX_VALUE);

    dictionaryVector1 = newVarCharVector("D1", allocator);
    setVector(
        dictionaryVector1,
        "foo".getBytes(StandardCharsets.UTF_8),
        "bar".getBytes(StandardCharsets.UTF_8),
        "baz".getBytes(StandardCharsets.UTF_8));

    dictionaryVector2 = newVarCharVector("D2", allocator);
    setVector(
        dictionaryVector2,
        "aa".getBytes(StandardCharsets.UTF_8),
        "bb".getBytes(StandardCharsets.UTF_8),
        "cc".getBytes(StandardCharsets.UTF_8));

    dictionaryVector3 = newVarCharVector("D3", allocator);
    setVector(
        dictionaryVector3,
        "foo".getBytes(StandardCharsets.UTF_8),
        "bar".getBytes(StandardCharsets.UTF_8),
        "baz".getBytes(StandardCharsets.UTF_8),
        "aa".getBytes(StandardCharsets.UTF_8),
        "bb".getBytes(StandardCharsets.UTF_8),
        "cc".getBytes(StandardCharsets.UTF_8));

    dictionaryVector4 = newVector(StructVector.class, "D4", MinorType.STRUCT, allocator);
    final Map<String, List<Integer>> dictionaryValues4 = new HashMap<>();
    dictionaryValues4.put("a", Arrays.asList(1, 2, 3));
    dictionaryValues4.put("b", Arrays.asList(4, 5, 6));
    setVector(dictionaryVector4, dictionaryValues4);

    dictionary1 =
        new Dictionary(
            dictionaryVector1,
            new DictionaryEncoding(/*id=*/ 1L, /*ordered=*/ false, /*indexType=*/ null));
    dictionary2 =
        new Dictionary(
            dictionaryVector2,
            new DictionaryEncoding(/*id=*/ 2L, /*ordered=*/ false, /*indexType=*/ null));
    dictionary3 =
        new Dictionary(
            dictionaryVector3,
            new DictionaryEncoding(/*id=*/ 1L, /*ordered=*/ false, /*indexType=*/ null));
    dictionary4 =
        new Dictionary(
            dictionaryVector4,
            new DictionaryEncoding(/*id=*/ 3L, /*ordered=*/ false, /*indexType=*/ null));
  }

  @AfterEach
  public void terminate() throws Exception {
    dictionaryVector1.close();
    dictionaryVector2.close();
    dictionaryVector3.close();
    dictionaryVector4.close();
    allocator.close();
  }

  ArrowBuf buf(byte[] bytes) {
    ArrowBuf buffer = allocator.buffer(bytes.length);
    buffer.writeBytes(bytes);
    return buffer;
  }

  byte[] array(ArrowBuf buf) {
    byte[] bytes = new byte[checkedCastToInt(buf.readableBytes())];
    buf.readBytes(bytes);
    return bytes;
  }

  @Test
  public void test() throws IOException {
    Schema schema =
        new Schema(
            asList(
                new Field(
                    "testField",
                    FieldType.nullable(new ArrowType.Int(8, true)),
                    Collections.<Field>emptyList())));
    ArrowType type = schema.getFields().get(0).getType();
    FieldVector vector = TestUtils.newVector(FieldVector.class, "testField", type, allocator);
    vector.initializeChildrenFromFields(schema.getFields().get(0).getChildren());

    byte[] validity = new byte[] {(byte) 255, 0};
    // second half is "undefined"
    byte[] values = new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};

    ByteArrayOutputStream out = new ByteArrayOutputStream();
    try (VectorSchemaRoot root = new VectorSchemaRoot(schema.getFields(), asList(vector), 16);
        ArrowFileWriter writer = new ArrowFileWriter(root, null, newChannel(out))) {
      ArrowBuf validityb = buf(validity);
      ArrowBuf valuesb = buf(values);
      ArrowRecordBatch batch =
          new ArrowRecordBatch(16, asList(new ArrowFieldNode(16, 8)), asList(validityb, valuesb));
      VectorLoader loader = new VectorLoader(root);
      loader.load(batch);
      writer.writeBatch();

      validityb.close();
      valuesb.close();
      batch.close();
    }

    byte[] byteArray = out.toByteArray();

    try (SeekableReadChannel channel =
            new SeekableReadChannel(new ByteArrayReadableSeekableByteChannel(byteArray));
        ArrowFileReader reader = new ArrowFileReader(channel, allocator)) {
      Schema readSchema = reader.getVectorSchemaRoot().getSchema();
      assertEquals(schema, readSchema);
      // TODO: dictionaries
      List<ArrowBlock> recordBatches = reader.getRecordBlocks();
      assertEquals(1, recordBatches.size());
      reader.loadNextBatch();
      VectorUnloader unloader = new VectorUnloader(reader.getVectorSchemaRoot());
      ArrowRecordBatch recordBatch = unloader.getRecordBatch();
      List<ArrowFieldNode> nodes = recordBatch.getNodes();
      assertEquals(1, nodes.size());
      ArrowFieldNode node = nodes.get(0);
      assertEquals(16, node.getLength());
      assertEquals(8, node.getNullCount());
      List<ArrowBuf> buffers = recordBatch.getBuffers();
      assertEquals(2, buffers.size());
      assertArrayEquals(validity, array(buffers.get(0)));
      assertArrayEquals(values, array(buffers.get(1)));

      // Read just the header. This demonstrates being able to read without need to
      // deserialize the buffer.
      ByteBuffer headerBuffer = ByteBuffer.allocate(recordBatches.get(0).getMetadataLength());
      headerBuffer.put(byteArray, (int) recordBatches.get(0).getOffset(), headerBuffer.capacity());
      // new format prefix_size ==8
      headerBuffer.position(8);
      Message messageFB = Message.getRootAsMessage(headerBuffer);
      RecordBatch recordBatchFB = (RecordBatch) messageFB.header(new RecordBatch());
      assertEquals(2, recordBatchFB.buffersLength());
      assertEquals(1, recordBatchFB.nodesLength());
      FieldNode nodeFB = recordBatchFB.nodes(0);
      assertEquals(16, nodeFB.length());
      assertEquals(8, nodeFB.nullCount());

      recordBatch.close();
    }
  }

  @Test
  public void testWriteReadNullVector() throws IOException {

    int valueCount = 3;

    NullVector nullVector = new NullVector("vector");
    nullVector.setValueCount(valueCount);

    Schema schema = new Schema(asList(nullVector.getField()));

    ByteArrayOutputStream out = new ByteArrayOutputStream();
    try (VectorSchemaRoot root =
            new VectorSchemaRoot(schema.getFields(), asList(nullVector), valueCount);
        ArrowFileWriter writer = new ArrowFileWriter(root, null, newChannel(out))) {
      ArrowRecordBatch batch =
          new ArrowRecordBatch(
              valueCount, asList(new ArrowFieldNode(valueCount, 0)), Collections.emptyList());
      VectorLoader loader = new VectorLoader(root);
      loader.load(batch);
      writer.writeBatch();
    }

    byte[] byteArray = out.toByteArray();

    try (SeekableReadChannel channel =
            new SeekableReadChannel(new ByteArrayReadableSeekableByteChannel(byteArray));
        ArrowFileReader reader = new ArrowFileReader(channel, allocator)) {
      Schema readSchema = reader.getVectorSchemaRoot().getSchema();
      assertEquals(schema, readSchema);
      List<ArrowBlock> recordBatches = reader.getRecordBlocks();
      assertEquals(1, recordBatches.size());

      assertTrue(reader.loadNextBatch());
      assertEquals(1, reader.getVectorSchemaRoot().getFieldVectors().size());

      NullVector readNullVector =
          (NullVector) reader.getVectorSchemaRoot().getFieldVectors().get(0);
      assertEquals(valueCount, readNullVector.getValueCount());
    }
  }

  @Test
  public void testWriteReadWithDictionaries() throws IOException {
    DictionaryProvider.MapDictionaryProvider provider =
        new DictionaryProvider.MapDictionaryProvider();
    provider.put(dictionary1);

    VarCharVector vector1 = newVarCharVector("varchar1", allocator);
    vector1.allocateNewSafe();
    vector1.set(0, "foo".getBytes(StandardCharsets.UTF_8));
    vector1.set(1, "bar".getBytes(StandardCharsets.UTF_8));
    vector1.set(3, "baz".getBytes(StandardCharsets.UTF_8));
    vector1.set(4, "bar".getBytes(StandardCharsets.UTF_8));
    vector1.set(5, "baz".getBytes(StandardCharsets.UTF_8));
    vector1.setValueCount(6);
    FieldVector encodedVector1 = (FieldVector) DictionaryEncoder.encode(vector1, dictionary1);
    vector1.close();

    VarCharVector vector2 = newVarCharVector("varchar2", allocator);
    vector2.allocateNewSafe();
    vector2.set(0, "bar".getBytes(StandardCharsets.UTF_8));
    vector2.set(1, "baz".getBytes(StandardCharsets.UTF_8));
    vector2.set(2, "foo".getBytes(StandardCharsets.UTF_8));
    vector2.set(3, "foo".getBytes(StandardCharsets.UTF_8));
    vector2.set(4, "foo".getBytes(StandardCharsets.UTF_8));
    vector2.set(5, "bar".getBytes(StandardCharsets.UTF_8));
    vector2.setValueCount(6);
    FieldVector encodedVector2 = (FieldVector) DictionaryEncoder.encode(vector2, dictionary1);
    vector2.close();

    List<Field> fields = Arrays.asList(encodedVector1.getField(), encodedVector2.getField());
    List<FieldVector> vectors = Collections2.asImmutableList(encodedVector1, encodedVector2);
    try (VectorSchemaRoot root =
            new VectorSchemaRoot(fields, vectors, encodedVector1.getValueCount());
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        ArrowFileWriter writer = new ArrowFileWriter(root, provider, newChannel(out)); ) {

      writer.start();
      writer.writeBatch();
      writer.end();

      try (SeekableReadChannel channel =
              new SeekableReadChannel(new ByteArrayReadableSeekableByteChannel(out.toByteArray()));
          ArrowFileReader reader = new ArrowFileReader(channel, allocator)) {
        Schema readSchema = reader.getVectorSchemaRoot().getSchema();
        assertEquals(root.getSchema(), readSchema);
        assertEquals(1, reader.getDictionaryBlocks().size());
        assertEquals(1, reader.getRecordBlocks().size());

        reader.loadNextBatch();
        assertEquals(2, reader.getVectorSchemaRoot().getFieldVectors().size());
      }
    }
  }

  @Test
  public void testWriteReadWithStructDictionaries() throws IOException {
    DictionaryProvider.MapDictionaryProvider provider =
        new DictionaryProvider.MapDictionaryProvider();
    provider.put(dictionary4);

    try (final StructVector vector =
        newVector(StructVector.class, "D4", MinorType.STRUCT, allocator)) {
      final Map<String, List<Integer>> values = new HashMap<>();
      // Index: 0, 2, 1, 2, 1, 0, 0
      values.put("a", Arrays.asList(1, 3, 2, 3, 2, 1, 1));
      values.put("b", Arrays.asList(4, 6, 5, 6, 5, 4, 4));
      setVector(vector, values);
      FieldVector encodedVector = (FieldVector) DictionaryEncoder.encode(vector, dictionary4);

      List<Field> fields = Arrays.asList(encodedVector.getField());
      List<FieldVector> vectors = Collections2.asImmutableList(encodedVector);
      try (VectorSchemaRoot root =
              new VectorSchemaRoot(fields, vectors, encodedVector.getValueCount());
          ByteArrayOutputStream out = new ByteArrayOutputStream();
          ArrowFileWriter writer = new ArrowFileWriter(root, provider, newChannel(out)); ) {

        writer.start();
        writer.writeBatch();
        writer.end();

        try (SeekableReadChannel channel =
                new SeekableReadChannel(
                    new ByteArrayReadableSeekableByteChannel(out.toByteArray()));
            ArrowFileReader reader = new ArrowFileReader(channel, allocator)) {
          final VectorSchemaRoot readRoot = reader.getVectorSchemaRoot();
          final Schema readSchema = readRoot.getSchema();
          assertEquals(root.getSchema(), readSchema);
          assertEquals(1, reader.getDictionaryBlocks().size());
          assertEquals(1, reader.getRecordBlocks().size());

          reader.loadNextBatch();
          assertEquals(1, readRoot.getFieldVectors().size());
          assertEquals(1, reader.getDictionaryVectors().size());

          // Read the encoded vector and check it
          final FieldVector readEncoded = readRoot.getVector(0);
          assertEquals(encodedVector.getValueCount(), readEncoded.getValueCount());
          assertTrue(
              new RangeEqualsVisitor(encodedVector, readEncoded)
                  .rangeEquals(new Range(0, 0, encodedVector.getValueCount())));

          // Read the dictionary
          final Map<Long, Dictionary> readDictionaryMap = reader.getDictionaryVectors();
          final Dictionary readDictionary =
              readDictionaryMap.get(readEncoded.getField().getDictionary().getId());
          assertNotNull(readDictionary);

          // Assert the dictionary vector is correct
          final FieldVector readDictionaryVector = readDictionary.getVector();
          assertEquals(dictionaryVector4.getValueCount(), readDictionaryVector.getValueCount());
          final BiFunction<ValueVector, ValueVector, Boolean> typeComparatorIgnoreName =
              (v1, v2) -> new TypeEqualsVisitor(v1, false, true).equals(v2);
          assertTrue(
              new RangeEqualsVisitor(
                      dictionaryVector4, readDictionaryVector, typeComparatorIgnoreName)
                  .rangeEquals(new Range(0, 0, dictionaryVector4.getValueCount())),
              "Dictionary vectors are not equal");

          // Assert the decoded vector is correct
          try (final ValueVector readVector =
              DictionaryEncoder.decode(readEncoded, readDictionary)) {
            assertEquals(vector.getValueCount(), readVector.getValueCount());
            assertTrue(
                new RangeEqualsVisitor(vector, readVector, typeComparatorIgnoreName)
                    .rangeEquals(new Range(0, 0, vector.getValueCount())),
                "Decoded vectors are not equal");
          }
        }
      }
    }
  }

  @Test
  public void testEmptyStreamInFileIPC() throws IOException {

    DictionaryProvider.MapDictionaryProvider provider =
        new DictionaryProvider.MapDictionaryProvider();
    provider.put(dictionary1);

    VarCharVector vector = newVarCharVector("varchar", allocator);
    vector.allocateNewSafe();
    vector.set(0, "foo".getBytes(StandardCharsets.UTF_8));
    vector.set(1, "bar".getBytes(StandardCharsets.UTF_8));
    vector.set(3, "baz".getBytes(StandardCharsets.UTF_8));
    vector.set(4, "bar".getBytes(StandardCharsets.UTF_8));
    vector.set(5, "baz".getBytes(StandardCharsets.UTF_8));
    vector.setValueCount(6);

    FieldVector encodedVector1A = (FieldVector) DictionaryEncoder.encode(vector, dictionary1);
    vector.close();

    List<Field> fields = Arrays.asList(encodedVector1A.getField());
    List<FieldVector> vectors = Collections2.asImmutableList(encodedVector1A);

    try (VectorSchemaRoot root =
            new VectorSchemaRoot(fields, vectors, encodedVector1A.getValueCount());
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        ArrowFileWriter writer = new ArrowFileWriter(root, provider, newChannel(out))) {

      writer.start();
      writer.end();

      try (SeekableReadChannel channel =
              new SeekableReadChannel(new ByteArrayReadableSeekableByteChannel(out.toByteArray()));
          ArrowFileReader reader = new ArrowFileReader(channel, allocator)) {
        Schema readSchema = reader.getVectorSchemaRoot().getSchema();
        assertEquals(root.getSchema(), readSchema);
        assertEquals(1, reader.getDictionaryVectors().size());
        assertEquals(0, reader.getDictionaryBlocks().size());
        assertEquals(0, reader.getRecordBlocks().size());
      }
    }
  }

  @Test
  public void testEmptyStreamInStreamingIPC() throws IOException {

    DictionaryProvider.MapDictionaryProvider provider =
        new DictionaryProvider.MapDictionaryProvider();
    provider.put(dictionary1);

    VarCharVector vector = newVarCharVector("varchar", allocator);
    vector.allocateNewSafe();
    vector.set(0, "foo".getBytes(StandardCharsets.UTF_8));
    vector.set(1, "bar".getBytes(StandardCharsets.UTF_8));
    vector.set(3, "baz".getBytes(StandardCharsets.UTF_8));
    vector.set(4, "bar".getBytes(StandardCharsets.UTF_8));
    vector.set(5, "baz".getBytes(StandardCharsets.UTF_8));
    vector.setValueCount(6);

    FieldVector encodedVector = (FieldVector) DictionaryEncoder.encode(vector, dictionary1);
    vector.close();

    List<Field> fields = Arrays.asList(encodedVector.getField());
    try (VectorSchemaRoot root =
            new VectorSchemaRoot(
                fields, Arrays.asList(encodedVector), encodedVector.getValueCount());
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        ArrowStreamWriter writer = new ArrowStreamWriter(root, provider, newChannel(out))) {

      writer.start();
      writer.end();

      try (ArrowStreamReader reader =
          new ArrowStreamReader(
              new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator)) {
        Schema readSchema = reader.getVectorSchemaRoot().getSchema();
        assertEquals(root.getSchema(), readSchema);
        assertEquals(1, reader.getDictionaryVectors().size());
        assertFalse(reader.loadNextBatch());
      }
    }
  }

  @Test
  public void testDictionaryReplacement() throws Exception {
    VarCharVector vector1 = newVarCharVector("varchar1", allocator);
    setVector(
        vector1,
        "foo".getBytes(StandardCharsets.UTF_8),
        "bar".getBytes(StandardCharsets.UTF_8),
        "baz".getBytes(StandardCharsets.UTF_8),
        "bar".getBytes(StandardCharsets.UTF_8));

    FieldVector encodedVector1 = (FieldVector) DictionaryEncoder.encode(vector1, dictionary1);

    VarCharVector vector2 = newVarCharVector("varchar2", allocator);
    setVector(
        vector2,
        "foo".getBytes(StandardCharsets.UTF_8),
        "foo".getBytes(StandardCharsets.UTF_8),
        "foo".getBytes(StandardCharsets.UTF_8),
        "foo".getBytes(StandardCharsets.UTF_8));

    FieldVector encodedVector2 = (FieldVector) DictionaryEncoder.encode(vector2, dictionary1);

    DictionaryProvider.MapDictionaryProvider provider =
        new DictionaryProvider.MapDictionaryProvider();
    provider.put(dictionary1);
    List<Field> schemaFields = new ArrayList<>();
    schemaFields.add(
        DictionaryUtility.toMessageFormat(encodedVector1.getField(), provider, new HashSet<>()));
    schemaFields.add(
        DictionaryUtility.toMessageFormat(encodedVector2.getField(), provider, new HashSet<>()));
    Schema schema = new Schema(schemaFields);

    ByteArrayOutputStream outStream = new ByteArrayOutputStream();
    WriteChannel out = new WriteChannel(newChannel(outStream));

    // write schema
    MessageSerializer.serialize(out, schema);

    List<AutoCloseable> closeableList = new ArrayList<>();

    // write non-delta dictionary with id=1
    serializeDictionaryBatch(out, dictionary3, false, closeableList);

    // write non-delta dictionary with id=1
    serializeDictionaryBatch(out, dictionary1, false, closeableList);

    // write recordBatch2
    serializeRecordBatch(out, Arrays.asList(encodedVector1, encodedVector2), closeableList);

    // write eos
    out.writeIntLittleEndian(0);

    try (ArrowStreamReader reader =
        new ArrowStreamReader(
            new ByteArrayReadableSeekableByteChannel(outStream.toByteArray()), allocator)) {
      assertEquals(1, reader.getDictionaryVectors().size());
      assertTrue(reader.loadNextBatch());
      FieldVector dictionaryVector = reader.getDictionaryVectors().get(1L).getVector();
      // make sure the delta dictionary is concatenated.
      assertTrue(VectorEqualsVisitor.vectorEquals(dictionaryVector, dictionaryVector1, null));
      assertFalse(reader.loadNextBatch());
    }

    vector1.close();
    vector2.close();
    AutoCloseables.close(closeableList);
  }

  @Test
  public void testDeltaDictionary() throws Exception {
    VarCharVector vector1 = newVarCharVector("varchar1", allocator);
    setVector(
        vector1,
        "foo".getBytes(StandardCharsets.UTF_8),
        "bar".getBytes(StandardCharsets.UTF_8),
        "baz".getBytes(StandardCharsets.UTF_8),
        "bar".getBytes(StandardCharsets.UTF_8));

    FieldVector encodedVector1 = (FieldVector) DictionaryEncoder.encode(vector1, dictionary1);

    VarCharVector vector2 = newVarCharVector("varchar2", allocator);
    setVector(
        vector2,
        "foo".getBytes(StandardCharsets.UTF_8),
        "aa".getBytes(StandardCharsets.UTF_8),
        "bb".getBytes(StandardCharsets.UTF_8),
        "cc".getBytes(StandardCharsets.UTF_8));

    FieldVector encodedVector2 = (FieldVector) DictionaryEncoder.encode(vector2, dictionary3);

    DictionaryProvider.MapDictionaryProvider provider =
        new DictionaryProvider.MapDictionaryProvider();
    provider.put(dictionary1);
    provider.put(dictionary3);
    List<Field> schemaFields = new ArrayList<>();
    schemaFields.add(
        DictionaryUtility.toMessageFormat(encodedVector1.getField(), provider, new HashSet<>()));
    schemaFields.add(
        DictionaryUtility.toMessageFormat(encodedVector2.getField(), provider, new HashSet<>()));
    Schema schema = new Schema(schemaFields);

    ByteArrayOutputStream outStream = new ByteArrayOutputStream();
    WriteChannel out = new WriteChannel(newChannel(outStream));

    // write schema
    MessageSerializer.serialize(out, schema);

    List<AutoCloseable> closeableList = new ArrayList<>();

    // write non-delta dictionary with id=1
    serializeDictionaryBatch(out, dictionary1, false, closeableList);

    // write delta dictionary with id=1
    Dictionary deltaDictionary =
        new Dictionary(dictionaryVector2, new DictionaryEncoding(1L, false, null));
    serializeDictionaryBatch(out, deltaDictionary, true, closeableList);
    deltaDictionary.getVector().close();

    // write recordBatch2
    serializeRecordBatch(out, Arrays.asList(encodedVector1, encodedVector2), closeableList);

    // write eos
    out.writeIntLittleEndian(0);

    try (ArrowStreamReader reader =
        new ArrowStreamReader(
            new ByteArrayReadableSeekableByteChannel(outStream.toByteArray()), allocator)) {
      assertEquals(1, reader.getDictionaryVectors().size());
      assertTrue(reader.loadNextBatch());
      FieldVector dictionaryVector = reader.getDictionaryVectors().get(1L).getVector();
      // make sure the delta dictionary is concatenated.
      assertTrue(VectorEqualsVisitor.vectorEquals(dictionaryVector, dictionaryVector3, null));
      assertFalse(reader.loadNextBatch());
    }

    vector1.close();
    vector2.close();
    AutoCloseables.close(closeableList);
  }

  // Tests that the ArrowStreamWriter re-emits dictionaries when they change
  @Test
  public void testWriteReadStreamWithDictionaryReplacement() throws Exception {
    DictionaryProvider.MapDictionaryProvider provider =
        new DictionaryProvider.MapDictionaryProvider();
    provider.put(dictionary1);

    String[] batch0 = {"foo", "bar", "baz", "bar", "baz"};
    String[] batch1 = {"foo", "aa", "bar", "bb", "baz", "cc"};

    VarCharVector vector = newVarCharVector("varchar", allocator);
    vector.allocateNewSafe();
    for (int i = 0; i < batch0.length; ++i) {
      vector.set(i, batch0[i].getBytes(StandardCharsets.UTF_8));
    }
    vector.setValueCount(batch0.length);
    FieldVector encodedVector1 = (FieldVector) DictionaryEncoder.encode(vector, dictionary1);

    List<Field> fields = Arrays.asList(encodedVector1.getField());
    try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
      try (VectorSchemaRoot root =
              new VectorSchemaRoot(
                  fields, Arrays.asList(encodedVector1), encodedVector1.getValueCount());
          ArrowStreamWriter writer = new ArrowStreamWriter(root, provider, newChannel(out))) {
        writer.start();

        // Write batch with initial data and dictionary
        writer.writeBatch();

        // Create data for the next batch, using an extended dictionary with the same id
        vector.reset();
        for (int i = 0; i < batch1.length; ++i) {
          vector.set(i, batch1[i].getBytes(StandardCharsets.UTF_8));
        }
        vector.setValueCount(batch1.length);

        // Re-encode and move encoded data into the vector schema root
        provider.put(dictionary3);
        FieldVector encodedVector2 = (FieldVector) DictionaryEncoder.encode(vector, dictionary3);
        TransferPair transferPair = encodedVector2.makeTransferPair(root.getVector(0));
        transferPair.transfer();

        // Write second batch
        root.setRowCount(batch1.length);
        writer.writeBatch();

        writer.end();
      }

      try (ArrowStreamReader reader =
          new ArrowStreamReader(
              new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator)) {
        VectorSchemaRoot root = reader.getVectorSchemaRoot();

        // Read and verify first batch
        assertTrue(reader.loadNextBatch());
        assertEquals(batch0.length, root.getRowCount());
        FieldVector readEncoded1 = root.getVector(0);
        long dictionaryId = readEncoded1.getField().getDictionary().getId();
        try (VarCharVector decodedValues =
            (VarCharVector) DictionaryEncoder.decode(readEncoded1, reader.lookup(dictionaryId))) {
          for (int i = 0; i < batch0.length; ++i) {
            assertEquals(batch0[i], new String(decodedValues.get(i), StandardCharsets.UTF_8));
          }
        }

        // Read and verify second batch
        assertTrue(reader.loadNextBatch());
        assertEquals(batch1.length, root.getRowCount());
        FieldVector readEncoded2 = root.getVector(0);
        dictionaryId = readEncoded2.getField().getDictionary().getId();
        try (VarCharVector decodedValues =
            (VarCharVector) DictionaryEncoder.decode(readEncoded2, reader.lookup(dictionaryId))) {
          for (int i = 0; i < batch1.length; ++i) {
            assertEquals(batch1[i], new String(decodedValues.get(i), StandardCharsets.UTF_8));
          }
        }

        assertFalse(reader.loadNextBatch());
      }
    }

    vector.close();
  }

  private void serializeDictionaryBatch(
      WriteChannel out, Dictionary dictionary, boolean isDelta, List<AutoCloseable> closeables)
      throws IOException {

    FieldVector dictVector = dictionary.getVector();
    VectorSchemaRoot root =
        new VectorSchemaRoot(
            Collections.singletonList(dictVector.getField()),
            Collections.singletonList(dictVector),
            dictVector.getValueCount());
    ArrowDictionaryBatch batch =
        new ArrowDictionaryBatch(
            dictionary.getEncoding().getId(), new VectorUnloader(root).getRecordBatch(), isDelta);
    MessageSerializer.serialize(out, batch);
    closeables.add(batch);
    closeables.add(root);
  }

  private void serializeRecordBatch(
      WriteChannel out, List<FieldVector> vectors, List<AutoCloseable> closeables)
      throws IOException {

    List<Field> fields = vectors.stream().map(v -> v.getField()).collect(Collectors.toList());
    VectorSchemaRoot root = new VectorSchemaRoot(fields, vectors, vectors.get(0).getValueCount());
    VectorUnloader unloader = new VectorUnloader(root);
    ArrowRecordBatch batch = unloader.getRecordBatch();
    MessageSerializer.serialize(out, batch);
    closeables.add(batch);
    closeables.add(root);
  }

  @Test
  public void testReadInterleavedData() throws IOException {
    List<ArrowRecordBatch> batches = createRecordBatches();

    ByteArrayOutputStream outStream = new ByteArrayOutputStream();
    WriteChannel out = new WriteChannel(newChannel(outStream));

    // write schema
    MessageSerializer.serialize(out, schema);

    // write dictionary1
    FieldVector dictVector1 = dictionary1.getVector();
    VectorSchemaRoot dictRoot1 =
        new VectorSchemaRoot(
            Collections.singletonList(dictVector1.getField()),
            Collections.singletonList(dictVector1),
            dictVector1.getValueCount());
    ArrowDictionaryBatch dictionaryBatch1 =
        new ArrowDictionaryBatch(1, new VectorUnloader(dictRoot1).getRecordBatch());
    MessageSerializer.serialize(out, dictionaryBatch1);
    dictionaryBatch1.close();
    dictRoot1.close();

    // write recordBatch1
    MessageSerializer.serialize(out, batches.get(0));

    // write dictionary2
    FieldVector dictVector2 = dictionary2.getVector();
    VectorSchemaRoot dictRoot2 =
        new VectorSchemaRoot(
            Collections.singletonList(dictVector2.getField()),
            Collections.singletonList(dictVector2),
            dictVector2.getValueCount());
    ArrowDictionaryBatch dictionaryBatch2 =
        new ArrowDictionaryBatch(2, new VectorUnloader(dictRoot2).getRecordBatch());
    MessageSerializer.serialize(out, dictionaryBatch2);
    dictionaryBatch2.close();
    dictRoot2.close();

    // write recordBatch1
    MessageSerializer.serialize(out, batches.get(1));

    // write eos
    out.writeIntLittleEndian(0);

    try (ArrowStreamReader reader =
        new ArrowStreamReader(
            new ByteArrayReadableSeekableByteChannel(outStream.toByteArray()), allocator)) {
      Schema readSchema = reader.getVectorSchemaRoot().getSchema();
      assertEquals(encodedSchema, readSchema);
      assertEquals(2, reader.getDictionaryVectors().size());
      assertTrue(reader.loadNextBatch());
      assertTrue(reader.loadNextBatch());
      assertFalse(reader.loadNextBatch());
    }

    batches.forEach(batch -> batch.close());
  }

  private List<ArrowRecordBatch> createRecordBatches() {
    List<ArrowRecordBatch> batches = new ArrayList<>();

    DictionaryProvider.MapDictionaryProvider provider =
        new DictionaryProvider.MapDictionaryProvider();
    provider.put(dictionary1);
    provider.put(dictionary2);

    VarCharVector vectorA1 = newVarCharVector("varcharA1", allocator);
    vectorA1.allocateNewSafe();
    vectorA1.set(0, "foo".getBytes(StandardCharsets.UTF_8));
    vectorA1.set(1, "bar".getBytes(StandardCharsets.UTF_8));
    vectorA1.set(3, "baz".getBytes(StandardCharsets.UTF_8));
    vectorA1.set(4, "bar".getBytes(StandardCharsets.UTF_8));
    vectorA1.set(5, "baz".getBytes(StandardCharsets.UTF_8));
    vectorA1.setValueCount(6);

    VarCharVector vectorA2 = newVarCharVector("varcharA2", allocator);
    vectorA2.setValueCount(6);
    FieldVector encodedVectorA1 = (FieldVector) DictionaryEncoder.encode(vectorA1, dictionary1);
    vectorA1.close();
    FieldVector encodedVectorA2 = (FieldVector) DictionaryEncoder.encode(vectorA1, dictionary2);
    vectorA2.close();

    List<Field> fields = Arrays.asList(encodedVectorA1.getField(), encodedVectorA2.getField());
    List<FieldVector> vectors = Collections2.asImmutableList(encodedVectorA1, encodedVectorA2);
    VectorSchemaRoot root = new VectorSchemaRoot(fields, vectors, encodedVectorA1.getValueCount());
    VectorUnloader unloader = new VectorUnloader(root);
    batches.add(unloader.getRecordBatch());
    root.close();

    VarCharVector vectorB1 = newVarCharVector("varcharB1", allocator);
    vectorB1.setValueCount(6);

    VarCharVector vectorB2 = newVarCharVector("varcharB2", allocator);
    vectorB2.allocateNew();
    vectorB2.setValueCount(6);
    vectorB2.set(0, "aa".getBytes(StandardCharsets.UTF_8));
    vectorB2.set(1, "aa".getBytes(StandardCharsets.UTF_8));
    vectorB2.set(3, "bb".getBytes(StandardCharsets.UTF_8));
    vectorB2.set(4, "bb".getBytes(StandardCharsets.UTF_8));
    vectorB2.set(5, "cc".getBytes(StandardCharsets.UTF_8));
    vectorB2.setValueCount(6);
    FieldVector encodedVectorB1 = (FieldVector) DictionaryEncoder.encode(vectorB1, dictionary1);
    vectorB1.close();
    FieldVector encodedVectorB2 = (FieldVector) DictionaryEncoder.encode(vectorB2, dictionary2);
    vectorB2.close();

    List<Field> fieldsB = Arrays.asList(encodedVectorB1.getField(), encodedVectorB2.getField());
    List<FieldVector> vectorsB = Collections2.asImmutableList(encodedVectorB1, encodedVectorB2);
    VectorSchemaRoot rootB = new VectorSchemaRoot(fieldsB, vectorsB, 6);
    VectorUnloader unloaderB = new VectorUnloader(rootB);
    batches.add(unloaderB.getRecordBatch());
    rootB.close();

    List<Field> schemaFields = new ArrayList<>();
    schemaFields.add(
        DictionaryUtility.toMessageFormat(encodedVectorA1.getField(), provider, new HashSet<>()));
    schemaFields.add(
        DictionaryUtility.toMessageFormat(encodedVectorA2.getField(), provider, new HashSet<>()));
    schema = new Schema(schemaFields);

    encodedSchema =
        new Schema(Arrays.asList(encodedVectorA1.getField(), encodedVectorA2.getField()));

    return batches;
  }

  @Test
  public void testLegacyIpcBackwardsCompatibility() throws Exception {
    Schema schema = new Schema(asList(Field.nullable("field", new ArrowType.Int(32, true))));
    IntVector vector = new IntVector("vector", allocator);
    final int valueCount = 2;
    vector.setValueCount(valueCount);
    vector.setSafe(0, 1);
    vector.setSafe(1, 2);
    ArrowRecordBatch batch =
        new ArrowRecordBatch(
            valueCount,
            asList(new ArrowFieldNode(valueCount, 0)),
            asList(vector.getValidityBuffer(), vector.getDataBuffer()));

    ByteArrayOutputStream outStream = new ByteArrayOutputStream();
    WriteChannel out = new WriteChannel(newChannel(outStream));

    // write legacy ipc format
    IpcOption option = new IpcOption(true, MetadataVersion.DEFAULT);
    MessageSerializer.serialize(out, schema, option);
    MessageSerializer.serialize(out, batch);

    ReadChannel in = new ReadChannel(newChannel(new ByteArrayInputStream(outStream.toByteArray())));
    Schema readSchema = MessageSerializer.deserializeSchema(in);
    assertEquals(schema, readSchema);
    ArrowRecordBatch readBatch = MessageSerializer.deserializeRecordBatch(in, allocator);
    assertEquals(batch.getLength(), readBatch.getLength());
    assertEquals(batch.computeBodyLength(), readBatch.computeBodyLength());
    readBatch.close();

    // write ipc format with continuation
    option = IpcOption.DEFAULT;
    MessageSerializer.serialize(out, schema, option);
    MessageSerializer.serialize(out, batch);

    ReadChannel in2 =
        new ReadChannel(newChannel(new ByteArrayInputStream(outStream.toByteArray())));
    Schema readSchema2 = MessageSerializer.deserializeSchema(in2);
    assertEquals(schema, readSchema2);
    ArrowRecordBatch readBatch2 = MessageSerializer.deserializeRecordBatch(in2, allocator);
    assertEquals(batch.getLength(), readBatch2.getLength());
    assertEquals(batch.computeBodyLength(), readBatch2.computeBodyLength());
    readBatch2.close();

    batch.close();
    vector.close();
  }

  @Test
  public void testChannelReadFully() throws IOException {
    final ByteBuffer buf = ByteBuffer.allocate(4).order(ByteOrder.nativeOrder());
    buf.putInt(200);
    buf.rewind();

    try (ReadChannel channel =
            new ReadChannel(Channels.newChannel(new ByteArrayInputStream(buf.array())));
        ArrowBuf arrBuf = allocator.buffer(8)) {
      arrBuf.setInt(0, 100);
      arrBuf.writerIndex(4);
      assertEquals(4, arrBuf.writerIndex());

      long n = channel.readFully(arrBuf, 4);
      assertEquals(4, n);
      assertEquals(8, arrBuf.writerIndex());

      assertEquals(100, arrBuf.getInt(0));
      assertEquals(200, arrBuf.getInt(4));
    }
  }

  @Test
  public void testChannelReadFullyEos() throws IOException {
    final ByteBuffer buf = ByteBuffer.allocate(4).order(ByteOrder.nativeOrder());
    buf.putInt(10);
    buf.rewind();

    try (ReadChannel channel =
            new ReadChannel(Channels.newChannel(new ByteArrayInputStream(buf.array())));
        ArrowBuf arrBuf = allocator.buffer(8)) {
      int n = channel.readFully(arrBuf.nioBuffer(0, 8));
      assertEquals(4, n);

      // the input has only 4 bytes, so the number of bytes read should be 4
      assertEquals(4, channel.bytesRead());

      // the first 4 bytes have been read successfully.
      assertEquals(10, arrBuf.getInt(0));
    }
  }

  @Test
  public void testCustomMetaData() throws IOException {

    VarCharVector vector = newVarCharVector("varchar1", allocator);

    List<Field> fields = Arrays.asList(vector.getField());
    List<FieldVector> vectors = Collections2.asImmutableList(vector);
    Map<String, String> metadata = new HashMap<>();
    metadata.put("key1", "value1");
    metadata.put("key2", "value2");
    try (VectorSchemaRoot root = new VectorSchemaRoot(fields, vectors, vector.getValueCount());
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        ArrowFileWriter writer = new ArrowFileWriter(root, null, newChannel(out), metadata); ) {

      writer.start();
      writer.end();

      try (SeekableReadChannel channel =
              new SeekableReadChannel(new ByteArrayReadableSeekableByteChannel(out.toByteArray()));
          ArrowFileReader reader = new ArrowFileReader(channel, allocator)) {
        reader.getVectorSchemaRoot();

        Map<String, String> readMeta = reader.getMetaData();
        assertEquals(2, readMeta.size());
        assertEquals("value1", readMeta.get("key1"));
        assertEquals("value2", readMeta.get("key2"));
      }
    }
  }

  /**
   * This test case covers the case for which the footer size is extremely large (much larger than
   * the file size). Due to integer overflow, our implementation fails detect the problem, which
   * leads to extremely large memory allocation and eventually causing an OutOfMemoryError.
   */
  @Test
  public void testFileFooterSizeOverflow() {
    // copy of org.apache.arrow.vector.ipc.ArrowMagic#MAGIC
    final byte[] magicBytes = "ARROW1".getBytes(StandardCharsets.UTF_8);

    // prepare input data
    byte[] data = new byte[30];
    System.arraycopy(magicBytes, 0, data, 0, ArrowMagic.MAGIC_LENGTH);
    int footerLength = Integer.MAX_VALUE;
    byte[] footerLengthBytes =
        ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putInt(footerLength).array();
    int footerOffset = data.length - ArrowMagic.MAGIC_LENGTH - 4;
    System.arraycopy(footerLengthBytes, 0, data, footerOffset, 4);
    System.arraycopy(magicBytes, 0, data, footerOffset + 4, ArrowMagic.MAGIC_LENGTH);

    // test file reader
    InvalidArrowFileException e =
        assertThrows(
            InvalidArrowFileException.class,
            () -> {
              try (SeekableReadChannel channel =
                      new SeekableReadChannel(new ByteArrayReadableSeekableByteChannel(data));
                  ArrowFileReader reader = new ArrowFileReader(channel, allocator)) {
                reader.getVectorSchemaRoot().getSchema();
              }
            });

    assertEquals("invalid footer length: " + footerLength, e.getMessage());
  }
}