1use std::cmp::Ordering;
2
3use itertools::Itertools;
4
5use super::inner::{FqSlice, FqSliceMut, FqVector};
6use crate::{
7 constants,
8 field::{Field, element::FieldElement},
9 limb::Limb,
10 prime::{Prime, ValidPrime},
11};
12
13impl<'a, F: Field> FqSliceMut<'a, F> {
14 pub fn prime(&self) -> ValidPrime {
15 self.fq().characteristic().to_dyn()
16 }
17
18 pub fn add_basis_element(&mut self, index: usize, value: FieldElement<F>) {
19 assert_eq!(self.fq(), value.field());
20 if self.fq().q() == 2 {
21 let pair = self.fq().limb_bit_index_pair(index + self.start());
22 self.limbs_mut()[pair.limb] ^= self.fq().encode(value) << pair.bit_index;
23 } else {
24 let mut x = self.as_slice().entry(index);
25 x += value;
26 self.set_entry(index, x);
27 }
28 }
29
30 pub fn set_entry(&mut self, index: usize, value: FieldElement<F>) {
31 assert_eq!(self.fq(), value.field());
32 assert!(index < self.as_slice().len());
33 let bit_mask = self.fq().bitmask();
34 let limb_index = self.fq().limb_bit_index_pair(index + self.start());
35 let mut result = self.limbs()[limb_index.limb];
36 result &= !(bit_mask << limb_index.bit_index);
37 result |= self.fq().encode(value) << limb_index.bit_index;
38 self.limbs_mut()[limb_index.limb] = result;
39 }
40
41 fn reduce_limbs(&mut self) {
42 let fq = self.fq();
43 if fq.q() != 2 {
44 let limb_range = self.as_slice().limb_range();
45
46 for limb in self.limbs_mut()[limb_range].iter_mut() {
47 *limb = fq.reduce(*limb);
48 }
49 }
50 }
51
52 pub fn scale(&mut self, c: FieldElement<F>) {
53 assert_eq!(self.fq(), c.field());
54 let fq = self.fq();
55
56 if fq.q() == 2 {
57 if c == fq.zero() {
58 self.set_to_zero();
59 }
60 return;
61 }
62
63 let limb_range = self.as_slice().limb_range();
64 if limb_range.is_empty() {
65 return;
66 }
67 let (min_mask, max_mask) = self.as_slice().limb_masks();
68
69 let limb = self.limbs()[limb_range.start];
70 let masked_limb = limb & min_mask;
71 let rest_limb = limb & !min_mask;
72 self.limbs_mut()[limb_range.start] = fq.fma_limb(0, masked_limb, c.clone()) | rest_limb;
73
74 let inner_range = self.as_slice().limb_range_inner();
75 for limb in self.limbs_mut()[inner_range].iter_mut() {
76 *limb = fq.fma_limb(0, *limb, c.clone());
77 }
78 if limb_range.len() > 1 {
79 let full_limb = self.limbs()[limb_range.end - 1];
80 let masked_limb = full_limb & max_mask;
81 let rest_limb = full_limb & !max_mask;
82 self.limbs_mut()[limb_range.end - 1] = fq.fma_limb(0, masked_limb, c) | rest_limb;
83 }
84 self.reduce_limbs();
85 }
86
87 pub fn set_to_zero(&mut self) {
88 let limb_range = self.as_slice().limb_range();
89 if limb_range.is_empty() {
90 return;
91 }
92 let (min_mask, max_mask) = self.as_slice().limb_masks();
93 self.limbs_mut()[limb_range.start] &= !min_mask;
94
95 let inner_range = self.as_slice().limb_range_inner();
96 for limb in self.limbs_mut()[inner_range].iter_mut() {
97 *limb = 0;
98 }
99 self.limbs_mut()[limb_range.end - 1] &= !max_mask;
100 }
101
102 pub fn add(&mut self, other: FqSlice<'_, F>, c: FieldElement<F>) {
103 assert_eq!(self.fq(), c.field());
104 assert_eq!(self.fq(), other.fq());
105
106 if self.as_slice().is_empty() {
107 return;
108 }
109
110 if self.fq().q() == 2 {
111 if c != self.fq().zero() {
112 match self.as_slice().offset().cmp(&other.offset()) {
113 Ordering::Equal => self.add_shift_none(other, self.fq().one()),
114 Ordering::Less => self.add_shift_left(other, self.fq().one()),
115 Ordering::Greater => self.add_shift_right(other, self.fq().one()),
116 };
117 }
118 } else {
119 match self.as_slice().offset().cmp(&other.offset()) {
120 Ordering::Equal => self.add_shift_none(other, c),
121 Ordering::Less => self.add_shift_left(other, c),
122 Ordering::Greater => self.add_shift_right(other, c),
123 };
124 }
125 }
126
127 pub fn add_offset(&mut self, other: FqSlice<'_, F>, c: FieldElement<F>, offset: usize) {
128 self.slice_mut(offset, self.as_slice().len())
129 .add(other.restrict(offset, other.len()), c)
130 }
131
132 pub fn add_tensor(
134 &mut self,
135 offset: usize,
136 coeff: FieldElement<F>,
137 left: FqSlice<F>,
138 right: FqSlice<F>,
139 ) {
140 assert_eq!(self.fq(), coeff.field());
141 assert_eq!(self.fq(), left.fq());
142 assert_eq!(self.fq(), right.fq());
143
144 let right_dim = right.len();
145
146 for (i, v) in left.iter_nonzero() {
147 let entry = v * coeff.clone();
148 self.slice_mut(offset + i * right_dim, offset + (i + 1) * right_dim)
149 .add(right, entry);
150 }
151 }
152
153 pub fn assign(&mut self, other: FqSlice<'_, F>) {
155 assert_eq!(self.fq(), other.fq());
156 if self.as_slice().offset() != other.offset() {
157 self.set_to_zero();
158 self.add(other, self.fq().one());
159 return;
160 }
161 let target_range = self.as_slice().limb_range();
162 let source_range = other.limb_range();
163
164 if target_range.is_empty() {
165 return;
166 }
167
168 let (min_mask, max_mask) = other.limb_masks();
169
170 let result = other.limbs()[source_range.start] & min_mask;
171 self.limbs_mut()[target_range.start] &= !min_mask;
172 self.limbs_mut()[target_range.start] |= result;
173
174 let target_inner_range = self.as_slice().limb_range_inner();
175 let source_inner_range = other.limb_range_inner();
176 if !target_inner_range.is_empty() && !source_inner_range.is_empty() {
177 self.limbs_mut()[target_inner_range]
178 .clone_from_slice(&other.limbs()[source_inner_range]);
179 }
180
181 let result = other.limbs()[source_range.end - 1] & max_mask;
182 self.limbs_mut()[target_range.end - 1] &= !max_mask;
183 self.limbs_mut()[target_range.end - 1] |= result;
184 }
185
186 pub fn shl_assign(&mut self, shift: usize) {
188 if shift == 0 {
189 return;
190 }
191 if self.start() == 0 && shift.is_multiple_of(self.fq().entries_per_limb()) {
192 let limb_shift = shift / self.fq().entries_per_limb();
193 *self.end_mut() -= shift;
194 let new_num_limbs = self.fq().number(self.end());
195 for idx in 0..new_num_limbs {
196 self.limbs_mut()[idx] = self.limbs()[idx + limb_shift];
197 }
198 } else {
199 unimplemented!()
200 }
201 }
202
203 pub fn add_shift_none(&mut self, other: FqSlice<'_, F>, c: FieldElement<F>) {
205 assert_eq!(self.fq(), c.field());
206 assert_eq!(self.fq(), other.fq());
207 let fq = self.fq();
208
209 let target_range = self.as_slice().limb_range();
210 let source_range = other.limb_range();
211
212 let (min_mask, max_mask) = other.limb_masks();
213
214 self.limbs_mut()[target_range.start] = fq.fma_limb(
215 self.limbs()[target_range.start],
216 other.limbs()[source_range.start] & min_mask,
217 c.clone(),
218 );
219 self.limbs_mut()[target_range.start] = fq.reduce(self.limbs()[target_range.start]);
220
221 let target_inner_range = self.as_slice().limb_range_inner();
222 let source_inner_range = other.limb_range_inner();
223 if !source_inner_range.is_empty() {
224 for (left, right) in self.limbs_mut()[target_inner_range]
225 .iter_mut()
226 .zip_eq(&other.limbs()[source_inner_range])
227 {
228 *left = fq.fma_limb(*left, *right, c.clone());
229 *left = fq.reduce(*left);
230 }
231 }
232 if source_range.len() > 1 {
233 self.limbs_mut()[target_range.end - 1] = fq.fma_limb(
235 self.limbs()[target_range.end - 1],
236 other.limbs()[source_range.end - 1] & max_mask,
237 c,
238 );
239 self.limbs_mut()[target_range.end - 1] = fq.reduce(self.limbs()[target_range.end - 1]);
240 }
241 }
242
243 fn add_shift_left(&mut self, other: FqSlice<'_, F>, c: FieldElement<F>) {
244 struct AddShiftLeftData {
245 offset_shift: usize,
246 tail_shift: usize,
247 zero_bits: usize,
248 min_source_limb: usize,
249 min_target_limb: usize,
250 number_of_source_limbs: usize,
251 number_of_target_limbs: usize,
252 min_mask: Limb,
253 max_mask: Limb,
254 }
255
256 impl AddShiftLeftData {
257 fn new<F: Field>(fq: F, target: FqSlice<'_, F>, source: FqSlice<'_, F>) -> Self {
258 debug_assert!(target.prime() == source.prime());
259 debug_assert!(target.offset() <= source.offset());
260 debug_assert!(
261 target.len() == source.len(),
262 "self.dim {} not equal to other.dim {}",
263 target.len(),
264 source.len()
265 );
266 let offset_shift = source.offset() - target.offset();
267 let bit_length = fq.bit_length();
268 let entries_per_limb = fq.entries_per_limb();
269 let usable_bits_per_limb = bit_length * entries_per_limb;
270 let tail_shift = usable_bits_per_limb - offset_shift;
271 let zero_bits = constants::BITS_PER_LIMB - usable_bits_per_limb;
272 let source_range = source.limb_range();
273 let target_range = target.limb_range();
274 let min_source_limb = source_range.start;
275 let min_target_limb = target_range.start;
276 let number_of_source_limbs = source_range.len();
277 let number_of_target_limbs = target_range.len();
278 let (min_mask, max_mask) = source.limb_masks();
279
280 Self {
281 offset_shift,
282 tail_shift,
283 zero_bits,
284 min_source_limb,
285 min_target_limb,
286 number_of_source_limbs,
287 number_of_target_limbs,
288 min_mask,
289 max_mask,
290 }
291 }
292
293 fn mask_first_limb<F: Field>(&self, other: FqSlice<'_, F>, i: usize) -> Limb {
294 (other.limbs()[i] & self.min_mask) >> self.offset_shift
295 }
296
297 fn mask_middle_limb_a<F: Field>(&self, other: FqSlice<'_, F>, i: usize) -> Limb {
298 other.limbs()[i] >> self.offset_shift
299 }
300
301 fn mask_middle_limb_b<F: Field>(&self, other: FqSlice<'_, F>, i: usize) -> Limb {
302 (other.limbs()[i] << (self.tail_shift + self.zero_bits)) >> self.zero_bits
303 }
304
305 fn mask_last_limb_a<F: Field>(&self, other: FqSlice<'_, F>, i: usize) -> Limb {
306 let source_limb_masked = other.limbs()[i] & self.max_mask;
307 source_limb_masked << self.tail_shift
308 }
309
310 fn mask_last_limb_b<F: Field>(&self, other: FqSlice<'_, F>, i: usize) -> Limb {
311 let source_limb_masked = other.limbs()[i] & self.max_mask;
312 source_limb_masked >> self.offset_shift
313 }
314 }
315
316 let dat = AddShiftLeftData::new(self.fq(), self.as_slice(), other);
317 let mut i = 0;
318 {
319 self.limbs_mut()[i + dat.min_target_limb] = self.fq().fma_limb(
320 self.limbs()[i + dat.min_target_limb],
321 dat.mask_first_limb(other, i + dat.min_source_limb),
322 c.clone(),
323 );
324 }
325 for i in 1..dat.number_of_source_limbs - 1 {
326 self.limbs_mut()[i + dat.min_target_limb] = self.fq().fma_limb(
327 self.limbs()[i + dat.min_target_limb],
328 dat.mask_middle_limb_a(other, i + dat.min_source_limb),
329 c.clone(),
330 );
331 self.limbs_mut()[i + dat.min_target_limb - 1] = self.fq().fma_limb(
332 self.limbs()[i + dat.min_target_limb - 1],
333 dat.mask_middle_limb_b(other, i + dat.min_source_limb),
334 c.clone(),
335 );
336 self.limbs_mut()[i + dat.min_target_limb - 1] =
337 self.fq().reduce(self.limbs()[i + dat.min_target_limb - 1]);
338 }
339 i = dat.number_of_source_limbs - 1;
340 if i > 0 {
341 self.limbs_mut()[i + dat.min_target_limb - 1] = self.fq().fma_limb(
342 self.limbs()[i + dat.min_target_limb - 1],
343 dat.mask_last_limb_a(other, i + dat.min_source_limb),
344 c.clone(),
345 );
346 self.limbs_mut()[i + dat.min_target_limb - 1] =
347 self.fq().reduce(self.limbs()[i + dat.min_target_limb - 1]);
348 if dat.number_of_source_limbs == dat.number_of_target_limbs {
349 self.limbs_mut()[i + dat.min_target_limb] = self.fq().fma_limb(
350 self.limbs()[i + dat.min_target_limb],
351 dat.mask_last_limb_b(other, i + dat.min_source_limb),
352 c,
353 );
354 self.limbs_mut()[i + dat.min_target_limb] =
355 self.fq().reduce(self.limbs()[i + dat.min_target_limb]);
356 }
357 } else {
358 self.limbs_mut()[i + dat.min_target_limb] =
359 self.fq().reduce(self.limbs()[i + dat.min_target_limb]);
360 }
361 }
362
363 fn add_shift_right(&mut self, other: FqSlice<'_, F>, c: FieldElement<F>) {
364 struct AddShiftRightData {
365 offset_shift: usize,
366 tail_shift: usize,
367 zero_bits: usize,
368 min_source_limb: usize,
369 min_target_limb: usize,
370 number_of_source_limbs: usize,
371 number_of_target_limbs: usize,
372 min_mask: Limb,
373 max_mask: Limb,
374 }
375
376 impl AddShiftRightData {
377 fn new<F: Field>(fq: F, target: FqSlice<'_, F>, source: FqSlice<'_, F>) -> Self {
378 debug_assert!(target.prime() == source.prime());
379 debug_assert!(target.offset() >= source.offset());
380 debug_assert!(
381 target.len() == source.len(),
382 "self.dim {} not equal to other.dim {}",
383 target.len(),
384 source.len()
385 );
386 let offset_shift = target.offset() - source.offset();
387 let bit_length = fq.bit_length();
388 let entries_per_limb = fq.entries_per_limb();
389 let usable_bits_per_limb = bit_length * entries_per_limb;
390 let tail_shift = usable_bits_per_limb - offset_shift;
391 let zero_bits = constants::BITS_PER_LIMB - usable_bits_per_limb;
392 let source_range = source.limb_range();
393 let target_range = target.limb_range();
394 let min_source_limb = source_range.start;
395 let min_target_limb = target_range.start;
396 let number_of_source_limbs = source_range.len();
397 let number_of_target_limbs = target_range.len();
398 let (min_mask, max_mask) = source.limb_masks();
399 Self {
400 offset_shift,
401 tail_shift,
402 zero_bits,
403 min_source_limb,
404 min_target_limb,
405 number_of_source_limbs,
406 number_of_target_limbs,
407 min_mask,
408 max_mask,
409 }
410 }
411
412 fn mask_first_limb_a<F: Field>(&self, other: FqSlice<'_, F>, i: usize) -> Limb {
413 let source_limb_masked = other.limbs()[i] & self.min_mask;
414 (source_limb_masked << (self.offset_shift + self.zero_bits)) >> self.zero_bits
415 }
416
417 fn mask_first_limb_b<F: Field>(&self, other: FqSlice<'_, F>, i: usize) -> Limb {
418 let source_limb_masked = other.limbs()[i] & self.min_mask;
419 source_limb_masked >> self.tail_shift
420 }
421
422 fn mask_middle_limb_a<F: Field>(&self, other: FqSlice<'_, F>, i: usize) -> Limb {
423 (other.limbs()[i] << (self.offset_shift + self.zero_bits)) >> self.zero_bits
424 }
425
426 fn mask_middle_limb_b<F: Field>(&self, other: FqSlice<'_, F>, i: usize) -> Limb {
427 other.limbs()[i] >> self.tail_shift
428 }
429
430 fn mask_last_limb_a<F: Field>(&self, other: FqSlice<'_, F>, i: usize) -> Limb {
431 let source_limb_masked = other.limbs()[i] & self.max_mask;
432 source_limb_masked << self.offset_shift
433 }
434
435 fn mask_last_limb_b<F: Field>(&self, other: FqSlice<'_, F>, i: usize) -> Limb {
436 let source_limb_masked = other.limbs()[i] & self.max_mask;
437 source_limb_masked >> self.tail_shift
438 }
439 }
440
441 let dat = AddShiftRightData::new(self.fq(), self.as_slice(), other);
442 let mut i = 0;
443 {
444 self.limbs_mut()[i + dat.min_target_limb] = self.fq().fma_limb(
445 self.limbs()[i + dat.min_target_limb],
446 dat.mask_first_limb_a(other, i + dat.min_source_limb),
447 c.clone(),
448 );
449 self.limbs_mut()[i + dat.min_target_limb] =
450 self.fq().reduce(self.limbs()[i + dat.min_target_limb]);
451 if dat.number_of_target_limbs > 1 {
452 self.limbs_mut()[i + dat.min_target_limb + 1] = self.fq().fma_limb(
453 self.limbs()[i + dat.min_target_limb + 1],
454 dat.mask_first_limb_b(other, i + dat.min_source_limb),
455 c.clone(),
456 );
457 }
458 }
459 for i in 1..dat.number_of_source_limbs - 1 {
460 self.limbs_mut()[i + dat.min_target_limb] = self.fq().fma_limb(
461 self.limbs()[i + dat.min_target_limb],
462 dat.mask_middle_limb_a(other, i + dat.min_source_limb),
463 c.clone(),
464 );
465 self.limbs_mut()[i + dat.min_target_limb] =
466 self.fq().reduce(self.limbs()[i + dat.min_target_limb]);
467 self.limbs_mut()[i + dat.min_target_limb + 1] = self.fq().fma_limb(
468 self.limbs()[i + dat.min_target_limb + 1],
469 dat.mask_middle_limb_b(other, i + dat.min_source_limb),
470 c.clone(),
471 );
472 }
473 i = dat.number_of_source_limbs - 1;
474 if i > 0 {
475 self.limbs_mut()[i + dat.min_target_limb] = self.fq().fma_limb(
476 self.limbs()[i + dat.min_target_limb],
477 dat.mask_last_limb_a(other, i + dat.min_source_limb),
478 c.clone(),
479 );
480 self.limbs_mut()[i + dat.min_target_limb] =
481 self.fq().reduce(self.limbs()[i + dat.min_target_limb]);
482 if dat.number_of_target_limbs > dat.number_of_source_limbs {
483 self.limbs_mut()[i + dat.min_target_limb + 1] = self.fq().fma_limb(
484 self.limbs()[i + dat.min_target_limb + 1],
485 dat.mask_last_limb_b(other, i + dat.min_source_limb),
486 c.clone(),
487 );
488 }
489 }
490 if dat.number_of_target_limbs > dat.number_of_source_limbs {
491 self.limbs_mut()[i + dat.min_target_limb + 1] =
492 self.fq().reduce(self.limbs()[i + dat.min_target_limb + 1]);
493 }
494 }
495
496 pub fn add_masked(&mut self, other: FqSlice<'_, F>, c: FieldElement<F>, mask: &[usize]) {
498 assert_eq!(self.fq(), c.field());
500 assert_eq!(self.fq(), other.fq());
501 assert_eq!(self.as_slice().len(), mask.len());
502 for (i, &x) in mask.iter().enumerate() {
503 let entry = other.entry(x);
504 if entry != self.fq().zero() {
505 self.add_basis_element(i, entry * c.clone());
506 }
507 }
508 }
509
510 pub fn add_unmasked(&mut self, other: FqSlice<'_, F>, c: FieldElement<F>, mask: &[usize]) {
512 assert_eq!(self.fq(), c.field());
513 assert_eq!(self.fq(), other.fq());
514 assert!(other.len() <= mask.len());
515 for (i, v) in other.iter_nonzero() {
516 self.add_basis_element(mask[i], v * c.clone());
517 }
518 }
519
520 pub fn slice_mut(&mut self, start: usize, end: usize) -> FqSliceMut<'_, F> {
521 assert!(start <= end && end <= self.as_slice().len());
522 let orig_start = self.start();
523
524 FqSliceMut::new(
525 self.fq(),
526 self.limbs_mut(),
527 orig_start + start,
528 orig_start + end,
529 )
530 }
531
532 #[inline]
533 #[must_use]
534 pub fn as_slice(&self) -> FqSlice<'_, F> {
535 FqSlice::new(self.fq(), self.limbs(), self.start(), self.end())
536 }
537
538 #[inline]
540 #[must_use]
541 pub fn copy(&mut self) -> FqSliceMut<'_, F> {
542 let start = self.start();
543 let end = self.end();
544
545 FqSliceMut::new(self.fq(), self.limbs_mut(), start, end)
546 }
547}
548
549impl<'a, F: Field> From<&'a mut FqVector<F>> for FqSliceMut<'a, F> {
550 fn from(v: &'a mut FqVector<F>) -> Self {
551 v.slice_mut(0, v.len())
552 }
553}