fp/vector/fp_wrapper/
mod.rs

1//! This module provides convenience wrappers around the contents of [`crate::vector::inner`] in the
2//! special case where the field is a prime field. The main purpose is to put [`FqVector`] for
3//! different fields `Fp<P>` into a single enum, and to simplify scalars to just `u32`s instead of
4//! rather unwieldy `FieldElement<Fp<P>>`s. It does the same for the various slice structs.
5//!
6//! The main magic occurs in the macros, such as `dispatch_vector_inner`, which we use to provide
7//! wrapper functions around the `FqVector` functions. To maintain consistency, we define the API
8//! in this file irrespective of whether the `odd-primes` feature is enabled or not, and it is the
9//! macros that will take care of making the distinction.
10//!
11//! Note: Since we still want to have scalars simply be `u32`s, even when `odd-primes` is disabled,
12//! we can't simply define `type FpVector = FqVector<Fp<2>>` like we previously did: we need to use
13//! a transparent wrapper.
14
15use std::{convert::TryInto, io};
16
17use itertools::Itertools;
18use serde::{Deserialize, Deserializer, Serialize, Serializer};
19
20use super::iter::{FqVectorIterator, FqVectorNonZeroIterator};
21use crate::{
22    field::{Fp, field_internal::FieldInternal},
23    limb::Limb,
24    prime::Prime,
25    vector::inner::{FqSlice, FqSliceMut, FqVector},
26};
27
28mod helpers;
29
30#[cfg(feature = "odd-primes")]
31#[macro_use]
32mod macros_generic;
33#[cfg(feature = "odd-primes")]
34#[allow(unused_imports)] // Nightly regression
35use macros_generic::{dispatch_struct, dispatch_vector, impl_try_into, use_primes};
36
37#[cfg(not(feature = "odd-primes"))]
38#[macro_use]
39mod macros_2;
40#[cfg(not(feature = "odd-primes"))]
41#[allow(unused_imports)] // Nightly regression
42use macros_2::{dispatch_struct, dispatch_vector, impl_try_into, use_primes};
43
44use_primes!();
45
46dispatch_struct! {
47    #[derive(Debug, Hash, Eq, PartialEq, Clone)]
48    pub FpVector from FqVector
49}
50
51dispatch_struct! {
52    #[derive(Debug, Copy, Clone)]
53    pub FpSlice<'a> from FqSlice
54}
55
56dispatch_struct! {
57    #[derive(Debug)]
58    pub FpSliceMut<'a> from FqSliceMut
59}
60
61dispatch_struct! {
62    pub FpVectorIterator<'a> from FqVectorIterator
63}
64
65dispatch_struct! {
66    pub FpVectorNonZeroIterator<'a> from FqVectorNonZeroIterator
67}
68
69impl FpVector {
70    dispatch_vector! {
71        pub fn prime(&self) -> ValidPrime;
72        pub fn len(&self) -> usize;
73        pub fn is_empty(&self) -> bool;
74        pub fn @scale(&mut self, c: u32);
75        pub fn set_to_zero(&mut self);
76        pub fn @entry(&self, index: usize) -> u32;
77        pub fn @set_entry(&mut self, index: usize, value: u32);
78        pub fn assign(&mut self, other: &Self);
79        pub fn assign_partial(&mut self, other: &Self);
80        pub fn @add(&mut self, other: &Self, c: u32);
81        pub fn @add_offset(&mut self, other: &Self, c: u32, offset: usize);
82        pub fn slice(&self, start: usize, end: usize) -> (dispatch FpSlice<'_>);
83        pub fn as_slice(&self) -> (dispatch FpSlice<'_>);
84        pub fn slice_mut(&mut self, start: usize, end: usize) -> (dispatch FpSliceMut<'_>);
85        pub fn as_slice_mut(&mut self) -> (dispatch FpSliceMut<'_>);
86        pub fn is_zero(&self) -> bool;
87        pub fn iter(&self) -> (dispatch FpVectorIterator<'_>);
88        pub fn iter_nonzero(&self) -> (dispatch FpVectorNonZeroIterator<'_>);
89        pub fn extend_len(&mut self, dim: usize);
90        pub fn set_scratch_vector_size(&mut self, dim: usize);
91        pub fn @add_basis_element(&mut self, index: usize, value: u32);
92        pub fn @copy_from_slice(&mut self, slice: &[u32]);
93        pub fn @add_truncate(&mut self, other: &Self, c: u32) -> (Option<()>);
94        pub fn sign_rule(&self, other: &Self) -> bool;
95        pub fn @add_carry(&mut self, other: &Self, c: u32, rest: &mut [Self]) -> bool;
96        pub fn @first_nonzero(&self) -> (Option<(usize, u32)>);
97        pub fn density(&self) -> f32;
98
99        pub(crate) fn limbs(&self) -> (&[Limb]);
100
101        pub fn new<P: Prime>(p: P, len: usize) -> (from FqVector);
102        pub fn new_with_capacity<P: Prime>(p: P, len: usize, capacity: usize) -> (from FqVector);
103
104        pub fn update_from_bytes(&mut self, data: &mut impl io::Read) -> (io::Result<()>);
105        pub fn from_bytes<P: Prime>(p: P, len: usize, data: &mut impl io::Read) -> (from io FqVector);
106        pub fn to_bytes(&self, buffer: &mut impl io::Write) -> (io::Result<()>);
107    }
108
109    pub fn from_slice<P: Prime>(p: P, slice: &[u32]) -> Self {
110        let mut v = Self::new(p, slice.len());
111        v.copy_from_slice(slice);
112        v
113    }
114
115    // Convenient for some matrix methods
116    pub(crate) fn num_limbs(p: ValidPrime, len: usize) -> usize {
117        Fp::new(p).number(len)
118    }
119
120    // Convenient for some matrix methods
121    pub(crate) fn padded_len(p: ValidPrime, len: usize) -> usize {
122        Self::num_limbs(p, len) * Fp::new(p).entries_per_limb()
123    }
124}
125
126impl<'a> FpSlice<'a> {
127    dispatch_vector! {
128        pub(crate) fn new<P: Prime>(p: P, limbs: &'a [Limb], start: usize, end: usize) -> (from FqSlice);
129        pub fn prime(&self) -> ValidPrime;
130        pub fn len(&self) -> usize;
131        pub fn is_empty(&self) -> bool;
132        pub fn @entry(&self, index: usize) -> u32;
133        pub fn iter(self) -> (dispatch FpVectorIterator<'a>);
134        pub fn iter_nonzero(self) -> (dispatch FpVectorNonZeroIterator<'a>);
135        pub fn @first_nonzero(&self) -> (Option<(usize, u32)>);
136        pub fn is_zero(&self) -> bool;
137        pub fn restrict(self, start: usize, end: usize) -> (dispatch FpSlice<'a>);
138        pub fn to_owned(self) -> (dispatch FpVector);
139
140        pub(crate) fn limbs(&self) -> (&[Limb]);
141    }
142}
143
144impl<'a> FpSliceMut<'a> {
145    dispatch_vector! {
146        pub(crate) fn new<P: Prime>(p: P, limbs: &'a mut [Limb], start: usize, end: usize) -> (from FqSliceMut);
147        pub fn prime(&self) -> ValidPrime;
148        pub fn @scale(&mut self, c: u32);
149        pub fn set_to_zero(&mut self);
150        pub fn @add(&mut self, other: FpSlice, c: u32);
151        pub fn @add_offset(&mut self, other: FpSlice, c: u32, offset: usize);
152        pub fn assign(&mut self, other: FpSlice);
153        pub fn shl_assign(&mut self, shift: usize);
154        pub fn @set_entry(&mut self, index: usize, value: u32);
155        pub fn as_slice(&self) -> (dispatch FpSlice<'_>);
156        pub fn slice_mut(&mut self, start: usize, end: usize) -> (dispatch FpSliceMut<'_>);
157        pub fn @add_basis_element(&mut self, index: usize, value: u32);
158        pub fn copy(&mut self) -> (dispatch FpSliceMut<'_>);
159        pub fn @add_masked(&mut self, other: FpSlice, c: u32, mask: &[usize]);
160        pub fn @add_unmasked(&mut self, other: FpSlice, c: u32, mask: &[usize]);
161        pub fn @add_tensor(&mut self, offset: usize, coeff: u32, @left: FpSlice, right: FpSlice);
162
163        pub(crate) fn limbs(&self) -> (&[Limb]);
164        pub(crate) fn limbs_mut(&mut self) -> (&mut [Limb]);
165    }
166}
167
168impl FpVectorIterator<'_> {
169    dispatch_vector! {
170        pub fn skip_n(&mut self, n: usize);
171    }
172}
173
174impl std::fmt::Display for FpVector {
175    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
176        self.as_slice().fmt(f)
177    }
178}
179
180impl std::fmt::Display for FpSlice<'_> {
181    /// # Example
182    /// ```
183    /// # use fp::vector::FpVector;
184    /// # use fp::prime::ValidPrime;
185    /// let v = FpVector::from_slice(ValidPrime::new(2), &[0, 1, 0]);
186    /// assert_eq!(&format!("{v}"), "[0, 1, 0]");
187    /// assert_eq!(&format!("{v:#}"), "010");
188    /// ```
189    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
190        if f.alternate() {
191            for v in self.iter() {
192                // If self.p >= 11, this will look funky
193                write!(f, "{v}")?;
194            }
195            Ok(())
196        } else {
197            write!(f, "[{}]", self.iter().format(", "))
198        }
199    }
200}
201
202impl From<&FpVector> for Vec<u32> {
203    fn from(v: &FpVector) -> Self {
204        v.iter().collect()
205    }
206}
207
208impl std::ops::AddAssign<&Self> for FpVector {
209    fn add_assign(&mut self, other: &Self) {
210        self.add(other, 1);
211    }
212}
213
214impl Iterator for FpVectorIterator<'_> {
215    type Item = u32;
216
217    dispatch_vector! {
218        fn @next(&mut self) -> (Option<u32>);
219    }
220}
221
222impl Iterator for FpVectorNonZeroIterator<'_> {
223    type Item = (usize, u32);
224
225    dispatch_vector! {
226        fn @next(&mut self) -> (Option<(usize, u32)>);
227    }
228}
229
230impl<'a> IntoIterator for &'a FpVector {
231    type IntoIter = FpVectorIterator<'a>;
232    type Item = u32;
233
234    fn into_iter(self) -> Self::IntoIter {
235        self.iter()
236    }
237}
238
239impl_from!();
240impl_try_into!();
241
242// `FpVector`'s serde format routes through `FqVector<Fp<ValidPrime>>`, giving a uniform
243// representation across all primes: the prime is encoded in the data, so round-tripping works
244// without out-of-band context. This is what the zarr save system relies on.
245//
246// Callers whose wire format must stay a flat `Vec<u32>` (e.g. the sseq_gui web frontend, which
247// reads vectors as plain JS arrays) should declare their fields as `Vec<u32>` and convert at the
248// boundary rather than relying on `FpVector`'s own serde impl.
249impl Serialize for FpVector {
250    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
251    where
252        S: Serializer,
253    {
254        use crate::field::Fp;
255        let p = self.prime();
256        let fq = Fp::new(p);
257        let v = FqVector::from_raw_parts(fq, self.len(), self.limbs().to_vec());
258        v.serialize(serializer)
259    }
260}
261
262impl<'de> Deserialize<'de> for FpVector {
263    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
264    where
265        D: Deserializer<'de>,
266    {
267        use crate::field::{Field, Fp};
268        let v: FqVector<Fp<ValidPrime>> = FqVector::deserialize(deserializer)?;
269        let p = v.fq().characteristic();
270        // Reconstruct an `FpVector` by round-tripping through the binary limb format. The
271        // intermediate byte buffer is small and only allocated on deserialize.
272        let mut bytes = Vec::new();
273        v.to_bytes(&mut bytes).map_err(serde::de::Error::custom)?;
274        Self::from_bytes(p, v.len(), &mut &bytes[..]).map_err(serde::de::Error::custom)
275    }
276}
277
278impl<'a, 'b> From<&'a mut FpSliceMut<'b>> for FpSliceMut<'a> {
279    fn from(slice: &'a mut FpSliceMut<'b>) -> Self {
280        slice.copy()
281    }
282}
283
284impl<'a, 'b> From<&'a FpSlice<'b>> for FpSlice<'a> {
285    fn from(slice: &'a FpSlice<'b>) -> Self {
286        *slice
287    }
288}
289
290impl<'a, 'b> From<&'a FpSliceMut<'b>> for FpSlice<'a> {
291    fn from(slice: &'a FpSliceMut<'b>) -> Self {
292        slice.as_slice()
293    }
294}
295
296impl<'a> From<&'a FpVector> for FpSlice<'a> {
297    fn from(v: &'a FpVector) -> Self {
298        v.as_slice()
299    }
300}
301
302impl<'a> From<&'a mut FpVector> for FpSliceMut<'a> {
303    fn from(v: &'a mut FpVector) -> Self {
304        v.as_slice_mut()
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use rand::Rng;
311    use rstest::rstest;
312
313    use crate::{prime::ValidPrime, vector::FpVector};
314
315    fn random_vector(p: u32, dimension: usize) -> Vec<u32> {
316        let mut rng = rand::rng();
317        (0..dimension).map(|_| rng.random_range(0..p)).collect()
318    }
319
320    #[rstest]
321    #[trace]
322    fn test_add_carry(#[values(2)] p: u32, #[values(10, 20, 70, 100, 1000)] dim: usize) {
323        use std::fmt::Write;
324
325        let p = ValidPrime::new(p);
326        const E_MAX: usize = 4;
327        let pto_the_e_max = (p * p * p * p) * p;
328        let mut v = Vec::with_capacity(E_MAX + 1);
329        let mut w = Vec::with_capacity(E_MAX + 1);
330        for _ in 0..=E_MAX {
331            v.push(FpVector::new(p, dim));
332            w.push(FpVector::new(p, dim));
333        }
334        let v_arr = random_vector(pto_the_e_max, dim);
335        let w_arr = random_vector(pto_the_e_max, dim);
336        for i in 0..dim {
337            let mut ev = v_arr[i];
338            let mut ew = w_arr[i];
339            for e in 0..=E_MAX {
340                v[e].set_entry(i, ev % p);
341                w[e].set_entry(i, ew % p);
342                ev /= p;
343                ew /= p;
344            }
345        }
346
347        println!("in  : {v_arr:?}");
348        for (e, val) in v.iter().enumerate() {
349            println!("in {e}: {val}");
350        }
351        println!();
352
353        println!("in  : {w_arr:?}");
354        for (e, val) in w.iter().enumerate() {
355            println!("in {e}: {val}");
356        }
357        println!();
358
359        for e in 0..=E_MAX {
360            let (first, rest) = v[e..].split_at_mut(1);
361            first[0].add_carry(&w[e], 1, rest);
362        }
363
364        let mut vec_result = vec![0; dim];
365        for (i, entry) in vec_result.iter_mut().enumerate() {
366            for e in (0..=E_MAX).rev() {
367                *entry *= p;
368                *entry += v[e].entry(i);
369            }
370        }
371
372        for (e, val) in v.iter().enumerate() {
373            println!("out{e}: {val}");
374        }
375        println!();
376
377        let mut comparison_result = vec![0; dim];
378        for i in 0..dim {
379            comparison_result[i] = (v_arr[i] + w_arr[i]) % pto_the_e_max;
380        }
381        println!("out : {comparison_result:?}");
382
383        let mut diffs = Vec::new();
384        let mut diffs_str = String::new();
385        for i in 0..dim {
386            if vec_result[i] != comparison_result[i] {
387                diffs.push((i, comparison_result[i], vec_result[i]));
388                let _ = write!(
389                    diffs_str,
390                    "\nIn position {} expected {} got {}. v[i] = {}, w[i] = {}.",
391                    i, comparison_result[i], vec_result[i], v_arr[i], w_arr[i]
392                );
393            }
394        }
395        assert!(diffs.is_empty(), "{}", diffs_str);
396    }
397}