1use std::{
2 fmt::{Debug, Display},
3 hash::Hash,
4 num::NonZeroU32,
5 ops::{
6 Add, AddAssign, Div, DivAssign, Mul, MulAssign, Rem, RemAssign, Shl, ShlAssign, Shr,
7 ShrAssign, Sub, SubAssign,
8 },
9};
10
11use serde::{Deserialize, Deserializer, Serialize, Serializer, de::Error};
12
13pub mod binomial;
14pub mod iter;
15
16#[cfg(not(feature = "odd-primes"))]
17pub mod primes_2;
18#[cfg(feature = "odd-primes")]
19pub mod primes_generic;
20
21pub use binomial::Binomial;
22#[cfg(not(feature = "odd-primes"))]
23pub use primes_2::*;
24#[cfg(feature = "odd-primes")]
25pub use primes_generic::*;
26
27pub const TWO: ValidPrime = ValidPrime::new(2);
28
29#[allow(private_bounds)]
42pub trait Prime:
43 Debug
44 + Clone
45 + Copy
46 + Display
47 + Hash
48 + PartialEq
49 + Eq
50 + PartialEq<u32>
51 + PartialOrd<u32>
52 + Add<u32, Output = u32>
53 + Sub<u32, Output = u32>
54 + Mul<u32, Output = u32>
55 + Div<u32, Output = u32>
56 + Rem<u32, Output = u32>
57 + Shl<u32, Output = u32>
58 + Shr<u32, Output = u32>
59 + Serialize
60 + for<'de> Deserialize<'de>
61 + crate::MaybeArbitrary<Option<NonZeroU32>>
62 + 'static
63{
64 fn as_i32(self) -> i32;
65 fn to_dyn(self) -> ValidPrime;
66
67 fn as_u32(self) -> u32 {
68 self.as_i32() as u32
69 }
70
71 fn as_usize(self) -> usize {
72 self.as_u32() as usize
73 }
74
75 fn sum(self, n1: u32, n2: u32) -> u32 {
77 let n1 = n1 as u64;
78 let n2 = n2 as u64;
79 let p = self.as_u32() as u64;
80 let sum = (n1 + n2) % p;
81 sum as u32
82 }
83
84 fn product(self, n1: u32, n2: u32) -> u32 {
86 let n1 = n1 as u64;
87 let n2 = n2 as u64;
88 let p = self.as_u32() as u64;
89 let product = (n1 * n2) % p;
90 product as u32
91 }
92
93 fn inverse(self, k: u32) -> u32 {
94 inverse(self, k)
95 }
96
97 fn pow(self, exp: u32) -> u32 {
98 self.as_u32().pow(exp)
99 }
100
101 fn pow_mod(self, mut b: u32, mut e: u32) -> u32 {
102 assert!(self.as_u32() > 0);
103 let mut result: u32 = 1;
104 while e > 0 {
105 if (e & 1) == 1 {
106 result = self.product(result, b);
107 }
108 b = self.product(b, b);
109 e >>= 1;
110 }
111 result
112 }
113}
114
115#[derive(Debug, Clone, PartialEq, Eq)]
116pub enum PrimeError {
117 NotAnInteger(std::num::ParseIntError),
118 InvalidPrime(u32),
119}
120
121impl std::fmt::Display for PrimeError {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 match self {
124 Self::NotAnInteger(s) => write!(f, "Not an integer: {s}"),
125 Self::InvalidPrime(p) => write!(f, "{p} is not a valid prime"),
126 }
127 }
128}
129
130macro_rules! def_prime_static {
131 ($pn:ident, $p:literal) => {
132 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
133 pub struct $pn;
134
135 impl Prime for $pn {
136 #[inline]
137 fn as_i32(self) -> i32 {
138 $p
139 }
140
141 #[inline]
142 fn to_dyn(self) -> ValidPrime {
143 ValidPrime::new($p)
144 }
145 }
146
147 impl Serialize for $pn {
148 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
149 where
150 S: Serializer,
151 {
152 self.as_u32().serialize(serializer)
153 }
154 }
155
156 impl<'de> Deserialize<'de> for $pn {
157 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
158 where
159 D: Deserializer<'de>,
160 {
161 let p: u32 = u32::deserialize(deserializer)?;
162 $pn::try_from(p).map_err(D::Error::custom)
163 }
164 }
165
166 #[cfg(feature = "proptest")]
167 impl proptest::arbitrary::Arbitrary for $pn {
168 type Parameters = Option<NonZeroU32>;
169 type Strategy = proptest::strategy::Just<$pn>;
170
171 fn arbitrary_with(_max: Self::Parameters) -> Self::Strategy {
172 proptest::strategy::Just($pn)
177 }
178 }
179
180 impl crate::MaybeArbitrary<Option<NonZeroU32>> for $pn {}
181 };
182}
183
184macro_rules! impl_op_pn_u32 {
185 ($pn:ty, $trt:ident, $mth:ident, $trt_assign:ident, $mth_assign:ident, $operator:tt) => {
186 impl $trt<$pn> for u32 {
187 type Output = u32;
188
189 fn $mth(self, other: $pn) -> Self::Output {
190 self $operator other.as_u32()
191 }
192 }
193
194 impl $trt<u32> for $pn {
195 type Output = u32;
196
197 fn $mth(self, other: u32) -> Self::Output {
198 self.as_u32() $operator other
199 }
200 }
201
202 impl $trt<$pn> for $pn {
203 type Output = u32;
204
205 fn $mth(self, other: $pn) -> Self::Output {
206 self.as_u32() $operator other.as_u32()
207 }
208 }
209
210 impl $trt_assign<$pn> for u32 {
211 fn $mth_assign(&mut self, other: $pn) {
212 *self = *self $operator other;
213 }
214 }
215 };
216}
217
218macro_rules! impl_prime_ops {
219 ($pn:ty) => {
220 impl_op_pn_u32!($pn, Add, add, AddAssign, add_assign, +);
221 impl_op_pn_u32!($pn, Sub, sub, SubAssign, sub_assign, -);
222 impl_op_pn_u32!($pn, Mul, mul, MulAssign, mul_assign, *);
223 impl_op_pn_u32!($pn, Div, div, DivAssign, div_assign, /);
224 impl_op_pn_u32!($pn, Rem, rem, RemAssign, rem_assign, %);
225 impl_op_pn_u32!($pn, Shl, shl, ShlAssign, shl_assign, <<);
226 impl_op_pn_u32!($pn, Shr, shr, ShrAssign, shr_assign, >>);
227
228 impl PartialEq<u32> for $pn {
229 fn eq(&self, other: &u32) -> bool {
230 self.as_u32() == *other
231 }
232 }
233
234 impl PartialOrd<u32> for $pn {
235 fn partial_cmp(&self, other: &u32) -> Option<std::cmp::Ordering> {
236 self.as_u32().partial_cmp(other)
237 }
238 }
239
240 impl From<$pn> for u32 {
241 fn from(value: $pn) -> u32 {
242 value.as_u32()
243 }
244 }
245
246 impl std::fmt::Display for $pn {
247 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
248 <u32 as std::fmt::Display>::fmt(&self.as_u32(), f)
249 }
250 }
251 };
252}
253
254macro_rules! impl_try_from {
255 ($p:tt) => {
256 impl_try_from!(@ $p, $p);
257 };
258 (@ $pn:ty, $pn_ex:expr) => {
260 impl std::convert::TryFrom<u32> for $pn {
261 type Error = PrimeError;
262
263 fn try_from(p: u32) -> Result<Self, PrimeError> {
264 if $pn_ex == p {
265 Ok($pn_ex)
266 } else {
267 Err(PrimeError::InvalidPrime(p))
268 }
269 }
270 }
271 };
272}
273
274use def_prime_static;
276use impl_op_pn_u32;
277use impl_prime_ops;
278use impl_try_from;
279
280pub const fn power_mod(p: u32, mut b: u32, mut e: u32) -> u32 {
282 assert!(p > 0);
284 let mut result: u32 = 1;
285 while e > 0 {
286 if (e & 1) == 1 {
287 result = ((result as u64) * (b as u64) % (p as u64)) as u32;
288 }
289 b = (((b as u64) * (b as u64)) % (p as u64)) as u32;
290 e >>= 1;
291 }
292 result
293}
294
295pub const fn log2(n: usize) -> usize {
306 std::mem::size_of::<usize>() * 8 - 1 - n.leading_zeros() as usize
307}
308
309pub fn logp<P: Prime>(p: P, mut n: u32) -> u32 {
311 let mut result = 0u32;
312 while n > 0 {
313 n /= p.as_u32();
314 result += 1;
315 }
316 result
317}
318
319pub fn factor_pk<P: Prime>(p: P, mut n: u32) -> (u32, u32) {
321 if n == 0 {
322 return (0, 0);
323 }
324 let mut k = 0;
325 while n.is_multiple_of(p.as_u32()) {
326 n /= p.as_u32();
327 k += 1;
328 }
329 (k, n)
330}
331
332pub fn inverse<P: Prime>(p: P, k: u32) -> u32 {
334 use crate::constants::{INVERSE_TABLE, MAX_PRIME, PRIME_TO_INDEX_MAP};
335 assert!(k > 0 && p > k);
336
337 if p <= MAX_PRIME as u32 {
338 unsafe { *INVERSE_TABLE[PRIME_TO_INDEX_MAP[p.as_usize()]].get_unchecked(k as usize) }
340 } else {
341 power_mod(p.as_u32(), k, p.as_u32() - 2)
342 }
343}
344
345#[inline(always)]
346pub fn minus_one_to_the_n<P: Prime>(p: P, i: i32) -> u32 {
347 if i % 2 == 0 { 1 } else { p - 1 }
348}
349
350#[cfg(test)]
351pub(crate) mod tests {
352 use super::{Prime, ValidPrime, binomial::Binomial, inverse, iter::BinomialIterator};
353 use crate::{
354 constants::PRIMES,
355 prime::{PrimeError, is_prime},
356 };
357
358 #[test]
359 fn validprime_test() {
360 for p in (0..(1 << 16)).filter(|&p| is_prime(p)) {
361 assert_eq!(ValidPrime::new(p), p);
362 }
363 }
364
365 #[test]
366 fn validprime_invalid() {
367 assert_eq!(
368 ValidPrime::try_from(4).unwrap_err(),
369 PrimeError::InvalidPrime(4)
370 );
371 assert_eq!(
372 "4".parse::<ValidPrime>().unwrap_err(),
373 PrimeError::InvalidPrime(4)
374 );
375 assert_eq!(
376 "4.0".parse::<ValidPrime>().unwrap_err(),
377 PrimeError::NotAnInteger("4.0".parse::<u32>().unwrap_err())
378 );
379 }
380
381 #[test]
382 fn inverse_test() {
383 for &p in PRIMES.iter() {
384 let p = ValidPrime::new(p);
385 for k in 1..p.as_u32() {
386 assert_eq!((inverse(p, k) * k) % p, 1);
387 }
388 }
389 }
390
391 #[test]
392 fn binomial_test() {
393 let entries = [[2, 2, 1, 0], [2, 3, 1, 1], [3, 1090, 730, 1], [7, 3, 2, 3]];
394
395 for entry in &entries {
396 assert_eq!(
397 entry[3],
398 u32::binomial(ValidPrime::new(entry[0]), entry[1], entry[2])
399 );
400 }
401 }
402
403 #[test]
404 fn binomial_vs_monomial() {
405 for &p in &[2, 3, 5, 7, 11] {
406 let p = ValidPrime::new(p);
407 for l in 0..20 {
408 for m in 0..20 {
409 assert_eq!(u32::binomial(p, l + m, m), u32::multinomial(p, &mut [l, m]))
410 }
411 }
412 }
413 }
414
415 fn binomial_full(n: u32, j: u32) -> u32 {
416 let mut res = 1;
417 for k in j + 1..=n {
418 res *= k;
419 }
420 for k in 1..=(n - j) {
421 res /= k;
422 }
423 res
424 }
425
426 #[test]
427 fn binomial_cmp() {
428 for n in 0..12 {
429 for j in 0..=n {
430 let ans = binomial_full(n, j);
431 for &p in &[2, 3, 5, 7, 11] {
432 assert_eq!(
433 u32::binomial(ValidPrime::new(p), n, j),
434 ans % p,
435 "{n} choose {j} mod {p}"
436 );
437 }
438 assert_eq!(u32::binomial4(n, j), ans % 4, "{n} choose {j} mod 4");
439 if n > 1 {
442 assert_eq!(
443 u32::binomial4_rec(n, j),
444 ans % 4,
445 "{n} choose {j} mod 4 rec"
446 );
447 }
448 }
449 }
450 }
451
452 #[test]
453 fn binomial_iterator() {
454 let mut iter = BinomialIterator::new(4);
455 assert_eq!(iter.next(), Some(0b1111));
456 assert_eq!(iter.next(), Some(0b10111));
457 assert_eq!(iter.next(), Some(0b11011));
458 assert_eq!(iter.next(), Some(0b11101));
459 assert_eq!(iter.next(), Some(0b11110));
460 assert_eq!(iter.next(), Some(0b100111));
461 assert_eq!(iter.next(), Some(0b101011));
462 assert_eq!(iter.next(), Some(0b101101));
463 assert_eq!(iter.next(), Some(0b101110));
464 assert_eq!(iter.next(), Some(0b110011));
465 }
466}