TestArrowStreamPipe.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.vector.ipc;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

import java.io.IOException;
import java.nio.channels.Pipe;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.WritableByteChannel;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.TinyIntVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.jupiter.api.Test;

public class TestArrowStreamPipe {
  Schema schema = MessageSerializerTest.testSchema();
  BufferAllocator alloc = new RootAllocator(Long.MAX_VALUE);

  private final class WriterThread extends Thread {

    private final int numBatches;
    private final ArrowStreamWriter writer;
    private final VectorSchemaRoot root;

    public WriterThread(int numBatches, WritableByteChannel sinkChannel) throws IOException {
      this.numBatches = numBatches;
      BufferAllocator allocator = alloc.newChildAllocator("writer thread", 0, Integer.MAX_VALUE);
      root = VectorSchemaRoot.create(schema, allocator);
      writer = new ArrowStreamWriter(root, null, sinkChannel);
    }

    @Override
    public void run() {
      try {
        writer.start();
        for (int j = 0; j < numBatches; j++) {
          root.getFieldVectors().get(0).allocateNew();
          TinyIntVector vector = (TinyIntVector) root.getFieldVectors().get(0);
          // Send a changing batch id first
          vector.set(0, j);
          for (int i = 1; i < 16; i++) {
            vector.set(i, i < 8 ? 1 : 0, (byte) (i + 1));
          }
          vector.setValueCount(16);
          root.setRowCount(16);

          writer.writeBatch();
        }
        writer.close();
        root.close();
      } catch (IOException e) {
        e.printStackTrace();
        fail(e.toString()); // have to explicitly fail since we're in a separate thread
      }
    }

    public long bytesWritten() {
      return writer.bytesWritten();
    }
  }

  private final class ReaderThread extends Thread {
    private int batchesRead = 0;
    private final ArrowStreamReader reader;
    private final BufferAllocator alloc = new RootAllocator(Long.MAX_VALUE);
    private boolean done = false;

    public ReaderThread(ReadableByteChannel sourceChannel) throws IOException {
      reader =
          new ArrowStreamReader(sourceChannel, alloc) {

            @Override
            public boolean loadNextBatch() throws IOException {
              if (super.loadNextBatch()) {
                batchesRead++;
              } else {
                done = true;
                return false;
              }
              VectorSchemaRoot root = getVectorSchemaRoot();
              assertEquals(16, root.getRowCount());
              TinyIntVector vector = (TinyIntVector) root.getFieldVectors().get(0);
              assertEquals((byte) (batchesRead - 1), vector.get(0));
              for (int i = 1; i < 16; i++) {
                if (i < 8) {
                  assertEquals((byte) (i + 1), vector.get(i));
                } else {
                  assertTrue(vector.isNull(i));
                }
              }

              return true;
            }
          };
    }

    @Override
    public void run() {
      try {
        assertEquals(schema, reader.getVectorSchemaRoot().getSchema());
        while (!done) {
          assertTrue(reader.loadNextBatch() != done);
        }
        reader.close();
      } catch (IOException e) {
        e.printStackTrace();
        fail(e.toString()); // have to explicitly fail since we're in a separate thread
      }
    }

    public int getBatchesRead() {
      return batchesRead;
    }

    public long bytesRead() {
      return reader.bytesRead();
    }
  }

  // Starts up a producer and consumer thread to read/write batches.
  @Test
  public void pipeTest() throws IOException, InterruptedException {
    final int NUM_BATCHES = 10;
    Pipe pipe = Pipe.open();
    WriterThread writer = new WriterThread(NUM_BATCHES, pipe.sink());
    ReaderThread reader = new ReaderThread(pipe.source());

    writer.start();
    reader.start();
    reader.join();
    writer.join();

    assertEquals(NUM_BATCHES, reader.getBatchesRead());
    assertEquals(writer.bytesWritten(), reader.bytesRead());
  }
}