View Javadoc
1   /*
2    * Copyright (C) 2011 The Guava Authors
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
5    * in compliance with the License. You may obtain a copy of the License at
6    *
7    * http://www.apache.org/licenses/LICENSE-2.0
8    *
9    * Unless required by applicable law or agreed to in writing, software distributed under the License
10   * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
11   * or implied. See the License for the specific language governing permissions and limitations under
12   * the License.
13   */
14  
15  package com.google.common.math;
16  
17  import static com.google.common.base.Preconditions.checkArgument;
18  import static com.google.common.base.Preconditions.checkNotNull;
19  import static com.google.common.math.MathPreconditions.checkNonNegative;
20  import static com.google.common.math.MathPreconditions.checkPositive;
21  import static com.google.common.math.MathPreconditions.checkRoundingUnnecessary;
22  import static java.math.RoundingMode.CEILING;
23  import static java.math.RoundingMode.FLOOR;
24  import static java.math.RoundingMode.HALF_EVEN;
25  
26  import com.google.common.annotations.Beta;
27  import com.google.common.annotations.GwtCompatible;
28  import com.google.common.annotations.GwtIncompatible;
29  import com.google.common.annotations.VisibleForTesting;
30  import java.math.BigDecimal;
31  import java.math.BigInteger;
32  import java.math.RoundingMode;
33  import java.util.ArrayList;
34  import java.util.List;
35  
36  /**
37   * A class for arithmetic on values of type {@code BigInteger}.
38   *
39   * <p>The implementations of many methods in this class are based on material from Henry S. Warren,
40   * Jr.'s <i>Hacker's Delight</i>, (Addison Wesley, 2002).
41   *
42   * <p>Similar functionality for {@code int} and for {@code long} can be found in {@link IntMath} and
43   * {@link LongMath} respectively.
44   *
45   * @author Louis Wasserman
46   * @since 11.0
47   */
48  @GwtCompatible(emulated = true)
49  public final class BigIntegerMath {
50    /**
51     * Returns the smallest power of two greater than or equal to {@code x}.  This is equivalent to
52     * {@code BigInteger.valueOf(2).pow(log2(x, CEILING))}.
53     *
54     * @throws IllegalArgumentException if {@code x <= 0}
55     * @since 20.0
56     */
57    @Beta
58    public static BigInteger ceilingPowerOfTwo(BigInteger x) {
59      return BigInteger.ZERO.setBit(log2(x, RoundingMode.CEILING));
60    }
61  
62    /**
63     * Returns the largest power of two less than or equal to {@code x}.  This is equivalent to
64     * {@code BigInteger.valueOf(2).pow(log2(x, FLOOR))}.
65     *
66     * @throws IllegalArgumentException if {@code x <= 0}
67     * @since 20.0
68     */
69    @Beta
70    public static BigInteger floorPowerOfTwo(BigInteger x) {
71      return BigInteger.ZERO.setBit(log2(x, RoundingMode.FLOOR));
72    }
73  
74    /**
75     * Returns {@code true} if {@code x} represents a power of two.
76     */
77    public static boolean isPowerOfTwo(BigInteger x) {
78      checkNotNull(x);
79      return x.signum() > 0 && x.getLowestSetBit() == x.bitLength() - 1;
80    }
81  
82    /**
83     * Returns the base-2 logarithm of {@code x}, rounded according to the specified rounding mode.
84     *
85     * @throws IllegalArgumentException if {@code x <= 0}
86     * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
87     *     is not a power of two
88     */
89    @SuppressWarnings("fallthrough")
90    // TODO(kevinb): remove after this warning is disabled globally
91    public static int log2(BigInteger x, RoundingMode mode) {
92      checkPositive("x", checkNotNull(x));
93      int logFloor = x.bitLength() - 1;
94      switch (mode) {
95        case UNNECESSARY:
96          checkRoundingUnnecessary(isPowerOfTwo(x)); // fall through
97        case DOWN:
98        case FLOOR:
99          return logFloor;
100 
101       case UP:
102       case CEILING:
103         return isPowerOfTwo(x) ? logFloor : logFloor + 1;
104 
105       case HALF_DOWN:
106       case HALF_UP:
107       case HALF_EVEN:
108         if (logFloor < SQRT2_PRECOMPUTE_THRESHOLD) {
109           BigInteger halfPower =
110               SQRT2_PRECOMPUTED_BITS.shiftRight(SQRT2_PRECOMPUTE_THRESHOLD - logFloor);
111           if (x.compareTo(halfPower) <= 0) {
112             return logFloor;
113           } else {
114             return logFloor + 1;
115           }
116         }
117         // Since sqrt(2) is irrational, log2(x) - logFloor cannot be exactly 0.5
118         //
119         // To determine which side of logFloor.5 the logarithm is,
120         // we compare x^2 to 2^(2 * logFloor + 1).
121         BigInteger x2 = x.pow(2);
122         int logX2Floor = x2.bitLength() - 1;
123         return (logX2Floor < 2 * logFloor + 1) ? logFloor : logFloor + 1;
124 
125       default:
126         throw new AssertionError();
127     }
128   }
129 
130   /*
131    * The maximum number of bits in a square root for which we'll precompute an explicit half power
132    * of two. This can be any value, but higher values incur more class load time and linearly
133    * increasing memory consumption.
134    */
135   @VisibleForTesting static final int SQRT2_PRECOMPUTE_THRESHOLD = 256;
136 
137   @VisibleForTesting
138   static final BigInteger SQRT2_PRECOMPUTED_BITS =
139       new BigInteger("16a09e667f3bcc908b2fb1366ea957d3e3adec17512775099da2f590b0667322a", 16);
140 
141   /**
142    * Returns the base-10 logarithm of {@code x}, rounded according to the specified rounding mode.
143    *
144    * @throws IllegalArgumentException if {@code x <= 0}
145    * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
146    *     is not a power of ten
147    */
148   @GwtIncompatible // TODO
149   @SuppressWarnings("fallthrough")
150   public static int log10(BigInteger x, RoundingMode mode) {
151     checkPositive("x", x);
152     if (fitsInLong(x)) {
153       return LongMath.log10(x.longValue(), mode);
154     }
155 
156     int approxLog10 = (int) (log2(x, FLOOR) * LN_2 / LN_10);
157     BigInteger approxPow = BigInteger.TEN.pow(approxLog10);
158     int approxCmp = approxPow.compareTo(x);
159 
160     /*
161      * We adjust approxLog10 and approxPow until they're equal to floor(log10(x)) and
162      * 10^floor(log10(x)).
163      */
164 
165     if (approxCmp > 0) {
166       /*
167        * The code is written so that even completely incorrect approximations will still yield the
168        * correct answer eventually, but in practice this branch should almost never be entered, and
169        * even then the loop should not run more than once.
170        */
171       do {
172         approxLog10--;
173         approxPow = approxPow.divide(BigInteger.TEN);
174         approxCmp = approxPow.compareTo(x);
175       } while (approxCmp > 0);
176     } else {
177       BigInteger nextPow = BigInteger.TEN.multiply(approxPow);
178       int nextCmp = nextPow.compareTo(x);
179       while (nextCmp <= 0) {
180         approxLog10++;
181         approxPow = nextPow;
182         approxCmp = nextCmp;
183         nextPow = BigInteger.TEN.multiply(approxPow);
184         nextCmp = nextPow.compareTo(x);
185       }
186     }
187 
188     int floorLog = approxLog10;
189     BigInteger floorPow = approxPow;
190     int floorCmp = approxCmp;
191 
192     switch (mode) {
193       case UNNECESSARY:
194         checkRoundingUnnecessary(floorCmp == 0);
195         // fall through
196       case FLOOR:
197       case DOWN:
198         return floorLog;
199 
200       case CEILING:
201       case UP:
202         return floorPow.equals(x) ? floorLog : floorLog + 1;
203 
204       case HALF_DOWN:
205       case HALF_UP:
206       case HALF_EVEN:
207         // Since sqrt(10) is irrational, log10(x) - floorLog can never be exactly 0.5
208         BigInteger x2 = x.pow(2);
209         BigInteger halfPowerSquared = floorPow.pow(2).multiply(BigInteger.TEN);
210         return (x2.compareTo(halfPowerSquared) <= 0) ? floorLog : floorLog + 1;
211       default:
212         throw new AssertionError();
213     }
214   }
215 
216   private static final double LN_10 = Math.log(10);
217   private static final double LN_2 = Math.log(2);
218 
219   /**
220    * Returns the square root of {@code x}, rounded with the specified rounding mode.
221    *
222    * @throws IllegalArgumentException if {@code x < 0}
223    * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and
224    *     {@code sqrt(x)} is not an integer
225    */
226   @GwtIncompatible // TODO
227   @SuppressWarnings("fallthrough")
228   public static BigInteger sqrt(BigInteger x, RoundingMode mode) {
229     checkNonNegative("x", x);
230     if (fitsInLong(x)) {
231       return BigInteger.valueOf(LongMath.sqrt(x.longValue(), mode));
232     }
233     BigInteger sqrtFloor = sqrtFloor(x);
234     switch (mode) {
235       case UNNECESSARY:
236         checkRoundingUnnecessary(sqrtFloor.pow(2).equals(x)); // fall through
237       case FLOOR:
238       case DOWN:
239         return sqrtFloor;
240       case CEILING:
241       case UP:
242         int sqrtFloorInt = sqrtFloor.intValue();
243         boolean sqrtFloorIsExact =
244             (sqrtFloorInt * sqrtFloorInt == x.intValue()) // fast check mod 2^32
245                 && sqrtFloor.pow(2).equals(x); // slow exact check
246         return sqrtFloorIsExact ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
247       case HALF_DOWN:
248       case HALF_UP:
249       case HALF_EVEN:
250         BigInteger halfSquare = sqrtFloor.pow(2).add(sqrtFloor);
251         /*
252          * We wish to test whether or not x <= (sqrtFloor + 0.5)^2 = halfSquare + 0.25. Since both x
253          * and halfSquare are integers, this is equivalent to testing whether or not x <=
254          * halfSquare.
255          */
256         return (halfSquare.compareTo(x) >= 0) ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
257       default:
258         throw new AssertionError();
259     }
260   }
261 
262   @GwtIncompatible // TODO
263   private static BigInteger sqrtFloor(BigInteger x) {
264     /*
265      * Adapted from Hacker's Delight, Figure 11-1.
266      *
267      * Using DoubleUtils.bigToDouble, getting a double approximation of x is extremely fast, and
268      * then we can get a double approximation of the square root. Then, we iteratively improve this
269      * guess with an application of Newton's method, which sets guess := (guess + (x / guess)) / 2.
270      * This iteration has the following two properties:
271      *
272      * a) every iteration (except potentially the first) has guess >= floor(sqrt(x)). This is
273      * because guess' is the arithmetic mean of guess and x / guess, sqrt(x) is the geometric mean,
274      * and the arithmetic mean is always higher than the geometric mean.
275      *
276      * b) this iteration converges to floor(sqrt(x)). In fact, the number of correct digits doubles
277      * with each iteration, so this algorithm takes O(log(digits)) iterations.
278      *
279      * We start out with a double-precision approximation, which may be higher or lower than the
280      * true value. Therefore, we perform at least one Newton iteration to get a guess that's
281      * definitely >= floor(sqrt(x)), and then continue the iteration until we reach a fixed point.
282      */
283     BigInteger sqrt0;
284     int log2 = log2(x, FLOOR);
285     if (log2 < Double.MAX_EXPONENT) {
286       sqrt0 = sqrtApproxWithDoubles(x);
287     } else {
288       int shift = (log2 - DoubleUtils.SIGNIFICAND_BITS) & ~1; // even!
289       /*
290        * We have that x / 2^shift < 2^54. Our initial approximation to sqrtFloor(x) will be
291        * 2^(shift/2) * sqrtApproxWithDoubles(x / 2^shift).
292        */
293       sqrt0 = sqrtApproxWithDoubles(x.shiftRight(shift)).shiftLeft(shift >> 1);
294     }
295     BigInteger sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
296     if (sqrt0.equals(sqrt1)) {
297       return sqrt0;
298     }
299     do {
300       sqrt0 = sqrt1;
301       sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
302     } while (sqrt1.compareTo(sqrt0) < 0);
303     return sqrt0;
304   }
305 
306   @GwtIncompatible // TODO
307   private static BigInteger sqrtApproxWithDoubles(BigInteger x) {
308     return DoubleMath.roundToBigInteger(Math.sqrt(DoubleUtils.bigToDouble(x)), HALF_EVEN);
309   }
310 
311   /**
312    * Returns the result of dividing {@code p} by {@code q}, rounding using the specified
313    * {@code RoundingMode}.
314    *
315    * @throws ArithmeticException if {@code q == 0}, or if {@code mode == UNNECESSARY} and {@code a}
316    *     is not an integer multiple of {@code b}
317    */
318   @GwtIncompatible // TODO
319   public static BigInteger divide(BigInteger p, BigInteger q, RoundingMode mode) {
320     BigDecimal pDec = new BigDecimal(p);
321     BigDecimal qDec = new BigDecimal(q);
322     return pDec.divide(qDec, 0, mode).toBigIntegerExact();
323   }
324 
325   /**
326    * Returns {@code n!}, that is, the product of the first {@code n} positive integers, or {@code 1}
327    * if {@code n == 0}.
328    *
329    * <p><b>Warning:</b> the result takes <i>O(n log n)</i> space, so use cautiously.
330    *
331    * <p>This uses an efficient binary recursive algorithm to compute the factorial with balanced
332    * multiplies. It also removes all the 2s from the intermediate products (shifting them back in at
333    * the end).
334    *
335    * @throws IllegalArgumentException if {@code n < 0}
336    */
337   public static BigInteger factorial(int n) {
338     checkNonNegative("n", n);
339 
340     // If the factorial is small enough, just use LongMath to do it.
341     if (n < LongMath.factorials.length) {
342       return BigInteger.valueOf(LongMath.factorials[n]);
343     }
344 
345     // Pre-allocate space for our list of intermediate BigIntegers.
346     int approxSize = IntMath.divide(n * IntMath.log2(n, CEILING), Long.SIZE, CEILING);
347     ArrayList<BigInteger> bignums = new ArrayList<>(approxSize);
348 
349     // Start from the pre-computed maximum long factorial.
350     int startingNumber = LongMath.factorials.length;
351     long product = LongMath.factorials[startingNumber - 1];
352     // Strip off 2s from this value.
353     int shift = Long.numberOfTrailingZeros(product);
354     product >>= shift;
355 
356     // Use floor(log2(num)) + 1 to prevent overflow of multiplication.
357     int productBits = LongMath.log2(product, FLOOR) + 1;
358     int bits = LongMath.log2(startingNumber, FLOOR) + 1;
359     // Check for the next power of two boundary, to save us a CLZ operation.
360     int nextPowerOfTwo = 1 << (bits - 1);
361 
362     // Iteratively multiply the longs as big as they can go.
363     for (long num = startingNumber; num <= n; num++) {
364       // Check to see if the floor(log2(num)) + 1 has changed.
365       if ((num & nextPowerOfTwo) != 0) {
366         nextPowerOfTwo <<= 1;
367         bits++;
368       }
369       // Get rid of the 2s in num.
370       int tz = Long.numberOfTrailingZeros(num);
371       long normalizedNum = num >> tz;
372       shift += tz;
373       // Adjust floor(log2(num)) + 1.
374       int normalizedBits = bits - tz;
375       // If it won't fit in a long, then we store off the intermediate product.
376       if (normalizedBits + productBits >= Long.SIZE) {
377         bignums.add(BigInteger.valueOf(product));
378         product = 1;
379         productBits = 0;
380       }
381       product *= normalizedNum;
382       productBits = LongMath.log2(product, FLOOR) + 1;
383     }
384     // Check for leftovers.
385     if (product > 1) {
386       bignums.add(BigInteger.valueOf(product));
387     }
388     // Efficiently multiply all the intermediate products together.
389     return listProduct(bignums).shiftLeft(shift);
390   }
391 
392   static BigInteger listProduct(List<BigInteger> nums) {
393     return listProduct(nums, 0, nums.size());
394   }
395 
396   static BigInteger listProduct(List<BigInteger> nums, int start, int end) {
397     switch (end - start) {
398       case 0:
399         return BigInteger.ONE;
400       case 1:
401         return nums.get(start);
402       case 2:
403         return nums.get(start).multiply(nums.get(start + 1));
404       case 3:
405         return nums.get(start).multiply(nums.get(start + 1)).multiply(nums.get(start + 2));
406       default:
407         // Otherwise, split the list in half and recursively do this.
408         int m = (end + start) >>> 1;
409         return listProduct(nums, start, m).multiply(listProduct(nums, m, end));
410     }
411   }
412 
413   /**
414    * Returns {@code n} choose {@code k}, also known as the binomial coefficient of {@code n} and
415    * {@code k}, that is, {@code n! / (k! (n - k)!)}.
416    *
417    * <p><b>Warning:</b> the result can take as much as <i>O(k log n)</i> space.
418    *
419    * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0}, or {@code k > n}
420    */
421   public static BigInteger binomial(int n, int k) {
422     checkNonNegative("n", n);
423     checkNonNegative("k", k);
424     checkArgument(k <= n, "k (%s) > n (%s)", k, n);
425     if (k > (n >> 1)) {
426       k = n - k;
427     }
428     if (k < LongMath.biggestBinomials.length && n <= LongMath.biggestBinomials[k]) {
429       return BigInteger.valueOf(LongMath.binomial(n, k));
430     }
431 
432     BigInteger accum = BigInteger.ONE;
433 
434     long numeratorAccum = n;
435     long denominatorAccum = 1;
436 
437     int bits = LongMath.log2(n, RoundingMode.CEILING);
438 
439     int numeratorBits = bits;
440 
441     for (int i = 1; i < k; i++) {
442       int p = n - i;
443       int q = i + 1;
444 
445       // log2(p) >= bits - 1, because p >= n/2
446 
447       if (numeratorBits + bits >= Long.SIZE - 1) {
448         // The numerator is as big as it can get without risking overflow.
449         // Multiply numeratorAccum / denominatorAccum into accum.
450         accum =
451             accum
452                 .multiply(BigInteger.valueOf(numeratorAccum))
453                 .divide(BigInteger.valueOf(denominatorAccum));
454         numeratorAccum = p;
455         denominatorAccum = q;
456         numeratorBits = bits;
457       } else {
458         // We can definitely multiply into the long accumulators without overflowing them.
459         numeratorAccum *= p;
460         denominatorAccum *= q;
461         numeratorBits += bits;
462       }
463     }
464     return accum
465         .multiply(BigInteger.valueOf(numeratorAccum))
466         .divide(BigInteger.valueOf(denominatorAccum));
467   }
468 
469   // Returns true if BigInteger.valueOf(x.longValue()).equals(x).
470   @GwtIncompatible // TODO
471   static boolean fitsInLong(BigInteger x) {
472     return x.bitLength() <= Long.SIZE - 1;
473   }
474 
475   private BigIntegerMath() {}
476 }