fp/vector/
impl_fqslicemut.rs

1use std::cmp::Ordering;
2
3use itertools::Itertools;
4
5use super::inner::{FqSlice, FqSliceMut, FqVector};
6use crate::{
7    constants,
8    field::{Field, element::FieldElement},
9    limb::Limb,
10    prime::{Prime, ValidPrime},
11};
12
13impl<'a, F: Field> FqSliceMut<'a, F> {
14    pub fn prime(&self) -> ValidPrime {
15        self.fq().characteristic().to_dyn()
16    }
17
18    pub fn add_basis_element(&mut self, index: usize, value: FieldElement<F>) {
19        assert_eq!(self.fq(), value.field());
20        if self.fq().q() == 2 {
21            let pair = self.fq().limb_bit_index_pair(index + self.start());
22            self.limbs_mut()[pair.limb] ^= self.fq().encode(value) << pair.bit_index;
23        } else {
24            let mut x = self.as_slice().entry(index);
25            x += value;
26            self.set_entry(index, x);
27        }
28    }
29
30    pub fn set_entry(&mut self, index: usize, value: FieldElement<F>) {
31        assert_eq!(self.fq(), value.field());
32        assert!(index < self.as_slice().len());
33        let bit_mask = self.fq().bitmask();
34        let limb_index = self.fq().limb_bit_index_pair(index + self.start());
35        let mut result = self.limbs()[limb_index.limb];
36        result &= !(bit_mask << limb_index.bit_index);
37        result |= self.fq().encode(value) << limb_index.bit_index;
38        self.limbs_mut()[limb_index.limb] = result;
39    }
40
41    fn reduce_limbs(&mut self) {
42        let fq = self.fq();
43        if fq.q() != 2 {
44            let limb_range = self.as_slice().limb_range();
45
46            for limb in self.limbs_mut()[limb_range].iter_mut() {
47                *limb = fq.reduce(*limb);
48            }
49        }
50    }
51
52    pub fn scale(&mut self, c: FieldElement<F>) {
53        assert_eq!(self.fq(), c.field());
54        let fq = self.fq();
55
56        if fq.q() == 2 {
57            if c == fq.zero() {
58                self.set_to_zero();
59            }
60            return;
61        }
62
63        let limb_range = self.as_slice().limb_range();
64        if limb_range.is_empty() {
65            return;
66        }
67        let (min_mask, max_mask) = self.as_slice().limb_masks();
68
69        let limb = self.limbs()[limb_range.start];
70        let masked_limb = limb & min_mask;
71        let rest_limb = limb & !min_mask;
72        self.limbs_mut()[limb_range.start] = fq.fma_limb(0, masked_limb, c.clone()) | rest_limb;
73
74        let inner_range = self.as_slice().limb_range_inner();
75        for limb in self.limbs_mut()[inner_range].iter_mut() {
76            *limb = fq.fma_limb(0, *limb, c.clone());
77        }
78        if limb_range.len() > 1 {
79            let full_limb = self.limbs()[limb_range.end - 1];
80            let masked_limb = full_limb & max_mask;
81            let rest_limb = full_limb & !max_mask;
82            self.limbs_mut()[limb_range.end - 1] = fq.fma_limb(0, masked_limb, c) | rest_limb;
83        }
84        self.reduce_limbs();
85    }
86
87    pub fn set_to_zero(&mut self) {
88        let limb_range = self.as_slice().limb_range();
89        if limb_range.is_empty() {
90            return;
91        }
92        let (min_mask, max_mask) = self.as_slice().limb_masks();
93        self.limbs_mut()[limb_range.start] &= !min_mask;
94
95        let inner_range = self.as_slice().limb_range_inner();
96        for limb in self.limbs_mut()[inner_range].iter_mut() {
97            *limb = 0;
98        }
99        self.limbs_mut()[limb_range.end - 1] &= !max_mask;
100    }
101
102    pub fn add(&mut self, other: FqSlice<'_, F>, c: FieldElement<F>) {
103        assert_eq!(self.fq(), c.field());
104        assert_eq!(self.fq(), other.fq());
105
106        if self.as_slice().is_empty() {
107            return;
108        }
109
110        if self.fq().q() == 2 {
111            if c != self.fq().zero() {
112                match self.as_slice().offset().cmp(&other.offset()) {
113                    Ordering::Equal => self.add_shift_none(other, self.fq().one()),
114                    Ordering::Less => self.add_shift_left(other, self.fq().one()),
115                    Ordering::Greater => self.add_shift_right(other, self.fq().one()),
116                };
117            }
118        } else {
119            match self.as_slice().offset().cmp(&other.offset()) {
120                Ordering::Equal => self.add_shift_none(other, c),
121                Ordering::Less => self.add_shift_left(other, c),
122                Ordering::Greater => self.add_shift_right(other, c),
123            };
124        }
125    }
126
127    pub fn add_offset(&mut self, other: FqSlice<'_, F>, c: FieldElement<F>, offset: usize) {
128        self.slice_mut(offset, self.as_slice().len())
129            .add(other.restrict(offset, other.len()), c)
130    }
131
132    /// Adds v otimes w to self.
133    pub fn add_tensor(
134        &mut self,
135        offset: usize,
136        coeff: FieldElement<F>,
137        left: FqSlice<F>,
138        right: FqSlice<F>,
139    ) {
140        assert_eq!(self.fq(), coeff.field());
141        assert_eq!(self.fq(), left.fq());
142        assert_eq!(self.fq(), right.fq());
143
144        let right_dim = right.len();
145
146        for (i, v) in left.iter_nonzero() {
147            let entry = v * coeff.clone();
148            self.slice_mut(offset + i * right_dim, offset + (i + 1) * right_dim)
149                .add(right, entry);
150        }
151    }
152
153    /// TODO: improve efficiency
154    pub fn assign(&mut self, other: FqSlice<'_, F>) {
155        assert_eq!(self.fq(), other.fq());
156        if self.as_slice().offset() != other.offset() {
157            self.set_to_zero();
158            self.add(other, self.fq().one());
159            return;
160        }
161        let target_range = self.as_slice().limb_range();
162        let source_range = other.limb_range();
163
164        if target_range.is_empty() {
165            return;
166        }
167
168        let (min_mask, max_mask) = other.limb_masks();
169
170        let result = other.limbs()[source_range.start] & min_mask;
171        self.limbs_mut()[target_range.start] &= !min_mask;
172        self.limbs_mut()[target_range.start] |= result;
173
174        let target_inner_range = self.as_slice().limb_range_inner();
175        let source_inner_range = other.limb_range_inner();
176        if !target_inner_range.is_empty() && !source_inner_range.is_empty() {
177            self.limbs_mut()[target_inner_range]
178                .clone_from_slice(&other.limbs()[source_inner_range]);
179        }
180
181        let result = other.limbs()[source_range.end - 1] & max_mask;
182        self.limbs_mut()[target_range.end - 1] &= !max_mask;
183        self.limbs_mut()[target_range.end - 1] |= result;
184    }
185
186    /// Shifts the entries of `self` to the left by `shift` entries.
187    pub fn shl_assign(&mut self, shift: usize) {
188        if shift == 0 {
189            return;
190        }
191        if self.start() == 0 && shift.is_multiple_of(self.fq().entries_per_limb()) {
192            let limb_shift = shift / self.fq().entries_per_limb();
193            *self.end_mut() -= shift;
194            let new_num_limbs = self.fq().number(self.end());
195            for idx in 0..new_num_limbs {
196                self.limbs_mut()[idx] = self.limbs()[idx + limb_shift];
197            }
198        } else {
199            unimplemented!()
200        }
201    }
202
203    /// Adds `c` * `other` to `self`. `other` must have the same length, offset, and prime as self.
204    pub fn add_shift_none(&mut self, other: FqSlice<'_, F>, c: FieldElement<F>) {
205        assert_eq!(self.fq(), c.field());
206        assert_eq!(self.fq(), other.fq());
207        let fq = self.fq();
208
209        let target_range = self.as_slice().limb_range();
210        let source_range = other.limb_range();
211
212        let (min_mask, max_mask) = other.limb_masks();
213
214        self.limbs_mut()[target_range.start] = fq.fma_limb(
215            self.limbs()[target_range.start],
216            other.limbs()[source_range.start] & min_mask,
217            c.clone(),
218        );
219        self.limbs_mut()[target_range.start] = fq.reduce(self.limbs()[target_range.start]);
220
221        let target_inner_range = self.as_slice().limb_range_inner();
222        let source_inner_range = other.limb_range_inner();
223        if !source_inner_range.is_empty() {
224            for (left, right) in self.limbs_mut()[target_inner_range]
225                .iter_mut()
226                .zip_eq(&other.limbs()[source_inner_range])
227            {
228                *left = fq.fma_limb(*left, *right, c.clone());
229                *left = fq.reduce(*left);
230            }
231        }
232        if source_range.len() > 1 {
233            // The first and last limbs are distinct, so we process the last.
234            self.limbs_mut()[target_range.end - 1] = fq.fma_limb(
235                self.limbs()[target_range.end - 1],
236                other.limbs()[source_range.end - 1] & max_mask,
237                c,
238            );
239            self.limbs_mut()[target_range.end - 1] = fq.reduce(self.limbs()[target_range.end - 1]);
240        }
241    }
242
243    fn add_shift_left(&mut self, other: FqSlice<'_, F>, c: FieldElement<F>) {
244        struct AddShiftLeftData {
245            offset_shift: usize,
246            tail_shift: usize,
247            zero_bits: usize,
248            min_source_limb: usize,
249            min_target_limb: usize,
250            number_of_source_limbs: usize,
251            number_of_target_limbs: usize,
252            min_mask: Limb,
253            max_mask: Limb,
254        }
255
256        impl AddShiftLeftData {
257            fn new<F: Field>(fq: F, target: FqSlice<'_, F>, source: FqSlice<'_, F>) -> Self {
258                debug_assert!(target.prime() == source.prime());
259                debug_assert!(target.offset() <= source.offset());
260                debug_assert!(
261                    target.len() == source.len(),
262                    "self.dim {} not equal to other.dim {}",
263                    target.len(),
264                    source.len()
265                );
266                let offset_shift = source.offset() - target.offset();
267                let bit_length = fq.bit_length();
268                let entries_per_limb = fq.entries_per_limb();
269                let usable_bits_per_limb = bit_length * entries_per_limb;
270                let tail_shift = usable_bits_per_limb - offset_shift;
271                let zero_bits = constants::BITS_PER_LIMB - usable_bits_per_limb;
272                let source_range = source.limb_range();
273                let target_range = target.limb_range();
274                let min_source_limb = source_range.start;
275                let min_target_limb = target_range.start;
276                let number_of_source_limbs = source_range.len();
277                let number_of_target_limbs = target_range.len();
278                let (min_mask, max_mask) = source.limb_masks();
279
280                Self {
281                    offset_shift,
282                    tail_shift,
283                    zero_bits,
284                    min_source_limb,
285                    min_target_limb,
286                    number_of_source_limbs,
287                    number_of_target_limbs,
288                    min_mask,
289                    max_mask,
290                }
291            }
292
293            fn mask_first_limb<F: Field>(&self, other: FqSlice<'_, F>, i: usize) -> Limb {
294                (other.limbs()[i] & self.min_mask) >> self.offset_shift
295            }
296
297            fn mask_middle_limb_a<F: Field>(&self, other: FqSlice<'_, F>, i: usize) -> Limb {
298                other.limbs()[i] >> self.offset_shift
299            }
300
301            fn mask_middle_limb_b<F: Field>(&self, other: FqSlice<'_, F>, i: usize) -> Limb {
302                (other.limbs()[i] << (self.tail_shift + self.zero_bits)) >> self.zero_bits
303            }
304
305            fn mask_last_limb_a<F: Field>(&self, other: FqSlice<'_, F>, i: usize) -> Limb {
306                let source_limb_masked = other.limbs()[i] & self.max_mask;
307                source_limb_masked << self.tail_shift
308            }
309
310            fn mask_last_limb_b<F: Field>(&self, other: FqSlice<'_, F>, i: usize) -> Limb {
311                let source_limb_masked = other.limbs()[i] & self.max_mask;
312                source_limb_masked >> self.offset_shift
313            }
314        }
315
316        let dat = AddShiftLeftData::new(self.fq(), self.as_slice(), other);
317        let mut i = 0;
318        {
319            self.limbs_mut()[i + dat.min_target_limb] = self.fq().fma_limb(
320                self.limbs()[i + dat.min_target_limb],
321                dat.mask_first_limb(other, i + dat.min_source_limb),
322                c.clone(),
323            );
324        }
325        for i in 1..dat.number_of_source_limbs - 1 {
326            self.limbs_mut()[i + dat.min_target_limb] = self.fq().fma_limb(
327                self.limbs()[i + dat.min_target_limb],
328                dat.mask_middle_limb_a(other, i + dat.min_source_limb),
329                c.clone(),
330            );
331            self.limbs_mut()[i + dat.min_target_limb - 1] = self.fq().fma_limb(
332                self.limbs()[i + dat.min_target_limb - 1],
333                dat.mask_middle_limb_b(other, i + dat.min_source_limb),
334                c.clone(),
335            );
336            self.limbs_mut()[i + dat.min_target_limb - 1] =
337                self.fq().reduce(self.limbs()[i + dat.min_target_limb - 1]);
338        }
339        i = dat.number_of_source_limbs - 1;
340        if i > 0 {
341            self.limbs_mut()[i + dat.min_target_limb - 1] = self.fq().fma_limb(
342                self.limbs()[i + dat.min_target_limb - 1],
343                dat.mask_last_limb_a(other, i + dat.min_source_limb),
344                c.clone(),
345            );
346            self.limbs_mut()[i + dat.min_target_limb - 1] =
347                self.fq().reduce(self.limbs()[i + dat.min_target_limb - 1]);
348            if dat.number_of_source_limbs == dat.number_of_target_limbs {
349                self.limbs_mut()[i + dat.min_target_limb] = self.fq().fma_limb(
350                    self.limbs()[i + dat.min_target_limb],
351                    dat.mask_last_limb_b(other, i + dat.min_source_limb),
352                    c,
353                );
354                self.limbs_mut()[i + dat.min_target_limb] =
355                    self.fq().reduce(self.limbs()[i + dat.min_target_limb]);
356            }
357        } else {
358            self.limbs_mut()[i + dat.min_target_limb] =
359                self.fq().reduce(self.limbs()[i + dat.min_target_limb]);
360        }
361    }
362
363    fn add_shift_right(&mut self, other: FqSlice<'_, F>, c: FieldElement<F>) {
364        struct AddShiftRightData {
365            offset_shift: usize,
366            tail_shift: usize,
367            zero_bits: usize,
368            min_source_limb: usize,
369            min_target_limb: usize,
370            number_of_source_limbs: usize,
371            number_of_target_limbs: usize,
372            min_mask: Limb,
373            max_mask: Limb,
374        }
375
376        impl AddShiftRightData {
377            fn new<F: Field>(fq: F, target: FqSlice<'_, F>, source: FqSlice<'_, F>) -> Self {
378                debug_assert!(target.prime() == source.prime());
379                debug_assert!(target.offset() >= source.offset());
380                debug_assert!(
381                    target.len() == source.len(),
382                    "self.dim {} not equal to other.dim {}",
383                    target.len(),
384                    source.len()
385                );
386                let offset_shift = target.offset() - source.offset();
387                let bit_length = fq.bit_length();
388                let entries_per_limb = fq.entries_per_limb();
389                let usable_bits_per_limb = bit_length * entries_per_limb;
390                let tail_shift = usable_bits_per_limb - offset_shift;
391                let zero_bits = constants::BITS_PER_LIMB - usable_bits_per_limb;
392                let source_range = source.limb_range();
393                let target_range = target.limb_range();
394                let min_source_limb = source_range.start;
395                let min_target_limb = target_range.start;
396                let number_of_source_limbs = source_range.len();
397                let number_of_target_limbs = target_range.len();
398                let (min_mask, max_mask) = source.limb_masks();
399                Self {
400                    offset_shift,
401                    tail_shift,
402                    zero_bits,
403                    min_source_limb,
404                    min_target_limb,
405                    number_of_source_limbs,
406                    number_of_target_limbs,
407                    min_mask,
408                    max_mask,
409                }
410            }
411
412            fn mask_first_limb_a<F: Field>(&self, other: FqSlice<'_, F>, i: usize) -> Limb {
413                let source_limb_masked = other.limbs()[i] & self.min_mask;
414                (source_limb_masked << (self.offset_shift + self.zero_bits)) >> self.zero_bits
415            }
416
417            fn mask_first_limb_b<F: Field>(&self, other: FqSlice<'_, F>, i: usize) -> Limb {
418                let source_limb_masked = other.limbs()[i] & self.min_mask;
419                source_limb_masked >> self.tail_shift
420            }
421
422            fn mask_middle_limb_a<F: Field>(&self, other: FqSlice<'_, F>, i: usize) -> Limb {
423                (other.limbs()[i] << (self.offset_shift + self.zero_bits)) >> self.zero_bits
424            }
425
426            fn mask_middle_limb_b<F: Field>(&self, other: FqSlice<'_, F>, i: usize) -> Limb {
427                other.limbs()[i] >> self.tail_shift
428            }
429
430            fn mask_last_limb_a<F: Field>(&self, other: FqSlice<'_, F>, i: usize) -> Limb {
431                let source_limb_masked = other.limbs()[i] & self.max_mask;
432                source_limb_masked << self.offset_shift
433            }
434
435            fn mask_last_limb_b<F: Field>(&self, other: FqSlice<'_, F>, i: usize) -> Limb {
436                let source_limb_masked = other.limbs()[i] & self.max_mask;
437                source_limb_masked >> self.tail_shift
438            }
439        }
440
441        let dat = AddShiftRightData::new(self.fq(), self.as_slice(), other);
442        let mut i = 0;
443        {
444            self.limbs_mut()[i + dat.min_target_limb] = self.fq().fma_limb(
445                self.limbs()[i + dat.min_target_limb],
446                dat.mask_first_limb_a(other, i + dat.min_source_limb),
447                c.clone(),
448            );
449            self.limbs_mut()[i + dat.min_target_limb] =
450                self.fq().reduce(self.limbs()[i + dat.min_target_limb]);
451            if dat.number_of_target_limbs > 1 {
452                self.limbs_mut()[i + dat.min_target_limb + 1] = self.fq().fma_limb(
453                    self.limbs()[i + dat.min_target_limb + 1],
454                    dat.mask_first_limb_b(other, i + dat.min_source_limb),
455                    c.clone(),
456                );
457            }
458        }
459        for i in 1..dat.number_of_source_limbs - 1 {
460            self.limbs_mut()[i + dat.min_target_limb] = self.fq().fma_limb(
461                self.limbs()[i + dat.min_target_limb],
462                dat.mask_middle_limb_a(other, i + dat.min_source_limb),
463                c.clone(),
464            );
465            self.limbs_mut()[i + dat.min_target_limb] =
466                self.fq().reduce(self.limbs()[i + dat.min_target_limb]);
467            self.limbs_mut()[i + dat.min_target_limb + 1] = self.fq().fma_limb(
468                self.limbs()[i + dat.min_target_limb + 1],
469                dat.mask_middle_limb_b(other, i + dat.min_source_limb),
470                c.clone(),
471            );
472        }
473        i = dat.number_of_source_limbs - 1;
474        if i > 0 {
475            self.limbs_mut()[i + dat.min_target_limb] = self.fq().fma_limb(
476                self.limbs()[i + dat.min_target_limb],
477                dat.mask_last_limb_a(other, i + dat.min_source_limb),
478                c.clone(),
479            );
480            self.limbs_mut()[i + dat.min_target_limb] =
481                self.fq().reduce(self.limbs()[i + dat.min_target_limb]);
482            if dat.number_of_target_limbs > dat.number_of_source_limbs {
483                self.limbs_mut()[i + dat.min_target_limb + 1] = self.fq().fma_limb(
484                    self.limbs()[i + dat.min_target_limb + 1],
485                    dat.mask_last_limb_b(other, i + dat.min_source_limb),
486                    c.clone(),
487                );
488            }
489        }
490        if dat.number_of_target_limbs > dat.number_of_source_limbs {
491            self.limbs_mut()[i + dat.min_target_limb + 1] =
492                self.fq().reduce(self.limbs()[i + dat.min_target_limb + 1]);
493        }
494    }
495
496    /// Given a mask v, add the `v[i]`th entry of `other` to the `i`th entry of `self`.
497    pub fn add_masked(&mut self, other: FqSlice<'_, F>, c: FieldElement<F>, mask: &[usize]) {
498        // TODO: If this ends up being a bottleneck, try to use PDEP/PEXT
499        assert_eq!(self.fq(), c.field());
500        assert_eq!(self.fq(), other.fq());
501        assert_eq!(self.as_slice().len(), mask.len());
502        for (i, &x) in mask.iter().enumerate() {
503            let entry = other.entry(x);
504            if entry != self.fq().zero() {
505                self.add_basis_element(i, entry * c.clone());
506            }
507        }
508    }
509
510    /// Given a mask v, add the `i`th entry of `other` to the `v[i]`th entry of `self`.
511    pub fn add_unmasked(&mut self, other: FqSlice<'_, F>, c: FieldElement<F>, mask: &[usize]) {
512        assert_eq!(self.fq(), c.field());
513        assert_eq!(self.fq(), other.fq());
514        assert!(other.len() <= mask.len());
515        for (i, v) in other.iter_nonzero() {
516            self.add_basis_element(mask[i], v * c.clone());
517        }
518    }
519
520    pub fn slice_mut(&mut self, start: usize, end: usize) -> FqSliceMut<'_, F> {
521        assert!(start <= end && end <= self.as_slice().len());
522        let orig_start = self.start();
523
524        FqSliceMut::new(
525            self.fq(),
526            self.limbs_mut(),
527            orig_start + start,
528            orig_start + end,
529        )
530    }
531
532    #[inline]
533    #[must_use]
534    pub fn as_slice(&self) -> FqSlice<'_, F> {
535        FqSlice::new(self.fq(), self.limbs(), self.start(), self.end())
536    }
537
538    /// Generates a version of itself with a shorter lifetime
539    #[inline]
540    #[must_use]
541    pub fn copy(&mut self) -> FqSliceMut<'_, F> {
542        let start = self.start();
543        let end = self.end();
544
545        FqSliceMut::new(self.fq(), self.limbs_mut(), start, end)
546    }
547}
548
549impl<'a, F: Field> From<&'a mut FqVector<F>> for FqSliceMut<'a, F> {
550    fn from(v: &'a mut FqVector<F>) -> Self {
551        v.slice_mut(0, v.len())
552    }
553}