ApproxEqualsVisitor.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.vector.compare;
import java.util.function.BiFunction;
import org.apache.arrow.vector.BaseFixedWidthVector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.ValueVector;
import org.apache.arrow.vector.compare.util.ValueEpsilonEqualizers;
/** Visitor to compare floating point vectors approximately. */
public class ApproxEqualsVisitor extends RangeEqualsVisitor {
/** Functions to calculate difference between float/double values. */
private final VectorValueEqualizer<Float4Vector> floatDiffFunction;
private final VectorValueEqualizer<Float8Vector> doubleDiffFunction;
/** Default epsilons for diff functions. */
public static final float DEFAULT_FLOAT_EPSILON = 1.0E-6f;
public static final double DEFAULT_DOUBLE_EPSILON = 1.0E-6;
/**
* Constructs a new instance with default tolerances.
*
* @param left left vector
* @param right right vector
*/
public ApproxEqualsVisitor(ValueVector left, ValueVector right) {
this(left, right, DEFAULT_FLOAT_EPSILON, DEFAULT_DOUBLE_EPSILON);
}
/**
* Constructs a new instance.
*
* @param left left vector
* @param right right vector
* @param floatEpsilon difference for float values
* @param doubleEpsilon difference for double values
*/
public ApproxEqualsVisitor(
ValueVector left, ValueVector right, float floatEpsilon, double doubleEpsilon) {
this(
left,
right,
new ValueEpsilonEqualizers.Float4EpsilonEqualizer(floatEpsilon),
new ValueEpsilonEqualizers.Float8EpsilonEqualizer(doubleEpsilon));
}
/** Constructs a new instance. */
public ApproxEqualsVisitor(
ValueVector left,
ValueVector right,
VectorValueEqualizer<Float4Vector> floatDiffFunction,
VectorValueEqualizer<Float8Vector> doubleDiffFunction) {
this(left, right, floatDiffFunction, doubleDiffFunction, DEFAULT_TYPE_COMPARATOR);
}
/**
* Constructs a new instance.
*
* @param left the left vector.
* @param right the right vector.
* @param floatDiffFunction the equalizer for float values.
* @param doubleDiffFunction the equalizer for double values.
* @param typeComparator type comparator to compare vector type.
*/
public ApproxEqualsVisitor(
ValueVector left,
ValueVector right,
VectorValueEqualizer<Float4Vector> floatDiffFunction,
VectorValueEqualizer<Float8Vector> doubleDiffFunction,
BiFunction<ValueVector, ValueVector, Boolean> typeComparator) {
super(left, right, typeComparator);
this.floatDiffFunction = floatDiffFunction;
this.doubleDiffFunction = doubleDiffFunction;
}
@Override
public Boolean visit(BaseFixedWidthVector left, Range range) {
if (left instanceof Float4Vector) {
if (!validate(left)) {
return false;
}
return float4ApproxEquals(range);
} else if (left instanceof Float8Vector) {
if (!validate(left)) {
return false;
}
return float8ApproxEquals(range);
} else {
return super.visit(left, range);
}
}
@Override
protected ApproxEqualsVisitor createInnerVisitor(
ValueVector left,
ValueVector right,
BiFunction<ValueVector, ValueVector, Boolean> typeComparator) {
return new ApproxEqualsVisitor(
left, right, floatDiffFunction.clone(), doubleDiffFunction.clone(), typeComparator);
}
private boolean float4ApproxEquals(Range range) {
Float4Vector leftVector = (Float4Vector) getLeft();
Float4Vector rightVector = (Float4Vector) getRight();
for (int i = 0; i < range.getLength(); i++) {
int leftIndex = range.getLeftStart() + i;
int rightIndex = range.getRightStart() + i;
if (!floatDiffFunction.valuesEqual(leftVector, leftIndex, rightVector, rightIndex)) {
return false;
}
}
return true;
}
private boolean float8ApproxEquals(Range range) {
Float8Vector leftVector = (Float8Vector) getLeft();
Float8Vector rightVector = (Float8Vector) getRight();
for (int i = 0; i < range.getLength(); i++) {
int leftIndex = range.getLeftStart() + i;
int rightIndex = range.getRightStart() + i;
if (!doubleDiffFunction.valuesEqual(leftVector, leftIndex, rightVector, rightIndex)) {
return false;
}
}
return true;
}
}