TestVectorSchemaRoot.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.vector;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.impl.UnionListWriter;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
public class TestVectorSchemaRoot {
private BufferAllocator allocator;
@BeforeEach
public void init() {
allocator = new RootAllocator(Long.MAX_VALUE);
}
@AfterEach
public void terminate() {
allocator.close();
}
@Test
public void testResetRowCount() {
final int size = 20;
try (final BitVector vec1 = new BitVector("bit", allocator);
final IntVector vec2 = new IntVector("int", allocator)) {
VectorSchemaRoot vsr = VectorSchemaRoot.of(vec1, vec2);
vsr.allocateNew();
assertEquals(0, vsr.getRowCount());
for (int i = 0; i < size; i++) {
vec1.setSafe(i, i % 2);
vec2.setSafe(i, i);
}
vsr.setRowCount(size);
checkCount(vec1, vec2, vsr, size);
vsr.allocateNew();
checkCount(vec1, vec2, vsr, 0);
for (int i = 0; i < size; i++) {
vec1.setSafe(i, i % 2);
vec2.setSafe(i, i);
}
vsr.setRowCount(size);
checkCount(vec1, vec2, vsr, size);
vsr.clear();
checkCount(vec1, vec2, vsr, 0);
}
}
private void checkCount(BitVector vec1, IntVector vec2, VectorSchemaRoot vsr, int count) {
assertEquals(vec1.getValueCount(), count);
assertEquals(vec2.getValueCount(), count);
assertEquals(vsr.getRowCount(), count);
}
private VectorSchemaRoot createBatch() {
FieldType varCharType = new FieldType(true, new ArrowType.Utf8(), /*dictionary=*/ null);
FieldType listType = new FieldType(true, new ArrowType.List(), /*dictionary=*/ null);
// create the schema
List<Field> schemaFields = new ArrayList<>();
Field childField = new Field("varCharCol", varCharType, null);
List<Field> childFields = new ArrayList<>();
childFields.add(childField);
schemaFields.add(new Field("listCol", listType, childFields));
Schema schema = new Schema(schemaFields);
VectorSchemaRoot schemaRoot = VectorSchemaRoot.create(schema, allocator);
// get and allocate the vector
ListVector vector = (ListVector) schemaRoot.getVector("listCol");
vector.allocateNew();
// write data to the vector
UnionListWriter writer = vector.getWriter();
writer.setPosition(0);
// write data vector(0)
writer.startList();
// write data vector(0)(0)
writer.list().startList();
// According to the schema above, the list element should have varchar type.
// When we write a big int, the original writer cannot handle this, so the writer will
// be promoted, and the vector structure will be different from the schema.
writer.list().bigInt().writeBigInt(0);
writer.list().bigInt().writeBigInt(1);
writer.list().endList();
// write data vector(0)(1)
writer.list().startList();
writer.list().float8().writeFloat8(3.0D);
writer.list().float8().writeFloat8(7.0D);
writer.list().endList();
// finish data vector(0)
writer.endList();
writer.setPosition(1);
// write data vector(1)
writer.startList();
// write data vector(1)(0)
writer.list().startList();
writer.list().integer().writeInt(3);
writer.list().integer().writeInt(2);
writer.list().endList();
// finish data vector(1)
writer.endList();
vector.setValueCount(2);
return schemaRoot;
}
@Test
public void testAddVector() {
try (final IntVector intVector1 = new IntVector("intVector1", allocator);
final IntVector intVector2 = new IntVector("intVector2", allocator);
final IntVector intVector3 = new IntVector("intVector3", allocator); ) {
VectorSchemaRoot original = new VectorSchemaRoot(Arrays.asList(intVector1, intVector2));
assertEquals(2, original.getFieldVectors().size());
VectorSchemaRoot newRecordBatch = original.addVector(1, intVector3);
assertEquals(3, newRecordBatch.getFieldVectors().size());
assertEquals(intVector3, newRecordBatch.getFieldVectors().get(1));
original.close();
newRecordBatch.close();
}
}
@Test
public void testRemoveVector() {
try (final IntVector intVector1 = new IntVector("intVector1", allocator);
final IntVector intVector2 = new IntVector("intVector2", allocator);
final IntVector intVector3 = new IntVector("intVector3", allocator); ) {
VectorSchemaRoot original =
new VectorSchemaRoot(Arrays.asList(intVector1, intVector2, intVector3));
assertEquals(3, original.getFieldVectors().size());
VectorSchemaRoot newRecordBatch = original.removeVector(0);
assertEquals(2, newRecordBatch.getFieldVectors().size());
assertEquals(intVector2, newRecordBatch.getFieldVectors().get(0));
assertEquals(intVector3, newRecordBatch.getFieldVectors().get(1));
original.close();
newRecordBatch.close();
}
}
@Test
public void testSlice() {
try (final IntVector intVector = new IntVector("intVector", allocator);
final Float4Vector float4Vector = new Float4Vector("float4Vector", allocator)) {
final int numRows = 10;
intVector.setValueCount(numRows);
float4Vector.setValueCount(numRows);
for (int i = 0; i < numRows; i++) {
intVector.setSafe(i, i);
float4Vector.setSafe(i, i + 0.1f);
}
final VectorSchemaRoot original =
new VectorSchemaRoot(Arrays.asList(intVector, float4Vector));
for (int sliceIndex = 0; sliceIndex < numRows; sliceIndex++) {
for (int sliceLength = 0; sliceIndex + sliceLength <= numRows; sliceLength++) {
try (VectorSchemaRoot slice = original.slice(sliceIndex, sliceLength)) {
assertEquals(sliceLength, slice.getRowCount());
// validate data
final IntVector childIntVector = (IntVector) slice.getFieldVectors().get(0);
final Float4Vector childFloatVector = (Float4Vector) slice.getFieldVectors().get(1);
for (int i = 0; i < sliceLength; i++) {
final int originalIndex = i + sliceIndex;
assertEquals(originalIndex, childIntVector.get(i));
assertEquals(originalIndex + 0.1f, childFloatVector.get(i), 0);
}
}
}
}
original.close();
}
}
@Test
public void testSliceWithInvalidParam() {
assertThrows(
IllegalArgumentException.class,
() -> {
try (final IntVector intVector = new IntVector("intVector", allocator);
final Float4Vector float4Vector = new Float4Vector("float4Vector", allocator)) {
intVector.setValueCount(10);
float4Vector.setValueCount(10);
for (int i = 0; i < 10; i++) {
intVector.setSafe(i, i);
float4Vector.setSafe(i, i + 0.1f);
}
final VectorSchemaRoot original =
new VectorSchemaRoot(Arrays.asList(intVector, float4Vector));
original.slice(0, 20);
}
});
}
@Test
public void testEquals() {
try (final IntVector intVector1 = new IntVector("intVector1", allocator);
final IntVector intVector2 = new IntVector("intVector2", allocator);
final IntVector intVector3 = new IntVector("intVector3", allocator); ) {
intVector1.setValueCount(5);
for (int i = 0; i < 5; i++) {
intVector1.set(i, i);
}
VectorSchemaRoot root1 =
new VectorSchemaRoot(Arrays.asList(intVector1, intVector2, intVector3));
VectorSchemaRoot root2 = new VectorSchemaRoot(Arrays.asList(intVector1, intVector2));
VectorSchemaRoot root3 =
new VectorSchemaRoot(Arrays.asList(intVector1, intVector2, intVector3));
assertFalse(root1.equals(root2));
assertTrue(root1.equals(root3));
root1.close();
root2.close();
root3.close();
}
}
@Test
public void testApproxEquals() {
try (final Float4Vector float4Vector1 = new Float4Vector("floatVector", allocator);
final Float4Vector float4Vector2 = new Float4Vector("floatVector", allocator);
final Float4Vector float4Vector3 = new Float4Vector("floatVector", allocator); ) {
float4Vector1.setValueCount(5);
float4Vector2.setValueCount(5);
float4Vector3.setValueCount(5);
final float epsilon = 1.0E-6f;
for (int i = 0; i < 5; i++) {
float4Vector1.set(i, i);
float4Vector2.set(i, i + epsilon * 2);
float4Vector3.set(i, i + epsilon / 2);
}
VectorSchemaRoot root1 = new VectorSchemaRoot(Arrays.asList(float4Vector1));
VectorSchemaRoot root2 = new VectorSchemaRoot(Arrays.asList(float4Vector2));
VectorSchemaRoot root3 = new VectorSchemaRoot(Arrays.asList(float4Vector3));
assertFalse(root1.approxEquals(root2));
assertTrue(root1.approxEquals(root3));
root1.close();
root2.close();
root3.close();
}
}
@Test
public void testSchemaSync() {
// create vector schema root
try (VectorSchemaRoot schemaRoot = createBatch()) {
Schema newSchema =
new Schema(
schemaRoot.getFieldVectors().stream()
.map(vec -> vec.getField())
.collect(Collectors.toList()));
assertNotEquals(newSchema, schemaRoot.getSchema());
assertTrue(schemaRoot.syncSchema());
assertEquals(newSchema, schemaRoot.getSchema());
// no schema update this time.
assertFalse(schemaRoot.syncSchema());
}
}
}