PartialSumUtils.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.algorithm.misc;

import org.apache.arrow.vector.BaseIntVector;

/** Partial sum related utilities. */
public class PartialSumUtils {

  /**
   * Converts an input vector to a partial sum vector. This is an inverse operation of {@link
   * PartialSumUtils#toDeltaVector(BaseIntVector, BaseIntVector)}. Suppose we have input vector a
   * and output vector b. Then we have b(0) = sumBase; b(i + 1) = b(i) + a(i) (i = 0, 1, 2, ...).
   *
   * @param deltaVector the input vector.
   * @param partialSumVector the output vector.
   * @param sumBase the base of the partial sums.
   */
  public static void toPartialSumVector(
      BaseIntVector deltaVector, BaseIntVector partialSumVector, long sumBase) {
    long sum = sumBase;
    partialSumVector.setWithPossibleTruncate(0, sumBase);

    for (int i = 0; i < deltaVector.getValueCount(); i++) {
      sum += deltaVector.getValueAsLong(i);
      partialSumVector.setWithPossibleTruncate(i + 1, sum);
    }
    partialSumVector.setValueCount(deltaVector.getValueCount() + 1);
  }

  /**
   * Converts an input vector to the delta vector. This is an inverse operation of {@link
   * PartialSumUtils#toPartialSumVector(BaseIntVector, BaseIntVector, long)}. Suppose we have input
   * vector a and output vector b. Then we have b(i) = a(i + 1) - a(i) (i = 0, 1, 2, ...).
   *
   * @param partialSumVector the input vector.
   * @param deltaVector the output vector.
   */
  public static void toDeltaVector(BaseIntVector partialSumVector, BaseIntVector deltaVector) {
    for (int i = 0; i < partialSumVector.getValueCount() - 1; i++) {
      long delta = partialSumVector.getValueAsLong(i + 1) - partialSumVector.getValueAsLong(i);
      deltaVector.setWithPossibleTruncate(i, delta);
    }
    deltaVector.setValueCount(partialSumVector.getValueCount() - 1);
  }

  /**
   * Given a value and a partial sum vector, finds its position in the partial sum vector. In
   * particular, given an integer value a and partial sum vector v, we try to find a position i, so
   * that v(i) <= a < v(i + 1). The algorithm is based on binary search, so it takes O(log(n)) time,
   * where n is the length of the partial sum vector.
   *
   * @param partialSumVector the input partial sum vector.
   * @param value the value to search.
   * @return the position in the partial sum vector, if any, or -1, if none is found.
   */
  public static int findPositionInPartialSumVector(BaseIntVector partialSumVector, long value) {
    if (value < partialSumVector.getValueAsLong(0)
        || value >= partialSumVector.getValueAsLong(partialSumVector.getValueCount() - 1)) {
      return -1;
    }

    int low = 0;
    int high = partialSumVector.getValueCount() - 1;
    while (low <= high) {
      int mid = low + (high - low) / 2;
      long midValue = partialSumVector.getValueAsLong(mid);

      if (midValue <= value) {
        if (mid == partialSumVector.getValueCount() - 1) {
          // the mid is the last element, we have found it
          return mid;
        }
        long nextMidValue = partialSumVector.getValueAsLong(mid + 1);
        if (value < nextMidValue) {
          // midValue <= value < nextMidValue
          // this is exactly what we want.
          return mid;
        } else {
          // value >= nextMidValue
          // continue to search from the next value on the right
          low = mid + 1;
        }
      } else {
        // midValue > value
        long prevMidValue = partialSumVector.getValueAsLong(mid - 1);
        if (prevMidValue <= value) {
          // prevMidValue <= value < midValue
          // this is exactly what we want
          return mid - 1;
        } else {
          // prevMidValue > value
          // continue to search from the previous value on the left
          high = mid - 1;
        }
      }
    }
    throw new IllegalStateException("Should never get here");
  }

  private PartialSumUtils() {}
}