IndexSorter.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.algorithm.sort;
import java.util.stream.IntStream;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.ValueVector;
/**
* Sorter for the indices of a vector.
*
* @param <V> vector type.
*/
public class IndexSorter<V extends ValueVector> {
/**
* If the number of items is smaller than this threshold, we will use another algorithm to sort
* the data.
*/
public static final int CHANGE_ALGORITHM_THRESHOLD = 15;
/** Comparator for vector indices. */
private VectorValueComparator<V> comparator;
/** Vector indices to sort. */
private IntVector indices;
/**
* Sorts indices, by quick-sort. Suppose the vector is denoted by v. After calling this method,
* the following relations hold: v(indices[0]) <= v(indices[1]) <= ...
*
* @param vector the vector whose indices need to be sorted.
* @param indices the vector for storing the sorted indices.
* @param comparator the comparator to sort indices.
*/
public void sort(V vector, IntVector indices, VectorValueComparator<V> comparator) {
comparator.attachVector(vector);
this.indices = indices;
IntStream.range(0, vector.getValueCount()).forEach(i -> indices.set(i, i));
this.comparator = comparator;
quickSort();
}
private void quickSort() {
try (OffHeapIntStack rangeStack = new OffHeapIntStack(indices.getAllocator())) {
rangeStack.push(0);
rangeStack.push(indices.getValueCount() - 1);
while (!rangeStack.isEmpty()) {
int high = rangeStack.pop();
int low = rangeStack.pop();
if (low < high) {
if (high - low < CHANGE_ALGORITHM_THRESHOLD) {
InsertionSorter.insertionSort(indices, low, high, comparator);
continue;
}
int mid = partition(low, high, indices, comparator);
// push the larger part to stack first,
// to reduce the required stack size
if (high - mid < mid - low) {
rangeStack.push(low);
rangeStack.push(mid - 1);
rangeStack.push(mid + 1);
rangeStack.push(high);
} else {
rangeStack.push(mid + 1);
rangeStack.push(high);
rangeStack.push(low);
rangeStack.push(mid - 1);
}
}
}
}
}
/** Select the pivot as the median of 3 samples. */
static <T extends ValueVector> int choosePivot(
int low, int high, IntVector indices, VectorValueComparator<T> comparator) {
// we need at least 3 items
if (high - low + 1 < FixedWidthInPlaceVectorSorter.STOP_CHOOSING_PIVOT_THRESHOLD) {
return indices.get(low);
}
int mid = low + (high - low) / 2;
// find the median by at most 3 comparisons
int medianIdx;
if (comparator.compare(indices.get(low), indices.get(mid)) < 0) {
if (comparator.compare(indices.get(mid), indices.get(high)) < 0) {
medianIdx = mid;
} else {
if (comparator.compare(indices.get(low), indices.get(high)) < 0) {
medianIdx = high;
} else {
medianIdx = low;
}
}
} else {
if (comparator.compare(indices.get(mid), indices.get(high)) > 0) {
medianIdx = mid;
} else {
if (comparator.compare(indices.get(low), indices.get(high)) < 0) {
medianIdx = low;
} else {
medianIdx = high;
}
}
}
// move the pivot to the low position, if necessary
if (medianIdx != low) {
int tmp = indices.get(medianIdx);
indices.set(medianIdx, indices.get(low));
indices.set(low, tmp);
return tmp;
} else {
return indices.get(low);
}
}
/**
* Partition a range of values in a vector into two parts, with elements in one part smaller than
* elements from the other part. The partition is based on the element indices, so it does not
* modify the underlying vector.
*
* @param low the lower bound of the range.
* @param high the upper bound of the range.
* @param indices vector element indices.
* @param comparator criteria for comparison.
* @param <T> the vector type.
* @return the index of the split point.
*/
public static <T extends ValueVector> int partition(
int low, int high, IntVector indices, VectorValueComparator<T> comparator) {
int pivotIndex = choosePivot(low, high, indices, comparator);
while (low < high) {
while (low < high && comparator.compare(indices.get(high), pivotIndex) >= 0) {
high -= 1;
}
indices.set(low, indices.get(high));
while (low < high && comparator.compare(indices.get(low), pivotIndex) <= 0) {
low += 1;
}
indices.set(high, indices.get(low));
}
indices.set(low, pivotIndex);
return low;
}
}