View Javadoc
1   /*
2    * Copyright (C) 2014 The Guava Authors
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    * http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  package com.google.common.collect;
18  
19  import static com.google.common.base.Preconditions.checkArgument;
20  import static com.google.common.base.Preconditions.checkNotNull;
21  
22  import com.google.common.annotations.GwtCompatible;
23  import com.google.common.math.IntMath;
24  import java.math.RoundingMode;
25  import java.util.Arrays;
26  import java.util.Collections;
27  import java.util.Comparator;
28  import java.util.Iterator;
29  import java.util.List;
30  import java.util.stream.Stream;
31  import javax.annotation.Nullable;
32  
33  /**
34   * An accumulator that selects the "top" {@code k} elements added to it, relative to a provided
35   * comparator. "Top" can mean the greatest or the lowest elements, specified in the factory used to
36   * create the {@code TopKSelector} instance.
37   *
38   * <p>If your input data is available as a {@link Stream}, prefer passing {@link
39   * Comparators#least(int)} to {@link Stream#collect(java.util.stream.Collector)}. If it is available
40   * as an {@link Iterable} or {@link Iterator}, prefer {@link Ordering#leastOf(Iterable, int)}.
41   *
42   * <p>This uses the same efficient implementation as {@link Ordering#leastOf(Iterable, int)},
43   * offering expected O(n + k log k) performance (worst case O(n log k)) for n calls to {@link
44   * #offer} and a call to {@link #topK}, with O(k) memory. In comparison, quickselect has the same
45   * asymptotics but requires O(n) memory, and a {@code PriorityQueue} implementation takes O(n log
46   * k). In benchmarks, this implementation performs at least as well as either implementation, and
47   * degrades more gracefully for worst-case input.
48   *
49   * <p>The implementation does not necessarily use a <i>stable</i> sorting algorithm; when multiple
50   * equivalent elements are added to it, it is undefined which will come first in the output.
51   *
52   * @author Louis Wasserman
53   */
54  @GwtCompatible final class TopKSelector<T> {
55  
56    /**
57     * Returns a {@code TopKSelector} that collects the lowest {@code k} elements added to it,
58     * relative to the natural ordering of the elements, and returns them via {@link #topK} in
59     * ascending order.
60     *
61     * @throws IllegalArgumentException if {@code k < 0}
62     */
63    public static <T extends Comparable<? super T>> TopKSelector<T> least(int k) {
64      return least(k, Ordering.natural());
65    }
66  
67    /**
68     * Returns a {@code TopKSelector} that collects the greatest {@code k} elements added to it,
69     * relative to the natural ordering of the elements, and returns them via {@link #topK} in
70     * descending order.
71     *
72     * @throws IllegalArgumentException if {@code k < 0}
73     */
74    public static <T extends Comparable<? super T>> TopKSelector<T> greatest(int k) {
75      return greatest(k, Ordering.natural());
76    }
77  
78    /**
79     * Returns a {@code TopKSelector} that collects the lowest {@code k} elements added to it,
80     * relative to the specified comparator, and returns them via {@link #topK} in ascending order.
81     *
82     * @throws IllegalArgumentException if {@code k < 0}
83     */
84    public static <T> TopKSelector<T> least(int k, Comparator<? super T> comparator) {
85      return new TopKSelector<T>(comparator, k);
86    }
87  
88    /**
89     * Returns a {@code TopKSelector} that collects the greatest {@code k} elements added to it,
90     * relative to the specified comparator, and returns them via {@link #topK} in descending order.
91     *
92     * @throws IllegalArgumentException if {@code k < 0}
93     */
94    public static <T> TopKSelector<T> greatest(int k, Comparator<? super T> comparator) {
95      return new TopKSelector<T>(Ordering.from(comparator).reverse(), k);
96    }
97  
98    private final int k;
99    private final Comparator<? super T> comparator;
100 
101   /*
102    * We are currently considering the elements in buffer in the range [0, bufferSize) as candidates
103    * for the top k elements. Whenever the buffer is filled, we quickselect the top k elements to the
104    * range [0, k) and ignore the remaining elements.
105    */
106   private final T[] buffer;
107   private int bufferSize;
108 
109   /**
110    * The largest of the lowest k elements we've seen so far relative to this comparator. If
111    * bufferSize ≥ k, then we can ignore any elements greater than this value.
112    */
113   private T threshold;
114 
115   private TopKSelector(Comparator<? super T> comparator, int k) {
116     this.comparator = checkNotNull(comparator, "comparator");
117     this.k = k;
118     checkArgument(k >= 0, "k must be nonnegative, was %s", k);
119     this.buffer = (T[]) new Object[k * 2];
120     this.bufferSize = 0;
121     this.threshold = null;
122   }
123 
124   /**
125    * Adds {@code elem} as a candidate for the top {@code k} elements. This operation takes
126    * amortized O(1) time.
127    */
128   public void offer(@Nullable T elem) {
129     if (k == 0) {
130       return;
131     } else if (bufferSize == 0) {
132       buffer[0] = elem;
133       threshold = elem;
134       bufferSize = 1;
135     } else if (bufferSize < k) {
136       buffer[bufferSize++] = elem;
137       if (comparator.compare(elem, threshold) > 0) {
138         threshold = elem;
139       }
140     } else if (comparator.compare(elem, threshold) < 0) {
141       // Otherwise, we can ignore elem; we've seen k better elements.
142       buffer[bufferSize++] = elem;
143       if (bufferSize == 2 * k) {
144         trim();
145       }
146     }
147   }
148 
149   /**
150    * Quickselects the top k elements from the 2k elements in the buffer.  O(k) expected time,
151    * O(k log k) worst case.
152    */
153   private void trim() {
154     int left = 0;
155     int right = 2 * k - 1;
156 
157     int minThresholdPosition = 0;
158     // The leftmost position at which the greatest of the k lower elements
159     // -- the new value of threshold -- might be found.
160 
161     int iterations = 0;
162     int maxIterations = IntMath.log2(right - left, RoundingMode.CEILING) * 3;
163     while (left < right) {
164       int pivotIndex = (left + right + 1) >>> 1;
165 
166       int pivotNewIndex = partition(left, right, pivotIndex);
167 
168       if (pivotNewIndex > k) {
169         right = pivotNewIndex - 1;
170       } else if (pivotNewIndex < k) {
171         left = Math.max(pivotNewIndex, left + 1);
172         minThresholdPosition = pivotNewIndex;
173       } else {
174         break;
175       }
176       iterations++;
177       if (iterations >= maxIterations) {
178         // We've already taken O(k log k), let's make sure we don't take longer than O(k log k).
179         Arrays.sort(buffer, left, right, comparator);
180         break;
181       }
182     }
183     bufferSize = k;
184 
185     threshold = buffer[minThresholdPosition];
186     for (int i = minThresholdPosition + 1; i < k; i++) {
187       if (comparator.compare(buffer[i], threshold) > 0) {
188         threshold = buffer[i];
189       }
190     }
191   }
192 
193   /**
194    * Partitions the contents of buffer in the range [left, right] around the pivot element
195    * previously stored in buffer[pivotValue]. Returns the new index of the pivot element,
196    * pivotNewIndex, so that everything in [left, pivotNewIndex] is ≤ pivotValue and everything in
197    * (pivotNewIndex, right] is greater than pivotValue.
198    */
199   private int partition(int left, int right, int pivotIndex) {
200     T pivotValue = buffer[pivotIndex];
201     buffer[pivotIndex] = buffer[right];
202 
203     int pivotNewIndex = left;
204     for (int i = left; i < right; i++) {
205       if (comparator.compare(buffer[i], pivotValue) < 0) {
206         swap(pivotNewIndex, i);
207         pivotNewIndex++;
208       }
209     }
210     buffer[right] = buffer[pivotNewIndex];
211     buffer[pivotNewIndex] = pivotValue;
212     return pivotNewIndex;
213   }
214 
215   private void swap(int i, int j) {
216     T tmp = buffer[i];
217     buffer[i] = buffer[j];
218     buffer[j] = tmp;
219   }
220 
221   TopKSelector<T> combine(TopKSelector<T> other) {
222     for (int i = 0; i < other.bufferSize; i++) {
223       this.offer(other.buffer[i]);
224     }
225     return this;
226   }
227 
228   /**
229    * Adds each member of {@code elements} as a candidate for the top {@code k} elements. This
230    * operation takes amortized linear time in the length of {@code elements}.
231    *
232    * <p>If all input data to this {@code TopKSelector} is in a single {@code Iterable},
233    * prefer {@link Ordering#leastOf(Iterable, int)}, which provides a simpler API for that use
234    * case.
235    */
236   public void offerAll(Iterable<? extends T> elements) {
237     offerAll(elements.iterator());
238   }
239 
240   /**
241    * Adds each member of {@code elements} as a candidate for the top {@code k} elements. This
242    * operation takes amortized linear time in the length of {@code elements}. The iterator is
243    * consumed after this operation completes.
244    *
245    * <p>If all input data to this {@code TopKSelector} is in a single {@code Iterator},
246    * prefer {@link Ordering#leastOf(Iterator, int)}, which provides a simpler API for that use
247    * case.
248    */
249   public void offerAll(Iterator<? extends T> elements) {
250     while (elements.hasNext()) {
251       offer(elements.next());
252     }
253   }
254 
255   /**
256    * Returns the top {@code k} elements offered to this {@code TopKSelector}, or all elements if
257    * fewer than {@code k} have been offered, in the order specified by the factory used to create
258    * this {@code TopKSelector}.
259    *
260    * <p>The returned list is an unmodifiable copy and will not be affected by further changes to
261    * this {@code TopKSelector}. This method returns in O(k log k) time.
262    */
263   public List<T> topK() {
264     Arrays.sort(buffer, 0, bufferSize, comparator);
265     if (bufferSize > k) {
266       Arrays.fill(buffer, k, buffer.length, null);
267       bufferSize = k;
268       threshold = buffer[k - 1];
269     }
270     // we have to support null elements, so no ImmutableList for us
271     return Collections.unmodifiableList(Arrays.asList(Arrays.copyOf(buffer, bufferSize)));
272   }
273 }