MessageSerializerTest.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.vector.ipc;

import static java.util.Arrays.asList;
import static org.apache.arrow.memory.util.LargeMemoryUtil.checkedCastToInt;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.Channels;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
import org.apache.arrow.vector.ipc.message.ArrowMessage;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.MetadataVersion;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.jupiter.api.Test;

public class MessageSerializerTest {

  public static ArrowBuf buf(BufferAllocator alloc, byte[] bytes) {
    ArrowBuf buffer = alloc.buffer(bytes.length);
    buffer.writeBytes(bytes);
    return buffer;
  }

  public static byte[] array(ArrowBuf buf) {
    byte[] bytes = new byte[checkedCastToInt(buf.readableBytes())];
    buf.readBytes(bytes);
    return bytes;
  }

  private int intToByteRoundtrip(int v, byte[] bytes) {
    MessageSerializer.intToBytes(v, bytes);
    return MessageSerializer.bytesToInt(bytes);
  }

  @Test
  public void testIntToBytes() {
    byte[] bytes = new byte[4];
    int[] values = new int[] {1, 15, 1 << 8, 1 << 16, Integer.MAX_VALUE};
    for (int v : values) {
      assertEquals(intToByteRoundtrip(v, bytes), v);
    }
  }

  @Test
  public void testWriteMessageBufferAligned() throws IOException {
    ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
    WriteChannel out = new WriteChannel(Channels.newChannel(outputStream));

    // This is not a valid Arrow Message, only to test writing and alignment
    ByteBuffer buffer = ByteBuffer.allocate(8).order(ByteOrder.nativeOrder());
    buffer.putInt(1);
    buffer.putInt(2);
    buffer.flip();

    int bytesWritten = MessageSerializer.writeMessageBuffer(out, 8, buffer);
    assertEquals(16, bytesWritten);

    buffer.rewind();
    buffer.putInt(3);
    buffer.flip();
    bytesWritten = MessageSerializer.writeMessageBuffer(out, 4, buffer);
    assertEquals(16, bytesWritten);

    ByteArrayInputStream inputStream = new ByteArrayInputStream(outputStream.toByteArray());
    ReadChannel in = new ReadChannel(Channels.newChannel(inputStream));
    ByteBuffer result = ByteBuffer.allocate(32).order(ByteOrder.nativeOrder());
    in.readFully(result);
    result.rewind();

    // First message continuation, size, and 2 int values
    assertEquals(MessageSerializer.IPC_CONTINUATION_TOKEN, result.getInt());
    // message length is represented in little endian
    result.order(ByteOrder.LITTLE_ENDIAN);
    assertEquals(8, result.getInt());
    result.order(ByteOrder.nativeOrder());
    assertEquals(1, result.getInt());
    assertEquals(2, result.getInt());

    // Second message continuation, size, 1 int value and 4 bytes padding
    assertEquals(MessageSerializer.IPC_CONTINUATION_TOKEN, result.getInt());
    // message length is represented in little endian
    result.order(ByteOrder.LITTLE_ENDIAN);
    assertEquals(8, result.getInt());
    result.order(ByteOrder.nativeOrder());
    assertEquals(3, result.getInt());
    assertEquals(0, result.getInt());
  }

  @Test
  public void testSchemaMessageSerialization() throws IOException {
    Schema schema = testSchema();
    ByteArrayOutputStream out = new ByteArrayOutputStream();
    long size = MessageSerializer.serialize(new WriteChannel(Channels.newChannel(out)), schema);
    assertEquals(size, out.toByteArray().length);

    ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
    Schema deserialized =
        MessageSerializer.deserializeSchema(new ReadChannel(Channels.newChannel(in)));
    assertEquals(schema, deserialized);
    assertEquals(1, deserialized.getFields().size());
  }

  @Test
  public void testSchemaDictionaryMessageSerialization() throws IOException {
    DictionaryEncoding dictionary = new DictionaryEncoding(9L, false, new ArrowType.Int(8, true));
    Field field =
        new Field("test", new FieldType(true, ArrowType.Utf8.INSTANCE, dictionary, null), null);
    Schema schema = new Schema(Collections.singletonList(field));
    ByteArrayOutputStream out = new ByteArrayOutputStream();
    long size = MessageSerializer.serialize(new WriteChannel(Channels.newChannel(out)), schema);
    assertEquals(size, out.toByteArray().length);

    ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
    Schema deserialized =
        MessageSerializer.deserializeSchema(new ReadChannel(Channels.newChannel(in)));
    assertEquals(schema, deserialized);
  }

  @Test
  public void testSerializeRecordBatchV4() throws IOException {
    byte[] validity = new byte[] {(byte) 255, 0};
    // second half is "undefined"
    byte[] values = new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};

    BufferAllocator alloc = new RootAllocator(Long.MAX_VALUE);
    ArrowBuf validityb = buf(alloc, validity);
    ArrowBuf valuesb = buf(alloc, values);

    ArrowRecordBatch batch =
        new ArrowRecordBatch(16, asList(new ArrowFieldNode(16, 8)), asList(validityb, valuesb));

    // avoid writing legacy ipc format by default
    IpcOption option = new IpcOption(false, MetadataVersion.V4);
    ByteArrayOutputStream out = new ByteArrayOutputStream();
    MessageSerializer.serialize(new WriteChannel(Channels.newChannel(out)), batch, option);

    ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
    ReadChannel channel = new ReadChannel(Channels.newChannel(in));
    ArrowMessage deserialized = MessageSerializer.deserializeMessageBatch(channel, alloc);
    assertEquals(ArrowRecordBatch.class, deserialized.getClass());
    verifyBatch((ArrowRecordBatch) deserialized, validity, values);
  }

  @Test
  public void testSerializeRecordBatchV5() throws Exception {
    byte[] validity = new byte[] {(byte) 255, 0};
    // second half is "undefined"
    byte[] values = new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};

    BufferAllocator alloc = new RootAllocator(Long.MAX_VALUE);
    ArrowBuf validityb = buf(alloc, validity);
    ArrowBuf valuesb = buf(alloc, values);

    ArrowRecordBatch batch =
        new ArrowRecordBatch(16, asList(new ArrowFieldNode(16, 8)), asList(validityb, valuesb));

    // avoid writing legacy ipc format by default
    IpcOption option = new IpcOption(false, MetadataVersion.V5);
    ByteArrayOutputStream out = new ByteArrayOutputStream();
    MessageSerializer.serialize(new WriteChannel(Channels.newChannel(out)), batch, option);
    validityb.close();
    valuesb.close();
    batch.close();

    {
      ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
      ReadChannel channel = new ReadChannel(Channels.newChannel(in));
      ArrowMessage deserialized = MessageSerializer.deserializeMessageBatch(channel, alloc);
      assertEquals(ArrowRecordBatch.class, deserialized.getClass());
      verifyBatch((ArrowRecordBatch) deserialized, validity, values);
      deserialized.close();
    }

    {
      byte[] validBytes = out.toByteArray();
      byte[] missingBytes = Arrays.copyOfRange(validBytes, /*from=*/ 0, validBytes.length - 1);

      ByteArrayInputStream in = new ByteArrayInputStream(missingBytes);
      ReadChannel channel = new ReadChannel(Channels.newChannel(in));

      assertThrows(
          IOException.class, () -> MessageSerializer.deserializeMessageBatch(channel, alloc));
    }

    alloc.close();
  }

  public static Schema testSchema() {
    return new Schema(
        asList(
            new Field(
                "testField",
                FieldType.nullable(new ArrowType.Int(8, true)),
                Collections.<Field>emptyList())));
  }

  // Verifies batch contents matching test schema.
  public static void verifyBatch(ArrowRecordBatch batch, byte[] validity, byte[] values) {
    assertTrue(batch != null);
    List<ArrowFieldNode> nodes = batch.getNodes();
    assertEquals(1, nodes.size());
    ArrowFieldNode node = nodes.get(0);
    assertEquals(16, node.getLength());
    assertEquals(8, node.getNullCount());
    List<ArrowBuf> buffers = batch.getBuffers();
    assertEquals(2, buffers.size());
    assertArrayEquals(validity, MessageSerializerTest.array(buffers.get(0)));
    assertArrayEquals(values, MessageSerializerTest.array(buffers.get(1)));
  }
}