fp/prime/
binomial.rs

1use super::{Prime, ValidPrime};
2use crate::{
3    PRIME_TO_INDEX_MAP,
4    constants::{BINOMIAL_TABLE, BINOMIAL4_TABLE, BINOMIAL4_TABLE_SIZE},
5};
6
7/// This uses a lookup table for n choose k when n and k are both less than p.
8/// Lucas's theorem reduces general binomial coefficients to this case.
9///
10/// Calling this function safely requires that `k, n < p`.  These invariants are often known
11/// apriori because k and n are obtained by reducing mod p, so it is better to expose an unsafe
12/// interface that avoids these checks.
13unsafe fn direct_binomial(p: ValidPrime, n: usize, k: usize) -> u32 {
14    unsafe {
15        *BINOMIAL_TABLE
16            .get_unchecked(PRIME_TO_INDEX_MAP[p.as_usize()])
17            .get_unchecked(n)
18            .get_unchecked(k)
19    }
20}
21
22/// A number satisfying the Binomial trait supports computing various binomial coefficients. This
23/// is implemented using a macro, since the implementation for all types is syntactically the same.
24pub trait Binomial: Sized {
25    /// mod 2 multinomial coefficient
26    fn multinomial2(k: &[Self]) -> Self;
27
28    /// mod 2 binomial coefficient n choose k
29    fn binomial2(n: Self, k: Self) -> Self;
30
31    /// Binomial coefficients mod 4. We pre-compute the coefficients for small values of n. For large
32    /// n, we recursively use the fact that if n = 2^k + l, l < 2^k, then
33    ///
34    ///    n choose r = l choose r + 2 (l choose (r - 2^{k - 1})) + (l choose (r - 2^k))
35    ///
36    /// This is easy to verify using the fact that
37    ///
38    ///    (x + y)^{2^k} = x^{2^k} + 2 x^{2^{k - 1}} y^{2^{k - 1}} + y^{2^k}
39    fn binomial4(n: Self, k: Self) -> Self;
40
41    /// Compute binomial coefficients mod 4 using the recursion relation in the documentation of
42    /// [Binomial::binomial4]. This calls into binomial4 instead of binomial4_rec. The main purpose
43    /// of this is to separate out the logic for testing.
44    fn binomial4_rec(n: Self, k: Self) -> Self;
45
46    /// Computes the multinomial coefficient mod p using Lucas' theorem. This modifies the
47    /// underlying list. For p = 2 it is more efficient to use multinomial2
48    fn multinomial_odd(p: ValidPrime, l: &mut [Self]) -> Self;
49
50    /// Compute odd binomial coefficients mod p, where p is odd. For p = 2 it is more efficient to
51    /// use binomial2
52    fn binomial_odd(p: ValidPrime, n: Self, k: Self) -> Self;
53
54    /// Checks whether n choose k is zero mod p. Since we don't have to compute the value, this is
55    /// faster than binomial_odd.
56    fn binomial_odd_is_zero(p: ValidPrime, n: Self, k: Self) -> bool;
57
58    /// Multinomial coefficient of the list l
59    fn multinomial(p: ValidPrime, l: &mut [Self]) -> Self {
60        if p == 2 {
61            Self::multinomial2(l)
62        } else {
63            Self::multinomial_odd(p, l)
64        }
65    }
66
67    /// Binomial coefficient n choose k.
68    fn binomial(p: ValidPrime, n: Self, k: Self) -> Self {
69        if p == 2 {
70            Self::binomial2(n, k)
71        } else {
72            Self::binomial_odd(p, n, k)
73        }
74    }
75}
76
77macro_rules! impl_binomial {
78    ($T:ty) => {
79        impl Binomial for $T {
80            #[inline]
81            fn multinomial2(l: &[Self]) -> Self {
82                let mut bit_or: Self = 0;
83                let mut sum: Self = 0;
84                for &e in l {
85                    sum += e;
86                    bit_or |= e;
87                }
88                if bit_or == sum { 1 } else { 0 }
89            }
90
91            #[inline]
92            fn binomial2(n: Self, k: Self) -> Self {
93                if n < k {
94                    0
95                } else if (n - k) & k == 0 {
96                    1
97                } else {
98                    0
99                }
100            }
101
102            #[inline]
103            fn multinomial_odd(p_: ValidPrime, l: &mut [Self]) -> Self {
104                let p = p_.as_u32() as Self;
105
106                let mut n: Self = l.iter().sum();
107                if n == 0 {
108                    return 1;
109                }
110                let mut answer = 1;
111
112                while n > 0 {
113                    let mut multi: Self = 1;
114
115                    let total_entry = n % p;
116                    n /= p;
117
118                    let mut partial_sum: Self = l[0] % p;
119                    l[0] /= p;
120
121                    for ll in l.iter_mut().skip(1) {
122                        let entry = *ll % p;
123                        *ll /= p;
124
125                        partial_sum += entry;
126                        if partial_sum > total_entry {
127                            // This early return is necessary because direct_binomial only works when
128                            // partial_sum < 19
129                            return 0;
130                        }
131                        // This is safe because p < 20, partial_sum <= total_entry < p and entry < p.
132                        multi *=
133                            unsafe { direct_binomial(p_, partial_sum as usize, entry as usize) }
134                                as Self;
135                        multi %= p;
136                    }
137                    answer *= multi;
138                    answer %= p;
139                }
140                answer
141            }
142
143            #[inline]
144            fn binomial_odd(p_: ValidPrime, mut n: Self, mut k: Self) -> Self {
145                let p = p_.as_u32() as Self;
146
147                // We have both signed and unsigned types
148                #[allow(unused_comparisons)]
149                if n < k || k < 0 {
150                    return 0;
151                }
152
153                let mut answer = 1;
154
155                while n > 0 {
156                    // This is safe because p < 20 and anything mod p is < p.
157                    answer *=
158                        unsafe { direct_binomial(p_, (n % p) as usize, (k % p) as usize) } as Self;
159                    answer %= p;
160                    n /= p;
161                    k /= p;
162                }
163                answer
164            }
165
166            #[inline]
167            fn binomial_odd_is_zero(p: ValidPrime, mut n: Self, mut k: Self) -> bool {
168                let p = p.as_u32() as Self;
169
170                while n > 0 {
171                    if n % p < k % p {
172                        return true;
173                    }
174                    n /= p;
175                    k /= p;
176                }
177                false
178            }
179
180            fn binomial4(n: Self, j: Self) -> Self {
181                if (n as usize) < BINOMIAL4_TABLE_SIZE {
182                    return BINOMIAL4_TABLE[n as usize][j as usize] as Self;
183                }
184                if (n - j) & j == 0 {
185                    // Answer is odd
186                    Self::binomial4_rec(n, j)
187                } else if (n - j).count_ones() + j.count_ones() - n.count_ones() == 1 {
188                    2
189                } else {
190                    0
191                }
192            }
193
194            #[inline]
195            fn binomial4_rec(n: Self, j: Self) -> Self {
196                let k = (std::mem::size_of::<Self>() * 8) as u32 - n.leading_zeros() - 1;
197                let l = n - (1 << k);
198                let mut ans = 0;
199                if j <= l {
200                    ans += Self::binomial4(l, j)
201                }
202                let pow = 1 << (k - 1);
203                if pow <= j && j <= l + pow {
204                    ans += 2 * Self::binomial2(l, j - pow);
205                }
206                if j >= (1 << k) {
207                    ans += Self::binomial4(l, j - (1 << k));
208                }
209                ans % 4
210            }
211        }
212    };
213}
214
215impl_binomial!(u32);
216impl_binomial!(u16);
217impl_binomial!(i32);