TestUIntDictionaryRoundTrip.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.vector.ipc;
import static org.apache.arrow.vector.testing.ValueVectorDataPopulator.setVector;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.channels.Channels;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Map;
import java.util.function.ToIntBiFunction;
import java.util.stream.Stream;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.UInt1Vector;
import org.apache.arrow.vector.UInt2Vector;
import org.apache.arrow.vector.UInt4Vector;
import org.apache.arrow.vector.UInt8Vector;
import org.apache.arrow.vector.ValueVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
/** Test the round-trip of dictionary encoding, with unsigned integer as indices. */
public class TestUIntDictionaryRoundTrip {
private BufferAllocator allocator;
private DictionaryProvider.MapDictionaryProvider dictionaryProvider;
@BeforeEach
public void init() {
allocator = new RootAllocator(Long.MAX_VALUE);
dictionaryProvider = new DictionaryProvider.MapDictionaryProvider();
}
@AfterEach
public void terminate() throws Exception {
allocator.close();
}
private byte[] writeData(boolean streamMode, FieldVector encodedVector) throws IOException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
VectorSchemaRoot root =
new VectorSchemaRoot(
Arrays.asList(encodedVector.getField()),
Arrays.asList(encodedVector),
encodedVector.getValueCount());
try (ArrowWriter writer =
streamMode
? new ArrowStreamWriter(root, dictionaryProvider, out)
: new ArrowFileWriter(root, dictionaryProvider, Channels.newChannel(out))) {
writer.start();
writer.writeBatch();
writer.end();
return out.toByteArray();
}
}
private void readData(
boolean streamMode,
byte[] data,
Field expectedField,
ToIntBiFunction<ValueVector, Integer> valGetter,
long dictionaryID,
int[] expectedIndices,
String[] expectedDictItems)
throws IOException {
try (ArrowReader reader =
streamMode
? new ArrowStreamReader(new ByteArrayInputStream(data), allocator)
: new ArrowFileReader(
new SeekableReadChannel(new ByteArrayReadableSeekableByteChannel(data)),
allocator)) {
// verify schema
Schema readSchema = reader.getVectorSchemaRoot().getSchema();
assertEquals(1, readSchema.getFields().size());
assertEquals(expectedField, readSchema.getFields().get(0));
// verify vector schema root
assertTrue(reader.loadNextBatch());
VectorSchemaRoot root = reader.getVectorSchemaRoot();
assertEquals(1, root.getFieldVectors().size());
ValueVector encodedVector = root.getVector(0);
assertEquals(expectedIndices.length, encodedVector.getValueCount());
for (int i = 0; i < expectedIndices.length; i++) {
assertEquals(expectedIndices[i], valGetter.applyAsInt(encodedVector, i));
}
// verify dictionary
Map<Long, Dictionary> dictVectors = reader.getDictionaryVectors();
assertEquals(1, dictVectors.size());
Dictionary dictionary = dictVectors.get(dictionaryID);
assertNotNull(dictionary);
assertTrue(dictionary.getVector() instanceof VarCharVector);
VarCharVector dictVector = (VarCharVector) dictionary.getVector();
assertEquals(expectedDictItems.length, dictVector.getValueCount());
for (int i = 0; i < dictVector.getValueCount(); i++) {
assertArrayEquals(expectedDictItems[i].getBytes(StandardCharsets.UTF_8), dictVector.get(i));
}
}
}
private ValueVector createEncodedVector(int bitWidth, VarCharVector dictionaryVector) {
final DictionaryEncoding dictionaryEncoding =
new DictionaryEncoding(bitWidth, false, new ArrowType.Int(bitWidth, false));
Dictionary dictionary = new Dictionary(dictionaryVector, dictionaryEncoding);
dictionaryProvider.put(dictionary);
final FieldType type =
new FieldType(true, dictionaryEncoding.getIndexType(), dictionaryEncoding, null);
final Field field = new Field("encoded", type, null);
return field.createVector(allocator);
}
@ParameterizedTest(name = "stream mode = {0}")
@MethodSource("getRepeat")
public void testUInt1RoundTrip(boolean streamMode) throws IOException {
final int vectorLength = UInt1Vector.MAX_UINT1 & UInt1Vector.PROMOTION_MASK;
try (VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator);
UInt1Vector encodedVector1 = (UInt1Vector) createEncodedVector(8, dictionaryVector)) {
int[] indices = new int[vectorLength];
String[] dictionaryItems = new String[vectorLength];
for (int i = 0; i < vectorLength; i++) {
encodedVector1.setSafe(i, (byte) i);
indices[i] = i;
dictionaryItems[i] = String.valueOf(i);
}
encodedVector1.setValueCount(vectorLength);
setVector(dictionaryVector, dictionaryItems);
byte[] data = writeData(streamMode, encodedVector1);
readData(
streamMode,
data,
encodedVector1.getField(),
(vector, index) -> (int) ((UInt1Vector) vector).getValueAsLong(index),
8L,
indices,
dictionaryItems);
}
}
@ParameterizedTest(name = "stream mode = {0}")
@MethodSource("getRepeat")
public void testUInt2RoundTrip(boolean streamMode) throws IOException {
try (VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator);
UInt2Vector encodedVector2 = (UInt2Vector) createEncodedVector(16, dictionaryVector)) {
int[] indices = new int[] {1, 3, 5, 7, 9, UInt2Vector.MAX_UINT2};
String[] dictItems = new String[UInt2Vector.MAX_UINT2];
for (int i = 0; i < UInt2Vector.MAX_UINT2; i++) {
dictItems[i] = String.valueOf(i);
}
setVector(
encodedVector2, (char) 1, (char) 3, (char) 5, (char) 7, (char) 9, UInt2Vector.MAX_UINT2);
setVector(dictionaryVector, dictItems);
byte[] data = writeData(streamMode, encodedVector2);
readData(
streamMode,
data,
encodedVector2.getField(),
(vector, index) -> (int) ((UInt2Vector) vector).getValueAsLong(index),
16L,
indices,
dictItems);
}
}
@ParameterizedTest(name = "stream mode = {0}")
@MethodSource("getRepeat")
public void testUInt4RoundTrip(boolean streamMode) throws IOException {
final int dictLength = 10;
try (VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator);
UInt4Vector encodedVector4 = (UInt4Vector) createEncodedVector(32, dictionaryVector)) {
int[] indices = new int[] {1, 3, 5, 7, 9};
String[] dictItems = new String[dictLength];
for (int i = 0; i < dictLength; i++) {
dictItems[i] = String.valueOf(i);
}
setVector(encodedVector4, 1, 3, 5, 7, 9);
setVector(dictionaryVector, dictItems);
setVector(encodedVector4, 1, 3, 5, 7, 9);
byte[] data = writeData(streamMode, encodedVector4);
readData(
streamMode,
data,
encodedVector4.getField(),
(vector, index) -> (int) ((UInt4Vector) vector).getValueAsLong(index),
32L,
indices,
dictItems);
}
}
@ParameterizedTest(name = "stream mode = {0}")
@MethodSource("getRepeat")
public void testUInt8RoundTrip(boolean streamMode) throws IOException {
final int dictLength = 10;
try (VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator);
UInt8Vector encodedVector8 = (UInt8Vector) createEncodedVector(64, dictionaryVector)) {
int[] indices = new int[] {1, 3, 5, 7, 9};
String[] dictItems = new String[dictLength];
for (int i = 0; i < dictLength; i++) {
dictItems[i] = String.valueOf(i);
}
setVector(encodedVector8, 1L, 3L, 5L, 7L, 9L);
setVector(dictionaryVector, dictItems);
byte[] data = writeData(streamMode, encodedVector8);
readData(
streamMode,
data,
encodedVector8.getField(),
(vector, index) -> (int) ((UInt8Vector) vector).getValueAsLong(index),
64L,
indices,
dictItems);
}
}
static Stream<Arguments> getRepeat() {
return Stream.of(Arguments.of(true), Arguments.of(false));
}
}