fp/prime/
mod.rs

1use std::{
2    fmt::{Debug, Display},
3    hash::Hash,
4    num::NonZeroU32,
5    ops::{
6        Add, AddAssign, Div, DivAssign, Mul, MulAssign, Rem, RemAssign, Shl, ShlAssign, Shr,
7        ShrAssign, Sub, SubAssign,
8    },
9};
10
11use serde::{Deserialize, Deserializer, Serialize, Serializer, de::Error};
12
13pub mod binomial;
14pub mod iter;
15
16#[cfg(not(feature = "odd-primes"))]
17pub mod primes_2;
18#[cfg(feature = "odd-primes")]
19pub mod primes_generic;
20
21pub use binomial::Binomial;
22#[cfg(not(feature = "odd-primes"))]
23pub use primes_2::*;
24#[cfg(feature = "odd-primes")]
25pub use primes_generic::*;
26
27pub const TWO: ValidPrime = ValidPrime::new(2);
28
29/// A trait that represents a prime number. There are currently two kinds of structs that implement
30/// this trait: static primes and `ValidPrime`, the dynamic prime.
31///
32/// The methods in this trait take a `self` receiver so that the dynamic prime `ValidPrime` can
33/// implement it. We could also have a `&self` receiver, but since `Prime` is a supertrait of
34/// `Copy`, the two are equivalent. Using `self` might also be useful in the future if we ever want
35/// to play with autoref specialization.
36///
37/// The fact that e.g. `P2::to_u32` is hardcoded to return 2 means that a condition like `p.to_u32()
38/// == 2` (or even better, just `p == 2`) will reduce to `true` at compile time, and allow the
39/// compiler to eliminate an entire branch, while also leaving that check in for when the prime is
40/// chosen at runtime.
41#[allow(private_bounds)]
42pub trait Prime:
43    Debug
44    + Clone
45    + Copy
46    + Display
47    + Hash
48    + PartialEq
49    + Eq
50    + PartialEq<u32>
51    + PartialOrd<u32>
52    + Add<u32, Output = u32>
53    + Sub<u32, Output = u32>
54    + Mul<u32, Output = u32>
55    + Div<u32, Output = u32>
56    + Rem<u32, Output = u32>
57    + Shl<u32, Output = u32>
58    + Shr<u32, Output = u32>
59    + Serialize
60    + for<'de> Deserialize<'de>
61    + crate::MaybeArbitrary<Option<NonZeroU32>>
62    + 'static
63{
64    fn as_i32(self) -> i32;
65    fn to_dyn(self) -> ValidPrime;
66
67    fn as_u32(self) -> u32 {
68        self.as_i32() as u32
69    }
70
71    fn as_usize(self) -> usize {
72        self.as_u32() as usize
73    }
74
75    /// Computes the sum mod p. This takes care of overflow.
76    fn sum(self, n1: u32, n2: u32) -> u32 {
77        let n1 = n1 as u64;
78        let n2 = n2 as u64;
79        let p = self.as_u32() as u64;
80        let sum = (n1 + n2) % p;
81        sum as u32
82    }
83
84    /// Computes the product mod p. This takes care of overflow.
85    fn product(self, n1: u32, n2: u32) -> u32 {
86        let n1 = n1 as u64;
87        let n2 = n2 as u64;
88        let p = self.as_u32() as u64;
89        let product = (n1 * n2) % p;
90        product as u32
91    }
92
93    fn inverse(self, k: u32) -> u32 {
94        inverse(self, k)
95    }
96
97    fn pow(self, exp: u32) -> u32 {
98        self.as_u32().pow(exp)
99    }
100
101    fn pow_mod(self, mut b: u32, mut e: u32) -> u32 {
102        assert!(self.as_u32() > 0);
103        let mut result: u32 = 1;
104        while e > 0 {
105            if (e & 1) == 1 {
106                result = self.product(result, b);
107            }
108            b = self.product(b, b);
109            e >>= 1;
110        }
111        result
112    }
113}
114
115#[derive(Debug, Clone, PartialEq, Eq)]
116pub enum PrimeError {
117    NotAnInteger(std::num::ParseIntError),
118    InvalidPrime(u32),
119}
120
121impl std::fmt::Display for PrimeError {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        match self {
124            Self::NotAnInteger(s) => write!(f, "Not an integer: {s}"),
125            Self::InvalidPrime(p) => write!(f, "{p} is not a valid prime"),
126        }
127    }
128}
129
130macro_rules! def_prime_static {
131    ($pn:ident, $p:literal) => {
132        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
133        pub struct $pn;
134
135        impl Prime for $pn {
136            #[inline]
137            fn as_i32(self) -> i32 {
138                $p
139            }
140
141            #[inline]
142            fn to_dyn(self) -> ValidPrime {
143                ValidPrime::new($p)
144            }
145        }
146
147        impl Serialize for $pn {
148            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
149            where
150                S: Serializer,
151            {
152                self.as_u32().serialize(serializer)
153            }
154        }
155
156        impl<'de> Deserialize<'de> for $pn {
157            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
158            where
159                D: Deserializer<'de>,
160            {
161                let p: u32 = u32::deserialize(deserializer)?;
162                $pn::try_from(p).map_err(D::Error::custom)
163            }
164        }
165
166        #[cfg(feature = "proptest")]
167        impl proptest::arbitrary::Arbitrary for $pn {
168            type Parameters = Option<NonZeroU32>;
169            type Strategy = proptest::strategy::Just<$pn>;
170
171            fn arbitrary_with(_max: Self::Parameters) -> Self::Strategy {
172                // This doesn't honor the max parameter, but that should be fine as long as the
173                // static primes are small enough and/or the max is large enough. There's no such
174                // thing as an empty strategy, so the best we could do is return a strategy that
175                // always rejects. This would cause local failures that may make tests fail.
176                proptest::strategy::Just($pn)
177            }
178        }
179
180        impl crate::MaybeArbitrary<Option<NonZeroU32>> for $pn {}
181    };
182}
183
184macro_rules! impl_op_pn_u32 {
185    ($pn:ty, $trt:ident, $mth:ident, $trt_assign:ident, $mth_assign:ident, $operator:tt) => {
186        impl $trt<$pn> for u32 {
187            type Output = u32;
188
189            fn $mth(self, other: $pn) -> Self::Output {
190                self $operator other.as_u32()
191            }
192        }
193
194        impl $trt<u32> for $pn {
195            type Output = u32;
196
197            fn $mth(self, other: u32) -> Self::Output {
198                self.as_u32() $operator other
199            }
200        }
201
202        impl $trt<$pn> for $pn {
203            type Output = u32;
204
205            fn $mth(self, other: $pn) -> Self::Output {
206                self.as_u32() $operator other.as_u32()
207            }
208        }
209
210        impl $trt_assign<$pn> for u32 {
211            fn $mth_assign(&mut self, other: $pn) {
212                *self = *self $operator other;
213            }
214        }
215    };
216}
217
218macro_rules! impl_prime_ops {
219    ($pn:ty) => {
220        impl_op_pn_u32!($pn, Add, add, AddAssign, add_assign, +);
221        impl_op_pn_u32!($pn, Sub, sub, SubAssign, sub_assign, -);
222        impl_op_pn_u32!($pn, Mul, mul, MulAssign, mul_assign, *);
223        impl_op_pn_u32!($pn, Div, div, DivAssign, div_assign, /);
224        impl_op_pn_u32!($pn, Rem, rem, RemAssign, rem_assign, %);
225        impl_op_pn_u32!($pn, Shl, shl, ShlAssign, shl_assign, <<);
226        impl_op_pn_u32!($pn, Shr, shr, ShrAssign, shr_assign, >>);
227
228        impl PartialEq<u32> for $pn {
229            fn eq(&self, other: &u32) -> bool {
230                self.as_u32() == *other
231            }
232        }
233
234        impl PartialOrd<u32> for $pn {
235            fn partial_cmp(&self, other: &u32) -> Option<std::cmp::Ordering> {
236                self.as_u32().partial_cmp(other)
237            }
238        }
239
240        impl From<$pn> for u32 {
241            fn from(value: $pn) -> u32 {
242                value.as_u32()
243            }
244        }
245
246        impl std::fmt::Display for $pn {
247            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
248                <u32 as std::fmt::Display>::fmt(&self.as_u32(), f)
249            }
250        }
251    };
252}
253
254macro_rules! impl_try_from {
255    ($p:tt) => {
256        impl_try_from!(@ $p, $p);
257    };
258    // We need the type both as a type and as an expression.
259    (@ $pn:ty, $pn_ex:expr) => {
260        impl std::convert::TryFrom<u32> for $pn {
261            type Error = PrimeError;
262
263            fn try_from(p: u32) -> Result<Self, PrimeError> {
264                if $pn_ex == p {
265                    Ok($pn_ex)
266                } else {
267                    Err(PrimeError::InvalidPrime(p))
268                }
269            }
270        }
271    };
272}
273
274// Strange but required to export macro properly
275use def_prime_static;
276use impl_op_pn_u32;
277use impl_prime_ops;
278use impl_try_from;
279
280/// Compute b^e mod p. This is a const version of `Prime::pow_mod`.
281pub const fn power_mod(p: u32, mut b: u32, mut e: u32) -> u32 {
282    // We can't use Prime::product because const traits are still unstable
283    assert!(p > 0);
284    let mut result: u32 = 1;
285    while e > 0 {
286        if (e & 1) == 1 {
287            result = ((result as u64) * (b as u64) % (p as u64)) as u32;
288        }
289        b = (((b as u64) * (b as u64)) % (p as u64)) as u32;
290        e >>= 1;
291    }
292    result
293}
294
295/// Compute the base 2 log of a number, rounded down to the nearest integer.
296///
297/// # Example
298/// ```
299/// # use fp::prime::log2;
300/// assert_eq!(0, log2(0b1));
301/// assert_eq!(1, log2(0b10));
302/// assert_eq!(1, log2(0b11));
303/// assert_eq!(3, log2(0b1011));
304/// ```
305pub const fn log2(n: usize) -> usize {
306    std::mem::size_of::<usize>() * 8 - 1 - n.leading_zeros() as usize
307}
308
309/// Discrete log base p of n.
310pub fn logp<P: Prime>(p: P, mut n: u32) -> u32 {
311    let mut result = 0u32;
312    while n > 0 {
313        n /= p.as_u32();
314        result += 1;
315    }
316    result
317}
318
319/// Factor $n$ as $p^k m$. Returns $(k, m)$.
320pub fn factor_pk<P: Prime>(p: P, mut n: u32) -> (u32, u32) {
321    if n == 0 {
322        return (0, 0);
323    }
324    let mut k = 0;
325    while n.is_multiple_of(p.as_u32()) {
326        n /= p.as_u32();
327        k += 1;
328    }
329    (k, n)
330}
331
332// Uses a the lookup table we initialized.
333pub fn inverse<P: Prime>(p: P, k: u32) -> u32 {
334    use crate::constants::{INVERSE_TABLE, MAX_PRIME, PRIME_TO_INDEX_MAP};
335    assert!(k > 0 && p > k);
336
337    if p <= MAX_PRIME as u32 {
338        // LLVM doesn't understand the inequality is transitive
339        unsafe { *INVERSE_TABLE[PRIME_TO_INDEX_MAP[p.as_usize()]].get_unchecked(k as usize) }
340    } else {
341        power_mod(p.as_u32(), k, p.as_u32() - 2)
342    }
343}
344
345#[inline(always)]
346pub fn minus_one_to_the_n<P: Prime>(p: P, i: i32) -> u32 {
347    if i % 2 == 0 { 1 } else { p - 1 }
348}
349
350#[cfg(test)]
351pub(crate) mod tests {
352    use super::{Prime, ValidPrime, binomial::Binomial, inverse, iter::BinomialIterator};
353    use crate::{
354        constants::PRIMES,
355        prime::{PrimeError, is_prime},
356    };
357
358    #[test]
359    fn validprime_test() {
360        for p in (0..(1 << 16)).filter(|&p| is_prime(p)) {
361            assert_eq!(ValidPrime::new(p), p);
362        }
363    }
364
365    #[test]
366    fn validprime_invalid() {
367        assert_eq!(
368            ValidPrime::try_from(4).unwrap_err(),
369            PrimeError::InvalidPrime(4)
370        );
371        assert_eq!(
372            "4".parse::<ValidPrime>().unwrap_err(),
373            PrimeError::InvalidPrime(4)
374        );
375        assert_eq!(
376            "4.0".parse::<ValidPrime>().unwrap_err(),
377            PrimeError::NotAnInteger("4.0".parse::<u32>().unwrap_err())
378        );
379    }
380
381    #[test]
382    fn inverse_test() {
383        for &p in PRIMES.iter() {
384            let p = ValidPrime::new(p);
385            for k in 1..p.as_u32() {
386                assert_eq!((inverse(p, k) * k) % p, 1);
387            }
388        }
389    }
390
391    #[test]
392    fn binomial_test() {
393        let entries = [[2, 2, 1, 0], [2, 3, 1, 1], [3, 1090, 730, 1], [7, 3, 2, 3]];
394
395        for entry in &entries {
396            assert_eq!(
397                entry[3],
398                u32::binomial(ValidPrime::new(entry[0]), entry[1], entry[2])
399            );
400        }
401    }
402
403    #[test]
404    fn binomial_vs_monomial() {
405        for &p in &[2, 3, 5, 7, 11] {
406            let p = ValidPrime::new(p);
407            for l in 0..20 {
408                for m in 0..20 {
409                    assert_eq!(u32::binomial(p, l + m, m), u32::multinomial(p, &mut [l, m]))
410                }
411            }
412        }
413    }
414
415    fn binomial_full(n: u32, j: u32) -> u32 {
416        let mut res = 1;
417        for k in j + 1..=n {
418            res *= k;
419        }
420        for k in 1..=(n - j) {
421            res /= k;
422        }
423        res
424    }
425
426    #[test]
427    fn binomial_cmp() {
428        for n in 0..12 {
429            for j in 0..=n {
430                let ans = binomial_full(n, j);
431                for &p in &[2, 3, 5, 7, 11] {
432                    assert_eq!(
433                        u32::binomial(ValidPrime::new(p), n, j),
434                        ans % p,
435                        "{n} choose {j} mod {p}"
436                    );
437                }
438                assert_eq!(u32::binomial4(n, j), ans % 4, "{n} choose {j} mod 4");
439                // binomial4_rec is only called on large n. It does not handle the n = 0, 1 cases
440                // correctly.
441                if n > 1 {
442                    assert_eq!(
443                        u32::binomial4_rec(n, j),
444                        ans % 4,
445                        "{n} choose {j} mod 4 rec"
446                    );
447                }
448            }
449        }
450    }
451
452    #[test]
453    fn binomial_iterator() {
454        let mut iter = BinomialIterator::new(4);
455        assert_eq!(iter.next(), Some(0b1111));
456        assert_eq!(iter.next(), Some(0b10111));
457        assert_eq!(iter.next(), Some(0b11011));
458        assert_eq!(iter.next(), Some(0b11101));
459        assert_eq!(iter.next(), Some(0b11110));
460        assert_eq!(iter.next(), Some(0b100111));
461        assert_eq!(iter.next(), Some(0b101011));
462        assert_eq!(iter.next(), Some(0b101101));
463        assert_eq!(iter.next(), Some(0b101110));
464        assert_eq!(iter.next(), Some(0b110011));
465    }
466}