fp/prime/binomial.rs
1use super::{Prime, ValidPrime};
2use crate::{
3 PRIME_TO_INDEX_MAP,
4 constants::{BINOMIAL_TABLE, BINOMIAL4_TABLE, BINOMIAL4_TABLE_SIZE},
5};
6
7/// This uses a lookup table for n choose k when n and k are both less than p.
8/// Lucas's theorem reduces general binomial coefficients to this case.
9///
10/// Calling this function safely requires that `k, n < p`. These invariants are often known
11/// apriori because k and n are obtained by reducing mod p, so it is better to expose an unsafe
12/// interface that avoids these checks.
13unsafe fn direct_binomial(p: ValidPrime, n: usize, k: usize) -> u32 {
14 unsafe {
15 *BINOMIAL_TABLE
16 .get_unchecked(PRIME_TO_INDEX_MAP[p.as_usize()])
17 .get_unchecked(n)
18 .get_unchecked(k)
19 }
20}
21
22/// A number satisfying the Binomial trait supports computing various binomial coefficients. This
23/// is implemented using a macro, since the implementation for all types is syntactically the same.
24pub trait Binomial: Sized {
25 /// mod 2 multinomial coefficient
26 fn multinomial2(k: &[Self]) -> Self;
27
28 /// mod 2 binomial coefficient n choose k
29 fn binomial2(n: Self, k: Self) -> Self;
30
31 /// Binomial coefficients mod 4. We pre-compute the coefficients for small values of n. For large
32 /// n, we recursively use the fact that if n = 2^k + l, l < 2^k, then
33 ///
34 /// n choose r = l choose r + 2 (l choose (r - 2^{k - 1})) + (l choose (r - 2^k))
35 ///
36 /// This is easy to verify using the fact that
37 ///
38 /// (x + y)^{2^k} = x^{2^k} + 2 x^{2^{k - 1}} y^{2^{k - 1}} + y^{2^k}
39 fn binomial4(n: Self, k: Self) -> Self;
40
41 /// Compute binomial coefficients mod 4 using the recursion relation in the documentation of
42 /// [Binomial::binomial4]. This calls into binomial4 instead of binomial4_rec. The main purpose
43 /// of this is to separate out the logic for testing.
44 fn binomial4_rec(n: Self, k: Self) -> Self;
45
46 /// Computes the multinomial coefficient mod p using Lucas' theorem. This modifies the
47 /// underlying list. For p = 2 it is more efficient to use multinomial2
48 fn multinomial_odd(p: ValidPrime, l: &mut [Self]) -> Self;
49
50 /// Compute odd binomial coefficients mod p, where p is odd. For p = 2 it is more efficient to
51 /// use binomial2
52 fn binomial_odd(p: ValidPrime, n: Self, k: Self) -> Self;
53
54 /// Checks whether n choose k is zero mod p. Since we don't have to compute the value, this is
55 /// faster than binomial_odd.
56 fn binomial_odd_is_zero(p: ValidPrime, n: Self, k: Self) -> bool;
57
58 /// Multinomial coefficient of the list l
59 fn multinomial(p: ValidPrime, l: &mut [Self]) -> Self {
60 if p == 2 {
61 Self::multinomial2(l)
62 } else {
63 Self::multinomial_odd(p, l)
64 }
65 }
66
67 /// Binomial coefficient n choose k.
68 fn binomial(p: ValidPrime, n: Self, k: Self) -> Self {
69 if p == 2 {
70 Self::binomial2(n, k)
71 } else {
72 Self::binomial_odd(p, n, k)
73 }
74 }
75}
76
77macro_rules! impl_binomial {
78 ($T:ty) => {
79 impl Binomial for $T {
80 #[inline]
81 fn multinomial2(l: &[Self]) -> Self {
82 let mut bit_or: Self = 0;
83 let mut sum: Self = 0;
84 for &e in l {
85 sum += e;
86 bit_or |= e;
87 }
88 if bit_or == sum { 1 } else { 0 }
89 }
90
91 #[inline]
92 fn binomial2(n: Self, k: Self) -> Self {
93 if n < k {
94 0
95 } else if (n - k) & k == 0 {
96 1
97 } else {
98 0
99 }
100 }
101
102 #[inline]
103 fn multinomial_odd(p_: ValidPrime, l: &mut [Self]) -> Self {
104 let p = p_.as_u32() as Self;
105
106 let mut n: Self = l.iter().sum();
107 if n == 0 {
108 return 1;
109 }
110 let mut answer = 1;
111
112 while n > 0 {
113 let mut multi: Self = 1;
114
115 let total_entry = n % p;
116 n /= p;
117
118 let mut partial_sum: Self = l[0] % p;
119 l[0] /= p;
120
121 for ll in l.iter_mut().skip(1) {
122 let entry = *ll % p;
123 *ll /= p;
124
125 partial_sum += entry;
126 if partial_sum > total_entry {
127 // This early return is necessary because direct_binomial only works when
128 // partial_sum < 19
129 return 0;
130 }
131 // This is safe because p < 20, partial_sum <= total_entry < p and entry < p.
132 multi *=
133 unsafe { direct_binomial(p_, partial_sum as usize, entry as usize) }
134 as Self;
135 multi %= p;
136 }
137 answer *= multi;
138 answer %= p;
139 }
140 answer
141 }
142
143 #[inline]
144 fn binomial_odd(p_: ValidPrime, mut n: Self, mut k: Self) -> Self {
145 let p = p_.as_u32() as Self;
146
147 // We have both signed and unsigned types
148 #[allow(unused_comparisons)]
149 if n < k || k < 0 {
150 return 0;
151 }
152
153 let mut answer = 1;
154
155 while n > 0 {
156 // This is safe because p < 20 and anything mod p is < p.
157 answer *=
158 unsafe { direct_binomial(p_, (n % p) as usize, (k % p) as usize) } as Self;
159 answer %= p;
160 n /= p;
161 k /= p;
162 }
163 answer
164 }
165
166 #[inline]
167 fn binomial_odd_is_zero(p: ValidPrime, mut n: Self, mut k: Self) -> bool {
168 let p = p.as_u32() as Self;
169
170 while n > 0 {
171 if n % p < k % p {
172 return true;
173 }
174 n /= p;
175 k /= p;
176 }
177 false
178 }
179
180 fn binomial4(n: Self, j: Self) -> Self {
181 if (n as usize) < BINOMIAL4_TABLE_SIZE {
182 return BINOMIAL4_TABLE[n as usize][j as usize] as Self;
183 }
184 if (n - j) & j == 0 {
185 // Answer is odd
186 Self::binomial4_rec(n, j)
187 } else if (n - j).count_ones() + j.count_ones() - n.count_ones() == 1 {
188 2
189 } else {
190 0
191 }
192 }
193
194 #[inline]
195 fn binomial4_rec(n: Self, j: Self) -> Self {
196 let k = (std::mem::size_of::<Self>() * 8) as u32 - n.leading_zeros() - 1;
197 let l = n - (1 << k);
198 let mut ans = 0;
199 if j <= l {
200 ans += Self::binomial4(l, j)
201 }
202 let pow = 1 << (k - 1);
203 if pow <= j && j <= l + pow {
204 ans += 2 * Self::binomial2(l, j - pow);
205 }
206 if j >= (1 << k) {
207 ans += Self::binomial4(l, j - (1 << k));
208 }
209 ans % 4
210 }
211 }
212 };
213}
214
215impl_binomial!(u32);
216impl_binomial!(u16);
217impl_binomial!(i32);