1use std::{fmt, io, ops::Range};
2
3use aligned_vec::AVec;
4use either::Either;
5use itertools::Itertools;
6use maybe_rayon::prelude::*;
7use serde::{Deserialize, Deserializer, Serialize};
8
9use super::{QuasiInverse, Subspace};
10use crate::{
11 field::{Field, Fp, field_internal::FieldInternal},
12 limb::Limb,
13 matrix::m4ri::M4riTable,
14 prime::{self, ValidPrime},
15 vector::{FpSlice, FpSliceMut, FpVector},
16};
17
18#[derive(Clone, Serialize)]
24pub struct Matrix {
25 fp: Fp<ValidPrime>,
26 rows: usize,
27 physical_rows: usize,
28 columns: usize,
29 data: AVec<Limb>,
30 stride: usize,
31 pub(crate) pivots: Vec<isize>,
35}
36
37impl<'de> Deserialize<'de> for Matrix {
42 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
43 where
44 D: Deserializer<'de>,
45 {
46 use serde::de::Error;
47
48 #[derive(Deserialize)]
49 struct Raw {
50 fp: Fp<ValidPrime>,
51 rows: usize,
52 physical_rows: usize,
53 columns: usize,
54 data: AVec<Limb>,
55 stride: usize,
56 pivots: Vec<isize>,
57 }
58
59 let raw = Raw::deserialize(deserializer)?;
60 let expected_stride = raw.fp.number(raw.columns);
61 if raw.stride != expected_stride {
62 return Err(D::Error::custom(format!(
63 "Matrix stride {} does not match expected {} for columns={}",
64 raw.stride, expected_stride, raw.columns,
65 )));
66 }
67 if raw.physical_rows < raw.rows {
68 return Err(D::Error::custom(format!(
69 "Matrix physical_rows {} less than rows {}",
70 raw.physical_rows, raw.rows,
71 )));
72 }
73 if raw.data.len() != raw.physical_rows * raw.stride {
74 return Err(D::Error::custom(format!(
75 "Matrix data length {} does not match physical_rows*stride = {}*{} = {}",
76 raw.data.len(),
77 raw.physical_rows,
78 raw.stride,
79 raw.physical_rows * raw.stride,
80 )));
81 }
82 if !raw.pivots.is_empty() && raw.pivots.len() != raw.columns {
84 return Err(D::Error::custom(format!(
85 "Matrix pivots length {} must be 0 or columns = {}",
86 raw.pivots.len(),
87 raw.columns,
88 )));
89 }
90 Ok(Self {
91 fp: raw.fp,
92 rows: raw.rows,
93 physical_rows: raw.physical_rows,
94 columns: raw.columns,
95 data: raw.data,
96 stride: raw.stride,
97 pivots: raw.pivots,
98 })
99 }
100}
101
102impl PartialEq for Matrix {
103 fn eq(&self, other: &Self) -> bool {
104 self.data == other.data
105 }
106}
107
108impl Eq for Matrix {}
109
110impl Matrix {
111 pub fn new(p: ValidPrime, rows: usize, columns: usize) -> Self {
114 Self::new_with_capacity(p, rows, columns, rows, columns)
115 }
116
117 pub fn new_with_capacity(
118 p: ValidPrime,
119 rows: usize,
120 columns: usize,
121 rows_capacity: usize,
122 columns_capacity: usize,
123 ) -> Self {
124 let fp = Fp::new(p);
125 let stride = fp.number(columns_capacity);
126 let physical_rows = get_physical_rows(p, rows_capacity);
127 let mut data = AVec::with_capacity(0, physical_rows * stride);
128 data.resize(physical_rows * stride, 0);
129
130 Self {
131 fp,
132 rows,
133 physical_rows,
134 columns,
135 data,
136 stride,
137 pivots: Vec::new(),
138 }
139 }
140
141 pub fn from_data(p: ValidPrime, rows: usize, columns: usize, mut data: Vec<Limb>) -> Self {
142 let fp = Fp::new(p);
143 let stride = fp.number(columns);
144 let physical_rows = get_physical_rows(p, rows);
145 data.resize(physical_rows * stride, 0);
146 Self {
147 fp,
148 rows,
149 physical_rows,
150 columns,
151 data: AVec::from_iter(0, data),
152 stride,
153 pivots: Vec::new(),
154 }
155 }
156
157 pub fn identity(p: ValidPrime, dim: usize) -> Self {
158 let mut matrix = Self::new(p, dim, dim);
159 matrix.as_slice_mut().add_identity();
160 matrix
161 }
162
163 pub fn from_bytes(
164 p: ValidPrime,
165 rows: usize,
166 columns: usize,
167 buffer: &mut impl io::Read,
168 ) -> io::Result<Self> {
169 let fp = Fp::new(p);
170 let stride = fp.number(columns);
171 let physical_rows = get_physical_rows(p, rows);
172 let mut data: AVec<Limb> = aligned_vec::avec![0; stride * physical_rows];
173 for row_idx in 0..rows {
174 let limb_range = row_to_limb_range(row_idx, stride);
175 crate::limb::from_bytes(&mut data[limb_range], buffer)?;
176 }
177 Ok(Self {
178 fp,
179 rows,
180 physical_rows,
181 columns,
182 data,
183 stride,
184 pivots: Vec::new(),
185 })
186 }
187
188 pub fn to_bytes(&self, data: &mut impl io::Write) -> io::Result<()> {
189 let limbs_per_row = self.fp.number(self.columns);
190 for row_idx in 0..self.rows() {
191 let row_range = row_idx * self.stride..row_idx * self.stride + limbs_per_row;
192 crate::limb::to_bytes(&self.data[row_range], data)?;
193 }
194 Ok(())
195 }
196
197 pub(crate) fn write_pivot(v: &[isize], buffer: &mut impl io::Write) -> io::Result<()> {
199 if cfg!(all(target_endian = "little", target_pointer_width = "64")) {
200 let buf: &[u8] =
201 unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, v.len() * 8) };
202 buffer.write_all(buf)
203 } else {
204 use byteorder::{LittleEndian, WriteBytesExt};
205 for &i in v {
206 buffer.write_i64::<LittleEndian>(i as i64)?;
207 }
208 Ok(())
209 }
210 }
211
212 pub(crate) fn read_pivot(dim: usize, data: &mut impl io::Read) -> io::Result<Vec<isize>> {
214 if cfg!(all(target_endian = "little", target_pointer_width = "64")) {
215 let mut image = vec![0; dim];
216 let buf: &mut [u8] =
217 unsafe { std::slice::from_raw_parts_mut(image.as_mut_ptr() as *mut u8, dim * 8) };
218 data.read_exact(buf)?;
219 Ok(image)
220 } else {
221 use byteorder::{LittleEndian, ReadBytesExt};
222 let mut image = Vec::with_capacity(dim);
223 for _ in 0..dim {
224 image.push(data.read_i64::<LittleEndian>()? as isize);
225 }
226 Ok(image)
227 }
228 }
229}
230
231impl Matrix {
232 pub fn prime(&self) -> ValidPrime {
233 self.fp.characteristic()
234 }
235
236 pub fn rows(&self) -> usize {
238 self.rows
239 }
240
241 pub(crate) fn physical_rows(&self) -> usize {
243 self.physical_rows
244 }
245
246 pub fn columns(&self) -> usize {
248 self.columns
249 }
250
251 pub(crate) fn stride(&self) -> usize {
252 self.stride
253 }
254
255 pub(crate) fn data(&self) -> &[Limb] {
256 &self.data
257 }
258
259 pub(crate) fn data_mut(&mut self) -> &mut [Limb] {
260 &mut self.data
261 }
262
263 pub fn initialize_pivots(&mut self) {
265 self.pivots.clear();
266 self.pivots.resize(self.columns, -1);
267 }
268
269 pub fn pivots(&self) -> &[isize] {
270 &self.pivots
271 }
272
273 pub fn pivots_mut(&mut self) -> &mut [isize] {
274 &mut self.pivots
275 }
276
277 pub fn from_rows(p: ValidPrime, input: Vec<FpVector>, columns: usize) -> Self {
280 let fp = Fp::new(p);
281 let rows = input.len();
282 let stride = fp.number(columns);
283 let physical_rows = get_physical_rows(p, rows);
284 let mut data = AVec::with_capacity(0, physical_rows * stride);
285 for row in &input {
286 data.extend_from_slice(row.limbs());
287 }
288 data.resize(physical_rows * stride, 0);
290 Self {
291 fp,
292 rows,
293 physical_rows,
294 columns,
295 data,
296 stride,
297 pivots: Vec::new(),
298 }
299 }
300
301 pub fn from_row(p: ValidPrime, row: FpVector, columns: usize) -> Self {
303 Self::from_rows(p, vec![row], columns)
304 }
305
306 pub fn from_vec(p: ValidPrime, input: &[Vec<u32>]) -> Self {
319 let fp = Fp::new(p);
320 let rows = input.len();
321 if rows == 0 {
322 return Self::new(p, 0, 0);
323 }
324 let columns = input[0].len();
325 let stride = fp.number(columns);
326 let physical_rows = get_physical_rows(p, rows);
327 let mut data = AVec::with_capacity(0, physical_rows * stride);
328 for row in input {
329 for chunk in row.chunks(fp.entries_per_limb()) {
330 data.push(fp.pack(chunk.iter().map(|x| fp.element(*x))));
331 }
332 }
333 data.resize(physical_rows * stride, 0);
335 Self {
336 fp,
337 rows,
338 physical_rows,
339 columns,
340 data,
341 stride,
342 pivots: Vec::new(),
343 }
344 }
345
346 pub fn to_vec(&self) -> Vec<Vec<u32>> {
355 self.data
356 .iter()
357 .chunks(self.stride)
358 .into_iter()
359 .map(|row| {
360 row.flat_map(|&limb| self.fp.unpack(limb).map(|x| x.val()))
361 .take(self.columns())
362 .collect()
363 })
364 .collect()
365 }
366
367 pub fn augmented_from_vec(p: ValidPrime, input: &[Vec<u32>]) -> (usize, Self) {
382 let rows = input.len();
383 let cols = input[0].len();
384 let padded_cols = FpVector::padded_len(p, cols);
385 let mut m = Self::new(p, rows, padded_cols + rows);
386
387 for (i, row) in input.iter().enumerate() {
388 for (j, &value) in row.iter().enumerate() {
389 m.row_mut(i).set_entry(j, value);
390 }
391 }
392 m.slice_mut(0, rows, padded_cols, padded_cols + rows)
393 .add_identity();
394 (padded_cols, m)
395 }
396
397 pub fn is_zero(&self) -> bool {
398 self.data.iter().all(|limb| *limb == 0)
399 }
400
401 pub fn set_to_zero(&mut self) {
402 for limb in self.data.iter_mut() {
403 *limb = 0;
404 }
405 }
406
407 pub fn assign(&mut self, other: &Self) {
408 self.data = other.data.clone();
409 }
410
411 pub fn as_slice_mut(&mut self) -> MatrixSliceMut<'_> {
412 self.slice_mut(0, self.rows(), 0, self.columns())
413 }
414
415 pub fn slice_mut(
416 &mut self,
417 row_start: usize,
418 row_end: usize,
419 col_start: usize,
420 col_end: usize,
421 ) -> MatrixSliceMut<'_> {
422 let row_range = row_start..row_end;
423 let limb_range = row_range_to_limb_range(&row_range, self.stride);
424 MatrixSliceMut {
425 fp: self.fp,
426 rows: row_range.len(),
427 data: &mut self.data[limb_range],
428 col_start,
429 col_end,
430 stride: self.stride,
431 }
432 }
433
434 pub fn row(&self, row: usize) -> FpSlice<'_> {
435 let limb_range = row_to_limb_range(row, self.stride);
436 FpSlice::new(self.prime(), &self.data[limb_range], 0, self.columns)
437 }
438
439 pub fn row_mut(&mut self, row: usize) -> FpSliceMut<'_> {
440 let limb_range = row_to_limb_range(row, self.stride);
441 FpSliceMut::new(self.prime(), &mut self.data[limb_range], 0, self.columns)
442 }
443}
444
445impl Matrix {
446 pub fn iter(&self) -> impl Iterator<Item = FpSlice<'_>> {
447 (0..self.rows()).map(move |row_idx| self.row(row_idx))
448 }
449
450 pub fn iter_mut(&mut self) -> impl Iterator<Item = FpSliceMut<'_>> {
451 let p = self.prime();
452 let columns = self.columns;
453 let logical_rows = self.rows;
454
455 if self.stride == 0 {
456 Either::Left(std::iter::empty())
457 } else {
458 let rows = self
459 .data
460 .chunks_mut(self.stride)
461 .take(logical_rows) .map(move |row| FpSliceMut::new(p, row, 0, columns));
463 Either::Right(rows)
464 }
465 }
466
467 pub fn maybe_par_iter_mut(
468 &mut self,
469 ) -> impl MaybeIndexedParallelIterator<Item = FpSliceMut<'_>> {
470 let p = self.prime();
471 let columns = self.columns;
472 let logical_rows = self.rows;
473
474 if self.stride == 0 {
475 Either::Left(maybe_rayon::empty())
476 } else {
477 let rows = self
478 .data
479 .maybe_par_chunks_mut(self.stride)
480 .take(logical_rows) .map(move |row| FpSliceMut::new(p, row, 0, columns));
482 Either::Right(rows)
483 }
484 }
485}
486
487impl fmt::Display for Matrix {
488 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
497 if f.alternate() {
498 write!(f, "{:#}", self.iter().format("\n"))
499 } else {
500 let mut it = self.iter();
501 if let Some(x) = it.next() {
502 write!(f, "[\n {x}")?;
503 } else {
504 return write!(f, "[]");
505 }
506 for x in it {
507 write!(f, ",\n {x}")?;
508 }
509 write!(f, "\n]")
510 }
511 }
512}
513
514impl fmt::Debug for Matrix {
515 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
516 <Self as fmt::Display>::fmt(self, f)
517 }
518}
519
520impl Matrix {
521 pub fn safe_row_op(&mut self, target: usize, source: usize, c: u32) {
523 assert_ne!(target, source);
524 assert!(source < self.rows());
525 assert!(target < self.rows());
526
527 let (mut target, source) = unsafe { self.split_borrow(target, source) };
528 target.add(source.as_slice(), c)
529 }
530
531 pub unsafe fn row_op(
537 &mut self,
538 target: usize,
539 source: usize,
540 pivot_column: usize,
541 prime: ValidPrime,
542 ) {
543 debug_assert_ne!(target, source);
544 let coef = self.row(target).entry(pivot_column);
545 if coef == 0 {
546 return;
547 }
548 let (mut target, source) = unsafe { self.split_borrow(target, source) };
549 target.add_offset(source.as_slice(), prime - coef, pivot_column);
550 }
551
552 pub unsafe fn row_op_naive(
557 &mut self,
558 target: usize,
559 source: usize,
560 pivot_column: usize,
561 prime: ValidPrime,
562 ) {
563 debug_assert_ne!(target, source);
564 let coef = self.row(target).entry(pivot_column);
565 if coef == 0 {
566 return;
567 }
568 let (mut target, source) = unsafe { self.split_borrow(target, source) };
569 target.add(source.as_slice(), prime - coef);
570 }
571
572 pub(crate) unsafe fn split_borrow(
577 &mut self,
578 i: usize,
579 j: usize,
580 ) -> (FpSliceMut<'_>, FpSliceMut<'_>) {
581 let ptr = self.data.as_mut_ptr();
582 let row1 = unsafe { std::slice::from_raw_parts_mut(ptr.add(i * self.stride), self.stride) };
583 let row2 = unsafe { std::slice::from_raw_parts_mut(ptr.add(j * self.stride), self.stride) };
584 (
585 FpSliceMut::new(self.prime(), row1, 0, self.columns),
586 FpSliceMut::new(self.prime(), row2, 0, self.columns),
587 )
588 }
589
590 pub fn swap_rows(&mut self, i: usize, j: usize) {
591 for limb_idx in 0..self.stride {
592 self.data
593 .swap(i * self.stride + limb_idx, j * self.stride + limb_idx);
594 }
595 }
596
597 pub fn find_pivots_permutation<T: Iterator<Item = usize>>(
600 &mut self,
601 permutation: T,
602 ) -> Vec<usize> {
603 let p = self.prime();
604 let rows = self.rows();
605 let mut pivots = Vec::with_capacity(rows);
606
607 if rows == 0 {
608 return pivots;
609 }
610
611 let mut pivot: usize = 0;
612 for pivot_column in permutation {
613 let mut pivot_row = rows;
615 for i in pivot..rows {
616 if self.row(i).entry(pivot_column) != 0 {
617 pivot_row = i;
618 break;
619 }
620 }
621 if pivot_row == rows {
622 continue;
623 }
624
625 pivots.push(pivot_column);
627
628 self.swap_rows(pivot, pivot_row);
631 let c = self.row(pivot).entry(pivot_column);
635 let c_inv = prime::inverse(p, c);
636 self.row_mut(pivot).scale(c_inv);
637 for i in pivot_row + 1..rows {
640 unsafe { self.row_op_naive(i, pivot, pivot_column, p) };
643 }
644 pivot += 1;
645 }
646 pivots
647 }
648
649 pub fn row_reduce(&mut self) -> usize {
677 let p = self.prime();
678 self.initialize_pivots();
679
680 let mut empty_rows = Vec::with_capacity(self.rows());
681
682 if p == 2 {
683 let k = std::cmp::min(8, crate::prime::log2(1 + self.rows()) * 3 / 4);
685 let mut table = M4riTable::new(k, self.columns());
686
687 for i in 0..self.rows() {
688 table.reduce_naive(&mut *self, i);
689
690 if let Some((c, _)) = self.row(i).first_nonzero() {
691 self.pivots[c] = i as isize;
692 for &row in table.rows() {
693 unsafe {
694 self.row_op(row, i, c, p);
695 }
696 }
697 table.add(c, i);
698
699 if table.len() == k {
700 table.generate(self);
701 for j in 0..table.rows()[0] {
702 table.reduce(self.row_mut(j).limbs_mut());
703 }
704 for j in i + 1..self.rows() {
705 table.reduce(self.row_mut(j).limbs_mut());
706 }
707 table.clear();
708 }
709 } else {
710 empty_rows.push(i);
711 }
712 }
713 if !table.is_empty() {
714 table.generate(self);
715 for j in 0..table.rows()[0] {
716 table.reduce(self.row_mut(j).limbs_mut());
717 }
718 table.clear();
719 }
720 } else {
721 for i in 0..self.rows() {
722 if let Some((c, v)) = self.row(i).first_nonzero() {
723 self.pivots[c] = i as isize;
724 self.row_mut(i).scale(prime::inverse(p, v));
725 for j in 0..self.rows() {
726 if i == j {
727 continue;
728 }
729 unsafe {
730 self.row_op(j, i, c, p);
731 }
732 }
733 } else {
734 empty_rows.push(i);
735 }
736 }
737 }
738
739 let old_len = self.data.len();
742 let old_data = std::mem::replace(&mut self.data, aligned_vec::avec![0; old_len]);
743
744 let mut new_row_idx = 0;
745 for old_row in self.pivots.iter_mut().filter(|row| **row >= 0) {
746 let old_row_idx = *old_row as usize;
747 let old_limb_range = row_to_limb_range(old_row_idx, self.stride);
748 let new_limb_range = row_to_limb_range(new_row_idx, self.stride);
749 self.data[new_limb_range].copy_from_slice(&old_data[old_limb_range]);
750 *old_row = new_row_idx as isize;
751 new_row_idx += 1;
752 }
753
754 new_row_idx
755 }
756}
757
758impl Matrix {
759 pub fn find_first_row_in_block(&self, first_column: usize) -> usize {
762 self.pivots[first_column..]
763 .iter()
764 .find(|&&x| x >= 0)
765 .map(|x| *x as usize)
766 .unwrap_or_else(|| self.rows())
767 }
768
769 pub fn compute_quasi_inverse(
796 &self,
797 last_target_col: usize,
798 first_source_col: usize,
799 ) -> QuasiInverse {
800 let p = self.prime();
801 let columns = self.columns();
802 let source_columns = columns - first_source_col;
803 let first_kernel_row = self.find_first_row_in_block(first_source_col);
804 let mut preimage = Self::new(p, first_kernel_row, source_columns);
805 for i in 0..first_kernel_row {
806 preimage
807 .row_mut(i)
808 .assign(self.row(i).restrict(first_source_col, columns));
809 }
810 QuasiInverse::new(Some(self.pivots()[..last_target_col].to_vec()), preimage)
811 }
812
813 pub fn compute_image(&self, last_target_col: usize, first_source_col: usize) -> Subspace {
842 let p = self.prime();
843 let first_kernel_row = self.find_first_row_in_block(first_source_col);
844 let mut image_matrix = Self::new(p, first_kernel_row, last_target_col);
845 for i in 0..first_kernel_row {
846 image_matrix
847 .row_mut(i)
848 .assign(self.row(i).restrict(0, last_target_col));
849 }
850 image_matrix.pivots = self.pivots()[..last_target_col].to_vec();
851 Subspace::from_matrix(image_matrix)
852 }
853
854 pub fn compute_kernel(&self, first_source_column: usize) -> Subspace {
890 let p = self.prime();
891 let rows = self.rows();
892 let columns = self.columns();
893 let source_dimension = columns - first_source_column;
894 let column_to_pivot_row = self.pivots();
895
896 let first_kernel_row = self.find_first_row_in_block(first_source_column);
898 let kernel_dimension = rows - first_kernel_row;
900 let mut kernel = Self::new(p, kernel_dimension, source_dimension);
901 kernel.initialize_pivots();
902
903 if kernel_dimension == 0 {
904 return Subspace::from_matrix(kernel);
905 }
906
907 for i in 0..source_dimension {
909 kernel.pivots_mut()[i] =
911 column_to_pivot_row[i + first_source_column] - first_kernel_row as isize;
912 }
913 for (i, mut row) in kernel.iter_mut().enumerate() {
915 row.assign(
916 self.row(first_kernel_row + i)
917 .restrict(first_source_column, first_source_column + source_dimension),
918 );
919 }
920 Subspace::from_matrix(kernel)
921 }
922
923 pub fn extend_column_dimension(&mut self, columns: usize) {
924 if columns > self.columns {
925 self.extend_column_capacity(columns);
926 self.columns = columns;
927 self.pivots.resize(columns, -1);
928 }
929 }
930
931 pub fn extend_column_capacity(&mut self, columns: usize) {
932 let new_stride = self.fp.number(columns);
933 if new_stride > self.stride {
934 self.data.resize(new_stride * self.physical_rows, 0);
935 for row_idx in (0..self.physical_rows).rev() {
937 let old_row_start = row_idx * self.stride;
938 let new_row_start = row_idx * new_stride;
939 let new_row_zero_part =
940 row_idx * new_stride + self.stride..(row_idx + 1) * new_stride;
941 unsafe {
943 std::ptr::copy(
944 &raw const self.data[old_row_start],
945 &raw mut self.data[new_row_start],
946 self.stride,
947 );
948 }
949 for limb in &mut self.data[new_row_zero_part] {
950 *limb = 0;
951 }
952 }
953 self.stride = new_stride;
954 }
955 }
956
957 pub fn add_row(&mut self) -> FpSliceMut<'_> {
959 if self.rows + 1 > self.physical_rows {
961 let new_physical_rows = get_physical_rows(self.prime(), self.rows + 1);
962 self.data.resize(new_physical_rows * self.stride, 0);
963 self.physical_rows = new_physical_rows;
964 }
965 self.rows += 1;
966 self.row_mut(self.rows - 1)
967 }
968
969 pub fn extend_to_surjection(
987 &mut self,
988 start_column: usize,
989 end_column: usize,
990 extra_column_capacity: usize,
991 ) -> Vec<usize> {
992 let mut added_pivots = Vec::new();
993 self.extend_column_capacity(self.columns + extra_column_capacity);
994
995 for (i, &pivot) in self.pivots.clone()[start_column..end_column]
996 .iter()
997 .enumerate()
998 {
999 if pivot >= 0 {
1000 continue;
1001 }
1002 let mut new_row = self.add_row();
1003 new_row.set_entry(i, 1);
1004 added_pivots.push(i);
1005 }
1006 added_pivots
1007 }
1008
1009 pub fn extend_image(
1020 &mut self,
1021 start_column: usize,
1022 end_column: usize,
1023 desired_image: &Subspace,
1024 extra_column_capacity: usize,
1025 ) -> Vec<usize> {
1026 let mut added_pivots = Vec::new();
1027 let desired_pivots = desired_image.pivots();
1028 let early_end_column = std::cmp::min(end_column, desired_pivots.len() + start_column);
1029
1030 self.extend_column_capacity(self.columns + extra_column_capacity);
1031
1032 for i in start_column..early_end_column {
1033 debug_assert!(
1034 self.pivots()[i] < 0 || desired_pivots[i - start_column] >= 0,
1035 "current_pivots : {:?}, desired_pivots : {:?}",
1036 self.pivots(),
1037 desired_pivots
1038 );
1039 if self.pivots()[i] >= 0 || desired_pivots[i - start_column] < 0 {
1040 continue;
1041 }
1042 let kernel_vector_row = desired_pivots[i - start_column] as usize;
1044 let new_image = desired_image.row(kernel_vector_row);
1045
1046 let mut new_row = self.add_row();
1047 new_row
1048 .slice_mut(
1049 start_column,
1050 start_column + desired_image.ambient_dimension(),
1051 )
1052 .assign(new_image);
1053
1054 added_pivots.push(i);
1055 }
1056 added_pivots
1057 }
1058
1059 pub fn apply(&self, mut result: FpSliceMut, coeff: u32, input: FpSlice) {
1077 debug_assert_eq!(input.len(), self.rows());
1078 for i in 0..input.len() {
1079 result.add(self.row(i), (coeff * input.entry(i)) % self.prime());
1080 }
1081 }
1082
1083 pub fn trim(&mut self, row_start: usize, row_end: usize, col_start: usize) {
1084 let mut new = Self::new(self.prime(), row_end - row_start, self.columns - col_start);
1085 for (i, mut row) in new.iter_mut().enumerate() {
1086 row.assign(self.row(row_start + i).restrict(col_start, self.columns));
1087 }
1088 std::mem::swap(self, &mut new);
1089 }
1090
1091 pub fn rotate_down(&mut self, range: Range<usize>, shift: usize) {
1093 let limb_range = row_range_to_limb_range(&range, self.stride);
1094 self.data[limb_range].rotate_right(shift * self.stride)
1095 }
1096}
1097
1098impl std::ops::MulAssign<u32> for Matrix {
1099 fn mul_assign(&mut self, rhs: u32) {
1100 #[allow(clippy::suspicious_op_assign_impl)]
1101 let rhs = rhs % self.prime();
1102 for mut row in self.iter_mut() {
1103 row.scale(rhs);
1104 }
1105 }
1106}
1107
1108impl std::ops::AddAssign<&Self> for Matrix {
1109 fn add_assign(&mut self, rhs: &Self) {
1110 assert_eq!(self.prime(), rhs.prime());
1111 assert_eq!(self.columns(), rhs.columns());
1112 assert_eq!(self.rows(), rhs.rows());
1113
1114 for (i, mut row) in self.iter_mut().enumerate() {
1115 row.add(rhs.row(i), 1);
1116 }
1117 }
1118}
1119
1120#[cfg(feature = "proptest")]
1121pub mod arbitrary {
1122 use proptest::prelude::*;
1123
1124 use super::*;
1125 use crate::{
1126 field::Fp,
1127 vector::{FqVector, arbitrary::FqVectorArbParams},
1128 };
1129
1130 pub const MAX_ROWS: usize = 100;
1131 pub const MAX_COLUMNS: usize = 100;
1132
1133 #[derive(Debug, Clone)]
1134 pub struct MatrixArbParams {
1135 pub p: Option<ValidPrime>,
1136 pub rows: BoxedStrategy<usize>,
1137 pub columns: BoxedStrategy<usize>,
1138 }
1139
1140 impl Default for MatrixArbParams {
1141 fn default() -> Self {
1142 Self {
1143 p: None,
1144 rows: (1..=MAX_ROWS).boxed(),
1145 columns: (1..=MAX_COLUMNS).boxed(),
1146 }
1147 }
1148 }
1149
1150 impl Arbitrary for Matrix {
1151 type Parameters = MatrixArbParams;
1152 type Strategy = BoxedStrategy<Self>;
1153
1154 fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
1155 let p = match args.p {
1156 Some(p) => Just(p).boxed(),
1157 None => any::<ValidPrime>().boxed(),
1158 };
1159
1160 (p, args.rows, args.columns)
1161 .prop_flat_map(|(p, rows, columns)| {
1162 let row_strategy = any_with::<FqVector<_>>(FqVectorArbParams {
1163 fq: Some(Fp::new(p)),
1164 len: Just(columns).boxed(),
1165 })
1166 .prop_map(|v| -> FpVector { v.into() });
1167
1168 let rows = proptest::collection::vec(row_strategy, rows);
1169 (Just(p), rows, Just(columns))
1170 })
1171 .prop_map(|(p, rows, columns)| Self::from_rows(p, rows, columns))
1172 .boxed()
1173 }
1174 }
1175
1176 impl Matrix {
1177 pub fn arbitrary_rref_with(args: MatrixArbParams) -> impl Strategy<Value = Self> {
1201 Self::arbitrary_with(args)
1202 .prop_flat_map(|m| {
1203 let column_vec = (0..m.columns()).collect::<Vec<_>>();
1204 let smallest_dim = std::cmp::min(m.rows(), m.columns());
1205 let pivot_cols = proptest::sample::subsequence(column_vec, 0..=smallest_dim);
1206 (Just(m), pivot_cols)
1207 })
1208 .prop_map(|(mut m, pivot_cols)| {
1209 for (row_idx, mut row) in m.iter_mut().enumerate() {
1211 if let Some(&col_idx) = pivot_cols.get(row_idx) {
1212 row.slice_mut(0, col_idx).set_to_zero();
1213 row.set_entry(col_idx, 1);
1214 } else {
1215 row.set_to_zero();
1216 }
1217 }
1218 for (row_idx, &col_idx) in pivot_cols.iter().enumerate() {
1220 for mut row in m.iter_mut().take(row_idx) {
1221 row.set_entry(col_idx, 0);
1222 }
1223 }
1224 m
1225 })
1226 }
1227
1228 pub fn arbitrary_rref() -> impl Strategy<Value = Self> {
1229 Self::arbitrary_rref_with(MatrixArbParams::default())
1230 }
1231 }
1232}
1233
1234fn row_to_limb_range(row: usize, stride: usize) -> Range<usize> {
1235 row_range_to_limb_range(&(row..row + 1), stride)
1236}
1237
1238fn row_range_to_limb_range(row_range: &Range<usize>, stride: usize) -> Range<usize> {
1239 row_range.start * stride..row_range.end * stride
1240}
1241
1242fn get_physical_rows(p: ValidPrime, rows: usize) -> usize {
1243 if p == 2 && rows >= 32 {
1244 rows.next_multiple_of(64)
1247 } else {
1248 rows
1251 }
1252}
1253
1254#[derive(Clone)]
1268pub struct AugmentedMatrix<const N: usize> {
1269 pub end: [usize; N],
1270 pub start: [usize; N],
1271 pub inner: Matrix,
1272}
1273
1274impl<const N: usize> AugmentedMatrix<N> {
1275 pub fn new(p: ValidPrime, rows: usize, columns: [usize; N]) -> Self {
1276 let mut start = [0; N];
1277 let mut end = [0; N];
1278 for i in 1..N {
1279 start[i] = start[i - 1] + FpVector::padded_len(p, columns[i - 1]);
1280 }
1281 for i in 0..N {
1282 end[i] = start[i] + columns[i];
1283 }
1284
1285 Self {
1286 inner: Matrix::new(p, rows, end[N - 1]),
1287 start,
1288 end,
1289 }
1290 }
1291
1292 pub fn new_with_capacity(
1293 p: ValidPrime,
1294 rows: usize,
1295 columns: &[usize],
1296 row_capacity: usize,
1297 extra_column_capacity: usize,
1298 ) -> Self {
1299 let mut start = [0; N];
1300 let mut end = [0; N];
1301 for i in 1..N {
1302 start[i] = start[i - 1] + FpVector::padded_len(p, columns[i - 1]);
1303 }
1304 for i in 0..N {
1305 end[i] = start[i] + columns[i];
1306 }
1307
1308 Self {
1309 inner: Matrix::new_with_capacity(
1310 p,
1311 rows,
1312 end[N - 1],
1313 row_capacity,
1314 end[N - 1] + extra_column_capacity,
1315 ),
1316 start,
1317 end,
1318 }
1319 }
1320
1321 pub fn segment(&mut self, start: usize, end: usize) -> MatrixSliceMut<'_> {
1322 let rows = self.inner.rows();
1323 let start = self.start[start];
1324 let end = self.end[end];
1325 self.slice_mut(0, rows, start, end)
1326 }
1327
1328 pub fn row_segment_mut(&mut self, i: usize, start: usize, end: usize) -> FpSliceMut<'_> {
1329 let start_idx = self.start[start];
1330 let end_idx = self.end[end];
1331 let limb_range = row_to_limb_range(i, self.stride);
1332 FpSliceMut::new(self.prime(), &mut self.data[limb_range], start_idx, end_idx)
1333 }
1334
1335 pub fn row_segment(&self, i: usize, start: usize, end: usize) -> FpSlice<'_> {
1336 let start_idx = self.start[start];
1337 let end_idx = self.end[end];
1338 self.row(i).restrict(start_idx, end_idx)
1339 }
1340
1341 pub fn into_matrix(self) -> Matrix {
1342 self.inner
1343 }
1344
1345 pub fn into_tail_segment(
1346 mut self,
1347 row_start: usize,
1348 row_end: usize,
1349 segment_start: usize,
1350 ) -> Matrix {
1351 self.inner
1352 .trim(row_start, row_end, self.start[segment_start]);
1353 self.inner
1354 }
1355
1356 pub fn compute_kernel(&self) -> Subspace {
1357 self.inner.compute_kernel(self.start[N - 1])
1358 }
1359
1360 pub fn extend_column_dimension(&mut self, columns: usize) {
1361 if columns > self.columns {
1362 self.end[N - 1] += columns - self.columns;
1363 self.inner.extend_column_dimension(columns);
1364 }
1365 }
1366}
1367
1368impl<const N: usize> std::ops::Deref for AugmentedMatrix<N> {
1369 type Target = Matrix;
1370
1371 fn deref(&self) -> &Matrix {
1372 &self.inner
1373 }
1374}
1375
1376impl<const N: usize> std::ops::DerefMut for AugmentedMatrix<N> {
1377 fn deref_mut(&mut self) -> &mut Matrix {
1378 &mut self.inner
1379 }
1380}
1381
1382impl AugmentedMatrix<2> {
1383 pub fn compute_image(&self) -> Subspace {
1384 self.inner.compute_image(self.end[0], self.start[1])
1385 }
1386
1387 pub fn compute_quasi_inverse(&self) -> QuasiInverse {
1388 self.inner.compute_quasi_inverse(self.end[0], self.start[1])
1389 }
1390}
1391
1392impl AugmentedMatrix<3> {
1393 pub fn drop_first(mut self) -> AugmentedMatrix<2> {
1394 let offset = self.start[1];
1395 for mut row in self.inner.iter_mut() {
1396 row.shl_assign(offset);
1397 }
1398 self.inner.columns -= offset;
1399 AugmentedMatrix::<2> {
1400 inner: self.inner,
1401 start: [self.start[1] - offset, self.start[2] - offset],
1402 end: [self.end[1] - offset, self.end[2] - offset],
1403 }
1404 }
1405
1406 pub fn compute_quasi_inverses(mut self) -> (QuasiInverse, QuasiInverse) {
1413 let p = self.prime();
1414 let stride = self.stride;
1415
1416 let source_columns = self.end[2] - self.start[2];
1417
1418 if self.end[0] == 0 {
1419 let cc_qi = QuasiInverse::new(None, Matrix::new(p, 0, source_columns));
1420 let res_qi = Matrix::compute_quasi_inverse(&self, self.end[1], self.start[2]);
1421 (cc_qi, res_qi)
1422 } else {
1423 let mut cc_preimage = Matrix::new(p, self.end[0], source_columns);
1424 for i in 0..self.end[0] {
1425 cc_preimage
1426 .row_mut(i)
1427 .assign(self.row(i).restrict(self.start[2], self.end[2]));
1428 }
1429 let cm_qi = QuasiInverse::new(None, cc_preimage);
1430
1431 let first_kernel_row = self.find_first_row_in_block(self.start[2]);
1432 self.rows = first_kernel_row;
1433 self.data.truncate(first_kernel_row * stride);
1434
1435 let mut res_matrix = self.drop_first();
1436 res_matrix.row_reduce();
1437 let res_qi = res_matrix.compute_quasi_inverse();
1438
1439 (cm_qi, res_qi)
1440 }
1441 }
1442}
1443
1444pub struct MatrixSliceMut<'a> {
1445 fp: Fp<ValidPrime>,
1446 data: &'a mut [Limb],
1447 rows: usize,
1448 col_start: usize,
1449 col_end: usize,
1450 stride: usize,
1451}
1452
1453impl<'a> MatrixSliceMut<'a> {
1454 pub fn prime(&self) -> ValidPrime {
1455 self.fp.characteristic()
1456 }
1457
1458 pub fn columns(&self) -> usize {
1459 self.col_end - self.col_start
1460 }
1461
1462 pub fn rows(&self) -> usize {
1463 self.rows
1464 }
1465
1466 pub fn row_slice<'b: 'a>(&'b mut self, row_start: usize, row_end: usize) -> MatrixSliceMut<'b> {
1467 let limb_range = row_range_to_limb_range(&(row_start..row_end), self.stride);
1468 Self {
1469 fp: self.fp,
1470 data: &mut self.data[limb_range],
1471 rows: row_end - row_start,
1472 col_start: self.col_start,
1473 col_end: self.col_end,
1474 stride: self.stride,
1475 }
1476 }
1477
1478 pub fn iter(&self) -> impl Iterator<Item = FpSlice<'_>> + '_ {
1479 let start = self.col_start;
1480 let end = self.col_end;
1481 (0..self.rows).map(move |row_idx| {
1482 let limb_range = row_to_limb_range(row_idx, self.stride);
1483 FpSlice::new(self.prime(), &self.data[limb_range], start, end)
1484 })
1485 }
1486
1487 pub fn iter_mut(&mut self) -> impl Iterator<Item = FpSliceMut<'_>> + '_ {
1488 let p = self.prime();
1489 let start = self.col_start;
1490 let end = self.col_end;
1491
1492 if self.stride == 0 {
1493 Either::Left(std::iter::empty())
1494 } else {
1495 let rows = self
1496 .data
1497 .chunks_mut(self.stride)
1498 .map(move |row| FpSliceMut::new(p, row, start, end));
1499 Either::Right(rows)
1500 }
1501 }
1502
1503 pub fn maybe_par_iter_mut(
1504 &mut self,
1505 ) -> impl MaybeIndexedParallelIterator<Item = FpSliceMut<'_>> + '_ {
1506 let p = self.prime();
1507 let start = self.col_start;
1508 let end = self.col_end;
1509
1510 if self.stride == 0 {
1511 Either::Left(maybe_rayon::empty())
1512 } else {
1513 let rows = self
1514 .data
1515 .maybe_par_chunks_mut(self.stride)
1516 .map(move |row| FpSliceMut::new(p, row, start, end));
1517 Either::Right(rows)
1518 }
1519 }
1520
1521 pub fn row(&mut self, row: usize) -> FpSlice<'_> {
1522 let limb_range = row_to_limb_range(row, self.stride);
1523 FpSlice::new(
1524 self.prime(),
1525 &self.data[limb_range],
1526 self.col_start,
1527 self.col_end,
1528 )
1529 }
1530
1531 pub fn row_mut(&mut self, row: usize) -> FpSliceMut<'_> {
1532 let limb_range = row_to_limb_range(row, self.stride);
1533 FpSliceMut::new(
1534 self.prime(),
1535 &mut self.data[limb_range],
1536 self.col_start,
1537 self.col_end,
1538 )
1539 }
1540
1541 pub fn add_identity(&mut self) {
1542 debug_assert_eq!(self.rows(), self.columns());
1543 for row_idx in 0..self.rows {
1544 self.row_mut(row_idx).add_basis_element(row_idx, 1);
1545 }
1546 }
1547
1548 pub fn add_masked(&mut self, other: &Matrix, mask: &[usize]) {
1550 assert_eq!(self.rows(), other.rows());
1551
1552 for (mut l, r) in self.iter_mut().zip(other.iter()) {
1553 l.add_masked(r, 1, mask);
1554 }
1555 }
1556}
1557
1558#[cfg(test)]
1559mod tests {
1560 use proptest::prelude::*;
1561
1562 use super::*;
1563
1564 #[test]
1565 fn test_augmented_matrix() {
1566 test_augmented_matrix_inner([1, 0, 5]);
1567 test_augmented_matrix_inner([4, 6, 2]);
1568 test_augmented_matrix_inner([129, 4, 64]);
1569 test_augmented_matrix_inner([64, 64, 102]);
1570 }
1571
1572 fn test_augmented_matrix_inner(cols: [usize; 3]) {
1573 let mut aug = AugmentedMatrix::<3>::new(ValidPrime::new(2), 3, cols);
1574 assert_eq!(aug.segment(0, 0).columns(), cols[0]);
1575 assert_eq!(aug.segment(1, 1).columns(), cols[1]);
1576 assert_eq!(aug.segment(2, 2).columns(), cols[2]);
1577 }
1578
1579 #[test]
1580 fn test_row_reduce_2() {
1581 let p = ValidPrime::new(2);
1582 let tests = [(
1583 [
1584 vec![0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1],
1585 vec![0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
1586 vec![0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1],
1587 vec![1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0],
1588 vec![1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0],
1589 vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
1590 vec![0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1],
1591 ],
1592 [
1593 [1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1],
1594 [0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1],
1595 [0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0],
1596 [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
1597 [0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1],
1598 [0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1],
1599 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
1600 ],
1601 [0, 1, 2, 3, 4, 5, -1, -1, -1, -1, -1, -1, 6, -1, -1],
1602 )];
1603 for test in &tests {
1604 let input = &test.0;
1605 let goal_output = test.1;
1606 let goal_pivots = test.2;
1607
1608 let mut m = Matrix::from_vec(p, input);
1609 println!("{m}");
1610 m.row_reduce();
1611 for (i, goal_row) in goal_output.iter().enumerate() {
1612 assert_eq!(m.row(i).iter().collect::<Vec<_>>(), *goal_row);
1613 }
1614 assert_eq!(m.pivots(), &goal_pivots)
1615 }
1616 }
1617
1618 proptest! {
1619 #[test]
1621 fn test_arbitrary_rref(m in Matrix::arbitrary_rref()) {
1622 let mut m_red = m.clone();
1623 m_red.row_reduce();
1624 prop_assert_eq!(m, m_red);
1625 }
1626 }
1627}