ArrowWriter.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.vector.ipc;

import java.io.IOException;
import java.nio.channels.WritableByteChannel;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.compression.CompressionCodec;
import org.apache.arrow.vector.compression.CompressionUtil;
import org.apache.arrow.vector.compression.NoCompressionCodec;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.message.ArrowBlock;
import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.DictionaryUtility;
import org.apache.arrow.vector.validate.MetadataV4UnionChecker;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Abstract base class for implementing Arrow writers for IPC over a WriteChannel. */
public abstract class ArrowWriter implements AutoCloseable {

  protected static final Logger LOGGER = LoggerFactory.getLogger(ArrowWriter.class);

  // schema with fields in message format, not memory format
  protected final Schema schema;
  protected final WriteChannel out;

  private final VectorUnloader unloader;
  private final DictionaryProvider dictionaryProvider;
  private final Set<Long> dictionaryIdsUsed = new HashSet<>();

  private final CompressionCodec.Factory compressionFactory;
  private final CompressionUtil.CodecType codecType;
  private final Optional<Integer> compressionLevel;
  private boolean started = false;
  private boolean ended = false;

  private final CompressionCodec codec;

  protected IpcOption option;

  protected ArrowWriter(
      VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) {
    this(root, provider, out, IpcOption.DEFAULT);
  }

  protected ArrowWriter(
      VectorSchemaRoot root,
      DictionaryProvider provider,
      WritableByteChannel out,
      IpcOption option) {
    this(
        root,
        provider,
        out,
        option,
        NoCompressionCodec.Factory.INSTANCE,
        CompressionUtil.CodecType.NO_COMPRESSION,
        Optional.empty());
  }

  /**
   * Note: fields are not closed when the writer is closed.
   *
   * @param root the vectors to write to the output
   * @param provider where to find the dictionaries
   * @param out the output where to write
   * @param option IPC write options
   * @param compressionFactory Compression codec factory
   * @param codecType Compression codec
   * @param compressionLevel Compression level
   */
  protected ArrowWriter(
      VectorSchemaRoot root,
      DictionaryProvider provider,
      WritableByteChannel out,
      IpcOption option,
      CompressionCodec.Factory compressionFactory,
      CompressionUtil.CodecType codecType,
      Optional<Integer> compressionLevel) {
    this.out = new WriteChannel(out);
    this.option = option;
    this.dictionaryProvider = provider;

    this.compressionFactory = compressionFactory;
    this.codecType = codecType;
    this.compressionLevel = compressionLevel;
    this.codec =
        this.compressionLevel.isPresent()
            ? this.compressionFactory.createCodec(this.codecType, this.compressionLevel.get())
            : this.compressionFactory.createCodec(this.codecType);
    this.unloader =
        new VectorUnloader(root, /*includeNullCount*/ true, codec, /*alignBuffers*/ true);

    List<Field> fields = new ArrayList<>(root.getSchema().getFields().size());

    MetadataV4UnionChecker.checkForUnion(
        root.getSchema().getFields().iterator(), option.metadataVersion);
    // Convert fields with dictionaries to have dictionary type
    for (Field field : root.getSchema().getFields()) {
      fields.add(DictionaryUtility.toMessageFormat(field, provider, dictionaryIdsUsed));
    }

    this.schema = new Schema(fields, root.getSchema().getCustomMetadata());
  }

  public void start() throws IOException {
    ensureStarted();
  }

  /** Writes the record batch currently loaded in this instance's VectorSchemaRoot. */
  public void writeBatch() throws IOException {
    ensureStarted();
    ensureDictionariesWritten(dictionaryProvider, dictionaryIdsUsed);
    try (ArrowRecordBatch batch = unloader.getRecordBatch()) {
      writeRecordBatch(batch);
    }
  }

  protected void writeDictionaryBatch(Dictionary dictionary) throws IOException {
    FieldVector vector = dictionary.getVector();
    long id = dictionary.getEncoding().getId();
    int count = vector.getValueCount();
    VectorSchemaRoot dictRoot =
        new VectorSchemaRoot(
            Collections.singletonList(vector.getField()), Collections.singletonList(vector), count);
    VectorUnloader unloader =
        new VectorUnloader(dictRoot, /*includeNullCount*/ true, this.codec, /*alignBuffers*/ true);
    ArrowRecordBatch batch = unloader.getRecordBatch();
    ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch(id, batch, false);
    try {
      writeDictionaryBatch(dictionaryBatch);
    } finally {
      try {
        dictionaryBatch.close();
      } catch (Exception e) {
        throw new RuntimeException("Error occurred while closing dictionary.", e);
      }
    }
  }

  protected ArrowBlock writeDictionaryBatch(ArrowDictionaryBatch batch) throws IOException {
    ArrowBlock block = MessageSerializer.serialize(out, batch, option);
    if (LOGGER.isDebugEnabled()) {
      LOGGER.debug(
          "DictionaryRecordBatch at {}, metadata: {}, body: {}",
          block.getOffset(),
          block.getMetadataLength(),
          block.getBodyLength());
    }
    return block;
  }

  protected ArrowBlock writeRecordBatch(ArrowRecordBatch batch) throws IOException {
    ArrowBlock block = MessageSerializer.serialize(out, batch, option);
    if (LOGGER.isDebugEnabled()) {
      LOGGER.debug(
          "RecordBatch at {}, metadata: {}, body: {}",
          block.getOffset(),
          block.getMetadataLength(),
          block.getBodyLength());
    }
    return block;
  }

  public void end() throws IOException {
    ensureStarted();
    ensureEnded();
  }

  public long bytesWritten() {
    return out.getCurrentPosition();
  }

  private void ensureStarted() throws IOException {
    if (!started) {
      started = true;
      startInternal(out);
      // write the schema - for file formats this is duplicated in the footer, but matches
      // the streaming format
      MessageSerializer.serialize(out, schema, option);
    }
  }

  /**
   * Write dictionaries after schema and before recordBatches, dictionaries won't be written if
   * empty stream (only has schema data in IPC).
   */
  protected abstract void ensureDictionariesWritten(
      DictionaryProvider provider, Set<Long> dictionaryIdsUsed) throws IOException;

  private void ensureEnded() throws IOException {
    if (!ended) {
      ended = true;
      endInternal(out);
    }
  }

  protected void startInternal(WriteChannel out) throws IOException {}

  protected void endInternal(WriteChannel out) throws IOException {}

  @Override
  public void close() {
    try {
      end();
      out.close();
    } catch (Exception e) {
      throw new RuntimeException(e);
    }
  }
}