fp/matrix/
matrix_inner.rs

1use std::{fmt, io, ops::Range};
2
3use aligned_vec::AVec;
4use either::Either;
5use itertools::Itertools;
6use maybe_rayon::prelude::*;
7use serde::{Deserialize, Deserializer, Serialize};
8
9use super::{QuasiInverse, Subspace};
10use crate::{
11    field::{Field, Fp, field_internal::FieldInternal},
12    limb::Limb,
13    matrix::m4ri::M4riTable,
14    prime::{self, ValidPrime},
15    vector::{FpSlice, FpSliceMut, FpVector},
16};
17
18/// A matrix! In particular, a matrix with values in F_p.
19///
20/// The way we store matrices means it is easier to perform row operations than column operations,
21/// and the way we use matrices means we want our matrices to act on the right. Hence we think of
22/// vectors as row vectors.
23#[derive(Clone, Serialize)]
24pub struct Matrix {
25    fp: Fp<ValidPrime>,
26    rows: usize,
27    physical_rows: usize,
28    columns: usize,
29    data: AVec<Limb>,
30    stride: usize,
31    /// The pivot columns of the matrix. `pivots[n]` is `k` if column `n` is the `k`th pivot
32    /// column, and a negative number otherwise. Said negative number is often -1 but this is not
33    /// guaranteed.
34    pub(crate) pivots: Vec<isize>,
35}
36
37// `Deserialize` is implemented manually rather than derived so that we can validate `Matrix`'s
38// internal invariants. Without these checks, malformed input could build a `Matrix` whose
39// accessors (`row`, `to_bytes`, ...) later panic on bounds-checked slice indexing into `data`,
40// escaping the `Deserialize` boundary instead of surfacing as a normal serde error.
41impl<'de> Deserialize<'de> for Matrix {
42    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
43    where
44        D: Deserializer<'de>,
45    {
46        use serde::de::Error;
47
48        #[derive(Deserialize)]
49        struct Raw {
50            fp: Fp<ValidPrime>,
51            rows: usize,
52            physical_rows: usize,
53            columns: usize,
54            data: AVec<Limb>,
55            stride: usize,
56            pivots: Vec<isize>,
57        }
58
59        let raw = Raw::deserialize(deserializer)?;
60        let expected_stride = raw.fp.number(raw.columns);
61        if raw.stride != expected_stride {
62            return Err(D::Error::custom(format!(
63                "Matrix stride {} does not match expected {} for columns={}",
64                raw.stride, expected_stride, raw.columns,
65            )));
66        }
67        if raw.physical_rows < raw.rows {
68            return Err(D::Error::custom(format!(
69                "Matrix physical_rows {} less than rows {}",
70                raw.physical_rows, raw.rows,
71            )));
72        }
73        if raw.data.len() != raw.physical_rows * raw.stride {
74            return Err(D::Error::custom(format!(
75                "Matrix data length {} does not match physical_rows*stride = {}*{} = {}",
76                raw.data.len(),
77                raw.physical_rows,
78                raw.stride,
79                raw.physical_rows * raw.stride,
80            )));
81        }
82        // `pivots` is either empty (matrix not yet row-reduced) or has one entry per column.
83        if !raw.pivots.is_empty() && raw.pivots.len() != raw.columns {
84            return Err(D::Error::custom(format!(
85                "Matrix pivots length {} must be 0 or columns = {}",
86                raw.pivots.len(),
87                raw.columns,
88            )));
89        }
90        Ok(Self {
91            fp: raw.fp,
92            rows: raw.rows,
93            physical_rows: raw.physical_rows,
94            columns: raw.columns,
95            data: raw.data,
96            stride: raw.stride,
97            pivots: raw.pivots,
98        })
99    }
100}
101
102impl PartialEq for Matrix {
103    fn eq(&self, other: &Self) -> bool {
104        self.data == other.data
105    }
106}
107
108impl Eq for Matrix {}
109
110impl Matrix {
111    /// Produces a new matrix over F_p with the specified number of rows and columns, initialized
112    /// to the 0 matrix.
113    pub fn new(p: ValidPrime, rows: usize, columns: usize) -> Self {
114        Self::new_with_capacity(p, rows, columns, rows, columns)
115    }
116
117    pub fn new_with_capacity(
118        p: ValidPrime,
119        rows: usize,
120        columns: usize,
121        rows_capacity: usize,
122        columns_capacity: usize,
123    ) -> Self {
124        let fp = Fp::new(p);
125        let stride = fp.number(columns_capacity);
126        let physical_rows = get_physical_rows(p, rows_capacity);
127        let mut data = AVec::with_capacity(0, physical_rows * stride);
128        data.resize(physical_rows * stride, 0);
129
130        Self {
131            fp,
132            rows,
133            physical_rows,
134            columns,
135            data,
136            stride,
137            pivots: Vec::new(),
138        }
139    }
140
141    pub fn from_data(p: ValidPrime, rows: usize, columns: usize, mut data: Vec<Limb>) -> Self {
142        let fp = Fp::new(p);
143        let stride = fp.number(columns);
144        let physical_rows = get_physical_rows(p, rows);
145        data.resize(physical_rows * stride, 0);
146        Self {
147            fp,
148            rows,
149            physical_rows,
150            columns,
151            data: AVec::from_iter(0, data),
152            stride,
153            pivots: Vec::new(),
154        }
155    }
156
157    pub fn identity(p: ValidPrime, dim: usize) -> Self {
158        let mut matrix = Self::new(p, dim, dim);
159        matrix.as_slice_mut().add_identity();
160        matrix
161    }
162
163    pub fn from_bytes(
164        p: ValidPrime,
165        rows: usize,
166        columns: usize,
167        buffer: &mut impl io::Read,
168    ) -> io::Result<Self> {
169        let fp = Fp::new(p);
170        let stride = fp.number(columns);
171        let physical_rows = get_physical_rows(p, rows);
172        let mut data: AVec<Limb> = aligned_vec::avec![0; stride * physical_rows];
173        for row_idx in 0..rows {
174            let limb_range = row_to_limb_range(row_idx, stride);
175            crate::limb::from_bytes(&mut data[limb_range], buffer)?;
176        }
177        Ok(Self {
178            fp,
179            rows,
180            physical_rows,
181            columns,
182            data,
183            stride,
184            pivots: Vec::new(),
185        })
186    }
187
188    pub fn to_bytes(&self, data: &mut impl io::Write) -> io::Result<()> {
189        let limbs_per_row = self.fp.number(self.columns);
190        for row_idx in 0..self.rows() {
191            let row_range = row_idx * self.stride..row_idx * self.stride + limbs_per_row;
192            crate::limb::to_bytes(&self.data[row_range], data)?;
193        }
194        Ok(())
195    }
196
197    /// Read a vector of `isize`
198    pub(crate) fn write_pivot(v: &[isize], buffer: &mut impl io::Write) -> io::Result<()> {
199        if cfg!(all(target_endian = "little", target_pointer_width = "64")) {
200            let buf: &[u8] =
201                unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, v.len() * 8) };
202            buffer.write_all(buf)
203        } else {
204            use byteorder::{LittleEndian, WriteBytesExt};
205            for &i in v {
206                buffer.write_i64::<LittleEndian>(i as i64)?;
207            }
208            Ok(())
209        }
210    }
211
212    /// Read a vector of `isize` of length `dim`.
213    pub(crate) fn read_pivot(dim: usize, data: &mut impl io::Read) -> io::Result<Vec<isize>> {
214        if cfg!(all(target_endian = "little", target_pointer_width = "64")) {
215            let mut image = vec![0; dim];
216            let buf: &mut [u8] =
217                unsafe { std::slice::from_raw_parts_mut(image.as_mut_ptr() as *mut u8, dim * 8) };
218            data.read_exact(buf)?;
219            Ok(image)
220        } else {
221            use byteorder::{LittleEndian, ReadBytesExt};
222            let mut image = Vec::with_capacity(dim);
223            for _ in 0..dim {
224                image.push(data.read_i64::<LittleEndian>()? as isize);
225            }
226            Ok(image)
227        }
228    }
229}
230
231impl Matrix {
232    pub fn prime(&self) -> ValidPrime {
233        self.fp.characteristic()
234    }
235
236    /// Gets the number of rows in the matrix.
237    pub fn rows(&self) -> usize {
238        self.rows
239    }
240
241    /// Gets the physical number of rows allocated (for BLAS operations).
242    pub(crate) fn physical_rows(&self) -> usize {
243        self.physical_rows
244    }
245
246    /// Gets the number of columns in the matrix.
247    pub fn columns(&self) -> usize {
248        self.columns
249    }
250
251    pub(crate) fn stride(&self) -> usize {
252        self.stride
253    }
254
255    pub(crate) fn data(&self) -> &[Limb] {
256        &self.data
257    }
258
259    pub(crate) fn data_mut(&mut self) -> &mut [Limb] {
260        &mut self.data
261    }
262
263    /// Set the pivots to -1 in every entry. This is called by [`Matrix::row_reduce`].
264    pub fn initialize_pivots(&mut self) {
265        self.pivots.clear();
266        self.pivots.resize(self.columns, -1);
267    }
268
269    pub fn pivots(&self) -> &[isize] {
270        &self.pivots
271    }
272
273    pub fn pivots_mut(&mut self) -> &mut [isize] {
274        &mut self.pivots
275    }
276
277    /// Produces a Matrix from a vector of FpVectors. We pass in the number of columns because all
278    /// `0 x n` matrices will have an empty Vec, and we have to distinguish between them.
279    pub fn from_rows(p: ValidPrime, input: Vec<FpVector>, columns: usize) -> Self {
280        let fp = Fp::new(p);
281        let rows = input.len();
282        let stride = fp.number(columns);
283        let physical_rows = get_physical_rows(p, rows);
284        let mut data = AVec::with_capacity(0, physical_rows * stride);
285        for row in &input {
286            data.extend_from_slice(row.limbs());
287        }
288        // Pad with zeros for prime 2
289        data.resize(physical_rows * stride, 0);
290        Self {
291            fp,
292            rows,
293            physical_rows,
294            columns,
295            data,
296            stride,
297            pivots: Vec::new(),
298        }
299    }
300
301    /// Produces a `1 x n` matrix from a single FpVector. This is a convenience function.
302    pub fn from_row(p: ValidPrime, row: FpVector, columns: usize) -> Self {
303        Self::from_rows(p, vec![row], columns)
304    }
305
306    /// Produces a Matrix from an `&[Vec<u32>]` object. If the number of rows is 0, the number
307    /// of columns is also assumed to be zero.
308    ///
309    /// # Example
310    /// ```
311    /// # use fp::prime::ValidPrime;
312    /// let p = ValidPrime::new(7);
313    /// # use fp::matrix::Matrix;
314    /// let input = [vec![1, 3, 6], vec![0, 3, 4]];
315    ///
316    /// let m = Matrix::from_vec(p, &input);
317    /// ```
318    pub fn from_vec(p: ValidPrime, input: &[Vec<u32>]) -> Self {
319        let fp = Fp::new(p);
320        let rows = input.len();
321        if rows == 0 {
322            return Self::new(p, 0, 0);
323        }
324        let columns = input[0].len();
325        let stride = fp.number(columns);
326        let physical_rows = get_physical_rows(p, rows);
327        let mut data = AVec::with_capacity(0, physical_rows * stride);
328        for row in input {
329            for chunk in row.chunks(fp.entries_per_limb()) {
330                data.push(fp.pack(chunk.iter().map(|x| fp.element(*x))));
331            }
332        }
333        // Pad with zeros for prime 2
334        data.resize(physical_rows * stride, 0);
335        Self {
336            fp,
337            rows,
338            physical_rows,
339            columns,
340            data,
341            stride,
342            pivots: Vec::new(),
343        }
344    }
345
346    /// ```
347    /// # use fp::matrix::Matrix;
348    /// # use fp::prime::TWO;
349    ///
350    /// let matrix_vec = vec![vec![1, 0], vec![0, 1]];
351    ///
352    /// assert_eq!(Matrix::from_vec(TWO, &matrix_vec).to_vec(), matrix_vec);
353    /// ```
354    pub fn to_vec(&self) -> Vec<Vec<u32>> {
355        self.data
356            .iter()
357            .chunks(self.stride)
358            .into_iter()
359            .map(|row| {
360                row.flat_map(|&limb| self.fp.unpack(limb).map(|x| x.val()))
361                    .take(self.columns())
362                    .collect()
363            })
364            .collect()
365    }
366
367    /// Produces a padded augmented matrix from an `&[Vec<u32>]` object (produces [A|0|I] from
368    /// A). Returns the matrix and the first column index of I.
369    ///
370    /// # Example
371    /// ```
372    /// # use fp::prime::ValidPrime;
373    /// let p = ValidPrime::new(7);
374    /// # use fp::matrix::Matrix;
375    /// # use fp::vector::FpVector;
376    /// let input = [vec![1, 3, 6], vec![0, 3, 4]];
377    ///
378    /// let (n, m) = Matrix::augmented_from_vec(p, &input);
379    /// assert!(n >= input[0].len());
380    /// ```
381    pub fn augmented_from_vec(p: ValidPrime, input: &[Vec<u32>]) -> (usize, Self) {
382        let rows = input.len();
383        let cols = input[0].len();
384        let padded_cols = FpVector::padded_len(p, cols);
385        let mut m = Self::new(p, rows, padded_cols + rows);
386
387        for (i, row) in input.iter().enumerate() {
388            for (j, &value) in row.iter().enumerate() {
389                m.row_mut(i).set_entry(j, value);
390            }
391        }
392        m.slice_mut(0, rows, padded_cols, padded_cols + rows)
393            .add_identity();
394        (padded_cols, m)
395    }
396
397    pub fn is_zero(&self) -> bool {
398        self.data.iter().all(|limb| *limb == 0)
399    }
400
401    pub fn set_to_zero(&mut self) {
402        for limb in self.data.iter_mut() {
403            *limb = 0;
404        }
405    }
406
407    pub fn assign(&mut self, other: &Self) {
408        self.data = other.data.clone();
409    }
410
411    pub fn as_slice_mut(&mut self) -> MatrixSliceMut<'_> {
412        self.slice_mut(0, self.rows(), 0, self.columns())
413    }
414
415    pub fn slice_mut(
416        &mut self,
417        row_start: usize,
418        row_end: usize,
419        col_start: usize,
420        col_end: usize,
421    ) -> MatrixSliceMut<'_> {
422        let row_range = row_start..row_end;
423        let limb_range = row_range_to_limb_range(&row_range, self.stride);
424        MatrixSliceMut {
425            fp: self.fp,
426            rows: row_range.len(),
427            data: &mut self.data[limb_range],
428            col_start,
429            col_end,
430            stride: self.stride,
431        }
432    }
433
434    pub fn row(&self, row: usize) -> FpSlice<'_> {
435        let limb_range = row_to_limb_range(row, self.stride);
436        FpSlice::new(self.prime(), &self.data[limb_range], 0, self.columns)
437    }
438
439    pub fn row_mut(&mut self, row: usize) -> FpSliceMut<'_> {
440        let limb_range = row_to_limb_range(row, self.stride);
441        FpSliceMut::new(self.prime(), &mut self.data[limb_range], 0, self.columns)
442    }
443}
444
445impl Matrix {
446    pub fn iter(&self) -> impl Iterator<Item = FpSlice<'_>> {
447        (0..self.rows()).map(move |row_idx| self.row(row_idx))
448    }
449
450    pub fn iter_mut(&mut self) -> impl Iterator<Item = FpSliceMut<'_>> {
451        let p = self.prime();
452        let columns = self.columns;
453        let logical_rows = self.rows;
454
455        if self.stride == 0 {
456            Either::Left(std::iter::empty())
457        } else {
458            let rows = self
459                .data
460                .chunks_mut(self.stride)
461                .take(logical_rows) // Only iterate over logical rows
462                .map(move |row| FpSliceMut::new(p, row, 0, columns));
463            Either::Right(rows)
464        }
465    }
466
467    pub fn maybe_par_iter_mut(
468        &mut self,
469    ) -> impl MaybeIndexedParallelIterator<Item = FpSliceMut<'_>> {
470        let p = self.prime();
471        let columns = self.columns;
472        let logical_rows = self.rows;
473
474        if self.stride == 0 {
475            Either::Left(maybe_rayon::empty())
476        } else {
477            let rows = self
478                .data
479                .maybe_par_chunks_mut(self.stride)
480                .take(logical_rows) // Only iterate over logical rows
481                .map(move |row| FpSliceMut::new(p, row, 0, columns));
482            Either::Right(rows)
483        }
484    }
485}
486
487impl fmt::Display for Matrix {
488    /// # Example
489    /// ```
490    /// # use fp::matrix::Matrix;
491    /// # use fp::prime::ValidPrime;
492    /// let m = Matrix::from_vec(ValidPrime::new(2), &[vec![0, 1, 0], vec![1, 1, 0]]);
493    /// assert_eq!(&format!("{m}"), "[\n    [0, 1, 0],\n    [1, 1, 0]\n]");
494    /// assert_eq!(&format!("{m:#}"), "010\n110");
495    /// ```
496    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
497        if f.alternate() {
498            write!(f, "{:#}", self.iter().format("\n"))
499        } else {
500            let mut it = self.iter();
501            if let Some(x) = it.next() {
502                write!(f, "[\n    {x}")?;
503            } else {
504                return write!(f, "[]");
505            }
506            for x in it {
507                write!(f, ",\n    {x}")?;
508            }
509            write!(f, "\n]")
510        }
511    }
512}
513
514impl fmt::Debug for Matrix {
515    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
516        <Self as fmt::Display>::fmt(self, f)
517    }
518}
519
520impl Matrix {
521    /// A no-nonsense, safe, row operation. Adds `c * self[source]` to `self[target]`.
522    pub fn safe_row_op(&mut self, target: usize, source: usize, c: u32) {
523        assert_ne!(target, source);
524        assert!(source < self.rows());
525        assert!(target < self.rows());
526
527        let (mut target, source) = unsafe { self.split_borrow(target, source) };
528        target.add(source.as_slice(), c)
529    }
530
531    /// Performs a row operation using `pivot_column` as the pivot column. This assumes that the
532    /// source row is zero in all columns before the pivot column.
533    ///
534    /// # Safety
535    /// `target` and `source` must be distinct and less that `vectors.len()`
536    pub unsafe fn row_op(
537        &mut self,
538        target: usize,
539        source: usize,
540        pivot_column: usize,
541        prime: ValidPrime,
542    ) {
543        debug_assert_ne!(target, source);
544        let coef = self.row(target).entry(pivot_column);
545        if coef == 0 {
546            return;
547        }
548        let (mut target, source) = unsafe { self.split_borrow(target, source) };
549        target.add_offset(source.as_slice(), prime - coef, pivot_column);
550    }
551
552    /// A version of [`Matrix::row_op`] without the zero assumption.
553    ///
554    /// # Safety
555    /// `target` and `source` must be distinct and less that `vectors.len()`
556    pub unsafe fn row_op_naive(
557        &mut self,
558        target: usize,
559        source: usize,
560        pivot_column: usize,
561        prime: ValidPrime,
562    ) {
563        debug_assert_ne!(target, source);
564        let coef = self.row(target).entry(pivot_column);
565        if coef == 0 {
566            return;
567        }
568        let (mut target, source) = unsafe { self.split_borrow(target, source) };
569        target.add(source.as_slice(), prime - coef);
570    }
571
572    /// Mutably borrows `x[i]` and `x[j]`.
573    ///
574    /// # Safety
575    /// `i` and `j` must be distinct and not out of bounds.
576    pub(crate) unsafe fn split_borrow(
577        &mut self,
578        i: usize,
579        j: usize,
580    ) -> (FpSliceMut<'_>, FpSliceMut<'_>) {
581        let ptr = self.data.as_mut_ptr();
582        let row1 = unsafe { std::slice::from_raw_parts_mut(ptr.add(i * self.stride), self.stride) };
583        let row2 = unsafe { std::slice::from_raw_parts_mut(ptr.add(j * self.stride), self.stride) };
584        (
585            FpSliceMut::new(self.prime(), row1, 0, self.columns),
586            FpSliceMut::new(self.prime(), row2, 0, self.columns),
587        )
588    }
589
590    pub fn swap_rows(&mut self, i: usize, j: usize) {
591        for limb_idx in 0..self.stride {
592            self.data
593                .swap(i * self.stride + limb_idx, j * self.stride + limb_idx);
594        }
595    }
596
597    /// This is very similar to row_reduce, except we only need to get to row echelon form, not
598    /// *reduced* row echelon form. It also returns the list of pivots instead.
599    pub fn find_pivots_permutation<T: Iterator<Item = usize>>(
600        &mut self,
601        permutation: T,
602    ) -> Vec<usize> {
603        let p = self.prime();
604        let rows = self.rows();
605        let mut pivots = Vec::with_capacity(rows);
606
607        if rows == 0 {
608            return pivots;
609        }
610
611        let mut pivot: usize = 0;
612        for pivot_column in permutation {
613            // Search down column for a nonzero entry.
614            let mut pivot_row = rows;
615            for i in pivot..rows {
616                if self.row(i).entry(pivot_column) != 0 {
617                    pivot_row = i;
618                    break;
619                }
620            }
621            if pivot_row == rows {
622                continue;
623            }
624
625            // Record position of pivot.
626            pivots.push(pivot_column);
627
628            // Pivot_row contains a row with a pivot in current column.
629            // Swap pivot row up.
630            self.swap_rows(pivot, pivot_row);
631            // println!("({}) <==> ({}): \n{}", pivot, pivot_row, self);
632
633            // // Divide pivot row by pivot entry
634            let c = self.row(pivot).entry(pivot_column);
635            let c_inv = prime::inverse(p, c);
636            self.row_mut(pivot).scale(c_inv);
637            // println!("({}) <== {} * ({}): \n{}", pivot, c_inv, pivot, self);
638
639            for i in pivot_row + 1..rows {
640                // Safety requires i != pivot, which follows from i > pivot_row >= pivot. They are
641                // both less than rows by construction
642                unsafe { self.row_op_naive(i, pivot, pivot_column, p) };
643            }
644            pivot += 1;
645        }
646        pivots
647    }
648
649    /// Perform row reduction to reduce it to reduced row echelon form. This modifies the matrix in
650    /// place and records the pivots in `column_to_pivot_row`. The way the pivots are recorded is
651    /// that `column_to_pivot_row[i]` is the row of the pivot if the `i`th row contains a pivot,
652    /// and `-1` otherwise.
653    ///
654    /// # Returns
655    /// The number of non-empty rows in the matrix
656    ///
657    /// # Arguments
658    ///  * `column_to_pivot_row` - A vector for the function to write the pivots into. The length
659    ///    should be at least as long as the number of columns (and the extra entries are ignored).
660    ///
661    /// # Example
662    /// ```
663    /// # use fp::prime::ValidPrime;
664    /// let p = ValidPrime::new(7);
665    /// # use fp::matrix::Matrix;
666    ///
667    /// let input = [vec![1, 3, 6], vec![0, 3, 4]];
668    ///
669    /// let result = [vec![1, 0, 2], vec![0, 1, 6]];
670    ///
671    /// let mut m = Matrix::from_vec(p, &input);
672    /// m.row_reduce();
673    ///
674    /// assert_eq!(m, Matrix::from_vec(p, &result));
675    /// ```
676    pub fn row_reduce(&mut self) -> usize {
677        let p = self.prime();
678        self.initialize_pivots();
679
680        let mut empty_rows = Vec::with_capacity(self.rows());
681
682        if p == 2 {
683            // the m4ri C library uses a similar formula but with a hard cap of 7 instead of 8
684            let k = std::cmp::min(8, crate::prime::log2(1 + self.rows()) * 3 / 4);
685            let mut table = M4riTable::new(k, self.columns());
686
687            for i in 0..self.rows() {
688                table.reduce_naive(&mut *self, i);
689
690                if let Some((c, _)) = self.row(i).first_nonzero() {
691                    self.pivots[c] = i as isize;
692                    for &row in table.rows() {
693                        unsafe {
694                            self.row_op(row, i, c, p);
695                        }
696                    }
697                    table.add(c, i);
698
699                    if table.len() == k {
700                        table.generate(self);
701                        for j in 0..table.rows()[0] {
702                            table.reduce(self.row_mut(j).limbs_mut());
703                        }
704                        for j in i + 1..self.rows() {
705                            table.reduce(self.row_mut(j).limbs_mut());
706                        }
707                        table.clear();
708                    }
709                } else {
710                    empty_rows.push(i);
711                }
712            }
713            if !table.is_empty() {
714                table.generate(self);
715                for j in 0..table.rows()[0] {
716                    table.reduce(self.row_mut(j).limbs_mut());
717                }
718                table.clear();
719            }
720        } else {
721            for i in 0..self.rows() {
722                if let Some((c, v)) = self.row(i).first_nonzero() {
723                    self.pivots[c] = i as isize;
724                    self.row_mut(i).scale(prime::inverse(p, v));
725                    for j in 0..self.rows() {
726                        if i == j {
727                            continue;
728                        }
729                        unsafe {
730                            self.row_op(j, i, c, p);
731                        }
732                    }
733                } else {
734                    empty_rows.push(i);
735                }
736            }
737        }
738
739        // Now reorder the vectors. There are O(n) in-place permutation algorithms but the way we
740        // get the permutation makes the naive strategy easier.
741        let old_len = self.data.len();
742        let old_data = std::mem::replace(&mut self.data, aligned_vec::avec![0; old_len]);
743
744        let mut new_row_idx = 0;
745        for old_row in self.pivots.iter_mut().filter(|row| **row >= 0) {
746            let old_row_idx = *old_row as usize;
747            let old_limb_range = row_to_limb_range(old_row_idx, self.stride);
748            let new_limb_range = row_to_limb_range(new_row_idx, self.stride);
749            self.data[new_limb_range].copy_from_slice(&old_data[old_limb_range]);
750            *old_row = new_row_idx as isize;
751            new_row_idx += 1;
752        }
753
754        new_row_idx
755    }
756}
757
758impl Matrix {
759    /// Given a row reduced matrix, find the first row whose pivot column is after (or at)
760    /// `first_column`.
761    pub fn find_first_row_in_block(&self, first_column: usize) -> usize {
762        self.pivots[first_column..]
763            .iter()
764            .find(|&&x| x >= 0)
765            .map(|x| *x as usize)
766            .unwrap_or_else(|| self.rows())
767    }
768
769    /// Computes the quasi-inverse of a matrix given a rref of [A|0|I], where 0 is the zero padding
770    /// as usual.
771    ///
772    /// # Arguments
773    ///  * `last_target_col` - the last column of A
774    ///  * `first_source_col` - the first column of I
775    ///
776    /// # Example
777    /// ```
778    /// # use fp::prime::ValidPrime;
779    /// let p = ValidPrime::new(3);
780    /// # use fp::matrix::Matrix;
781    /// # use fp::vector::FpVector;
782    /// let input = [
783    ///     vec![1, 2, 1, 1, 0],
784    ///     vec![1, 0, 2, 1, 1],
785    ///     vec![2, 2, 0, 2, 1],
786    /// ];
787    ///
788    /// let (padded_cols, mut m) = Matrix::augmented_from_vec(p, &input);
789    /// m.row_reduce();
790    /// let qi = m.compute_quasi_inverse(input[0].len(), padded_cols);
791    ///
792    /// let preimage = [vec![0, 1, 0], vec![0, 2, 2]];
793    /// assert_eq!(qi.preimage(), &Matrix::from_vec(p, &preimage));
794    /// ```
795    pub fn compute_quasi_inverse(
796        &self,
797        last_target_col: usize,
798        first_source_col: usize,
799    ) -> QuasiInverse {
800        let p = self.prime();
801        let columns = self.columns();
802        let source_columns = columns - first_source_col;
803        let first_kernel_row = self.find_first_row_in_block(first_source_col);
804        let mut preimage = Self::new(p, first_kernel_row, source_columns);
805        for i in 0..first_kernel_row {
806            preimage
807                .row_mut(i)
808                .assign(self.row(i).restrict(first_source_col, columns));
809        }
810        QuasiInverse::new(Some(self.pivots()[..last_target_col].to_vec()), preimage)
811    }
812
813    /// Computes the quasi-inverse of a matrix given a rref of [A|0|I], where 0 is the zero padding
814    /// as usual.
815    ///
816    /// # Arguments
817    ///  * `last_target_col` - the last column of A
818    ///  * `first_source_col` - the first column of I
819    ///
820    /// # Example
821    /// ```
822    /// # use fp::prime::ValidPrime;
823    /// let p = ValidPrime::new(3);
824    /// # use fp::matrix::Matrix;
825    /// # use fp::vector::FpVector;
826    /// let input = [
827    ///     vec![1, 2, 1, 1, 0],
828    ///     vec![1, 0, 2, 1, 1],
829    ///     vec![2, 2, 0, 2, 1],
830    /// ];
831    ///
832    /// let (padded_cols, mut m) = Matrix::augmented_from_vec(p, &input);
833    /// m.row_reduce();
834    ///
835    /// let computed_image = m.compute_image(input[0].len(), padded_cols);
836    ///
837    /// let image = [vec![1, 0, 2, 1, 1], vec![0, 1, 1, 0, 1]];
838    /// assert_eq!(*computed_image, Matrix::from_vec(p, &image));
839    /// assert_eq!(computed_image.pivots(), &vec![0, 1, -1, -1, -1]);
840    /// ```
841    pub fn compute_image(&self, last_target_col: usize, first_source_col: usize) -> Subspace {
842        let p = self.prime();
843        let first_kernel_row = self.find_first_row_in_block(first_source_col);
844        let mut image_matrix = Self::new(p, first_kernel_row, last_target_col);
845        for i in 0..first_kernel_row {
846            image_matrix
847                .row_mut(i)
848                .assign(self.row(i).restrict(0, last_target_col));
849        }
850        image_matrix.pivots = self.pivots()[..last_target_col].to_vec();
851        Subspace::from_matrix(image_matrix)
852    }
853
854    /// Computes the kernel from an augmented matrix in rref. To compute the kernel of a matrix
855    /// A, produce an augmented matrix of the form
856    /// ```text
857    /// [A | I]
858    /// ```
859    /// An important thing to note is that the number of columns of `A` should be a multiple of the
860    /// number of entries per limb in an FpVector, and this is often achieved by padding columns
861    /// with 0. The padded length can be obtained from `FpVector::padded_dimension`.
862    ///
863    /// After this matrix is set up, perform row reduction with `Matrix::row_reduce`, and then
864    /// apply `compute_kernel`.
865    ///
866    /// # Arguments
867    ///  * `column_to_pivot_row` - This is the list of pivots `row_reduce` gave you.
868    ///  * `first_source_column` - The column where the `I` part of the augmented matrix starts.
869    ///
870    /// # Example
871    /// ```
872    /// # use fp::prime::ValidPrime;
873    /// let p = ValidPrime::new(3);
874    /// # use fp::matrix::Matrix;
875    /// # use fp::vector::FpVector;
876    /// let input = [
877    ///     vec![1, 2, 1, 1, 0],
878    ///     vec![1, 0, 2, 1, 1],
879    ///     vec![2, 2, 0, 2, 1],
880    /// ];
881    ///
882    /// let (padded_cols, mut m) = Matrix::augmented_from_vec(p, &input);
883    /// m.row_reduce();
884    /// let ker = m.compute_kernel(padded_cols);
885    ///
886    /// let mut target = vec![0; 3];
887    /// assert_eq!(ker.row(0).iter().collect::<Vec<u32>>(), vec![1, 1, 2]);
888    /// ```
889    pub fn compute_kernel(&self, first_source_column: usize) -> Subspace {
890        let p = self.prime();
891        let rows = self.rows();
892        let columns = self.columns();
893        let source_dimension = columns - first_source_column;
894        let column_to_pivot_row = self.pivots();
895
896        // Find the first kernel row
897        let first_kernel_row = self.find_first_row_in_block(first_source_column);
898        // Every row after the first kernel row is also a kernel row, so now we know how big it is and can allocate space.
899        let kernel_dimension = rows - first_kernel_row;
900        let mut kernel = Self::new(p, kernel_dimension, source_dimension);
901        kernel.initialize_pivots();
902
903        if kernel_dimension == 0 {
904            return Subspace::from_matrix(kernel);
905        }
906
907        // Write pivots into kernel
908        for i in 0..source_dimension {
909            // Turns -1 into some negative number... make sure to check <0 for no pivot in column...
910            kernel.pivots_mut()[i] =
911                column_to_pivot_row[i + first_source_column] - first_kernel_row as isize;
912        }
913        // Copy kernel matrix into kernel
914        for (i, mut row) in kernel.iter_mut().enumerate() {
915            row.assign(
916                self.row(first_kernel_row + i)
917                    .restrict(first_source_column, first_source_column + source_dimension),
918            );
919        }
920        Subspace::from_matrix(kernel)
921    }
922
923    pub fn extend_column_dimension(&mut self, columns: usize) {
924        if columns > self.columns {
925            self.extend_column_capacity(columns);
926            self.columns = columns;
927            self.pivots.resize(columns, -1);
928        }
929    }
930
931    pub fn extend_column_capacity(&mut self, columns: usize) {
932        let new_stride = self.fp.number(columns);
933        if new_stride > self.stride {
934            self.data.resize(new_stride * self.physical_rows, 0);
935            // Shift row data backwards, starting from the end to avoid overwriting data.
936            for row_idx in (0..self.physical_rows).rev() {
937                let old_row_start = row_idx * self.stride;
938                let new_row_start = row_idx * new_stride;
939                let new_row_zero_part =
940                    row_idx * new_stride + self.stride..(row_idx + 1) * new_stride;
941                // Safety: we already resized the data, and limbs are always aligned.
942                unsafe {
943                    std::ptr::copy(
944                        &raw const self.data[old_row_start],
945                        &raw mut self.data[new_row_start],
946                        self.stride,
947                    );
948                }
949                for limb in &mut self.data[new_row_zero_part] {
950                    *limb = 0;
951                }
952            }
953            self.stride = new_stride;
954        }
955    }
956
957    /// Add a row to the matrix and return a mutable reference to it.
958    pub fn add_row(&mut self) -> FpSliceMut<'_> {
959        // Check if we need to expand physical capacity
960        if self.rows + 1 > self.physical_rows {
961            let new_physical_rows = get_physical_rows(self.prime(), self.rows + 1);
962            self.data.resize(new_physical_rows * self.stride, 0);
963            self.physical_rows = new_physical_rows;
964        }
965        self.rows += 1;
966        self.row_mut(self.rows - 1)
967    }
968
969    /// Given a matrix M in rref, add rows to make the matrix surjective when restricted to the
970    /// columns between `start_column` and `end_column`. That is, if M = [*|B|*] where B is between
971    /// columns `start_column` and `end_column`, we want the new B to be surjective. This doesn't
972    /// change the size of the matrix. Rather, it adds the new row to the next empty row in the
973    /// matrix. This will panic if there are not enough empty rows.
974    ///
975    /// The rows added are all zero except in a single column, where it is 1. The function returns
976    /// the list of such columns.
977    ///
978    /// # Arguments
979    ///  * `first_empty_row` - The first row in the matrix that is empty. This is where we will add
980    ///    our new rows. This is a mutable borrow and by the end of the function, `first_empty_row`
981    ///    will be updated to the new first empty row.
982    ///  * `current_pivots` - The current pivots of the matrix.
983    ///
984    /// # Panics
985    /// The function panics if there are not enough empty rows.
986    pub fn extend_to_surjection(
987        &mut self,
988        start_column: usize,
989        end_column: usize,
990        extra_column_capacity: usize,
991    ) -> Vec<usize> {
992        let mut added_pivots = Vec::new();
993        self.extend_column_capacity(self.columns + extra_column_capacity);
994
995        for (i, &pivot) in self.pivots.clone()[start_column..end_column]
996            .iter()
997            .enumerate()
998        {
999            if pivot >= 0 {
1000                continue;
1001            }
1002            let mut new_row = self.add_row();
1003            new_row.set_entry(i, 1);
1004            added_pivots.push(i);
1005        }
1006        added_pivots
1007    }
1008
1009    /// Given a matrix in rref, say [A|B|C], where B lies between columns `start_column` and
1010    /// `end_columns`, and a superspace of the image of B, add rows to the matrix such that the
1011    /// image of B becomes this superspace.
1012    ///
1013    /// The rows added are basis vectors of the desired image as specified in the Subspace object.
1014    /// The function returns the list of new pivot columns.
1015    ///
1016    /// # Panics
1017    /// It *may* panic if the current image is not contained in `desired_image`, but is not
1018    /// guaranteed to do so.
1019    pub fn extend_image(
1020        &mut self,
1021        start_column: usize,
1022        end_column: usize,
1023        desired_image: &Subspace,
1024        extra_column_capacity: usize,
1025    ) -> Vec<usize> {
1026        let mut added_pivots = Vec::new();
1027        let desired_pivots = desired_image.pivots();
1028        let early_end_column = std::cmp::min(end_column, desired_pivots.len() + start_column);
1029
1030        self.extend_column_capacity(self.columns + extra_column_capacity);
1031
1032        for i in start_column..early_end_column {
1033            debug_assert!(
1034                self.pivots()[i] < 0 || desired_pivots[i - start_column] >= 0,
1035                "current_pivots : {:?}, desired_pivots : {:?}",
1036                self.pivots(),
1037                desired_pivots
1038            );
1039            if self.pivots()[i] >= 0 || desired_pivots[i - start_column] < 0 {
1040                continue;
1041            }
1042            // Look up the cycle that we're missing and add a generator hitting it.
1043            let kernel_vector_row = desired_pivots[i - start_column] as usize;
1044            let new_image = desired_image.row(kernel_vector_row);
1045
1046            let mut new_row = self.add_row();
1047            new_row
1048                .slice_mut(
1049                    start_column,
1050                    start_column + desired_image.ambient_dimension(),
1051                )
1052                .assign(new_image);
1053
1054            added_pivots.push(i);
1055        }
1056        added_pivots
1057    }
1058
1059    /// Applies a matrix to a vector.
1060    ///
1061    /// # Example
1062    /// ```
1063    /// # use fp::prime::ValidPrime;
1064    /// let p = ValidPrime::new(7);
1065    /// # use fp::matrix::Matrix;
1066    /// # use fp::vector::FpVector;
1067    /// let input = [vec![1, 3, 6], vec![0, 3, 4]];
1068    ///
1069    /// let m = Matrix::from_vec(p, &input);
1070    /// let v = FpVector::from_slice(p, &vec![3, 1]);
1071    /// let mut result = FpVector::new(p, 3);
1072    /// let desired_result = FpVector::from_slice(p, &vec![3, 5, 1]);
1073    /// m.apply(result.as_slice_mut(), 1, v.as_slice());
1074    /// assert_eq!(result, desired_result);
1075    /// ```
1076    pub fn apply(&self, mut result: FpSliceMut, coeff: u32, input: FpSlice) {
1077        debug_assert_eq!(input.len(), self.rows());
1078        for i in 0..input.len() {
1079            result.add(self.row(i), (coeff * input.entry(i)) % self.prime());
1080        }
1081    }
1082
1083    pub fn trim(&mut self, row_start: usize, row_end: usize, col_start: usize) {
1084        let mut new = Self::new(self.prime(), row_end - row_start, self.columns - col_start);
1085        for (i, mut row) in new.iter_mut().enumerate() {
1086            row.assign(self.row(row_start + i).restrict(col_start, self.columns));
1087        }
1088        std::mem::swap(self, &mut new);
1089    }
1090
1091    /// Rotate the rows downwards in the range `range`.
1092    pub fn rotate_down(&mut self, range: Range<usize>, shift: usize) {
1093        let limb_range = row_range_to_limb_range(&range, self.stride);
1094        self.data[limb_range].rotate_right(shift * self.stride)
1095    }
1096}
1097
1098impl std::ops::MulAssign<u32> for Matrix {
1099    fn mul_assign(&mut self, rhs: u32) {
1100        #[allow(clippy::suspicious_op_assign_impl)]
1101        let rhs = rhs % self.prime();
1102        for mut row in self.iter_mut() {
1103            row.scale(rhs);
1104        }
1105    }
1106}
1107
1108impl std::ops::AddAssign<&Self> for Matrix {
1109    fn add_assign(&mut self, rhs: &Self) {
1110        assert_eq!(self.prime(), rhs.prime());
1111        assert_eq!(self.columns(), rhs.columns());
1112        assert_eq!(self.rows(), rhs.rows());
1113
1114        for (i, mut row) in self.iter_mut().enumerate() {
1115            row.add(rhs.row(i), 1);
1116        }
1117    }
1118}
1119
1120#[cfg(feature = "proptest")]
1121pub mod arbitrary {
1122    use proptest::prelude::*;
1123
1124    use super::*;
1125    use crate::{
1126        field::Fp,
1127        vector::{FqVector, arbitrary::FqVectorArbParams},
1128    };
1129
1130    pub const MAX_ROWS: usize = 100;
1131    pub const MAX_COLUMNS: usize = 100;
1132
1133    #[derive(Debug, Clone)]
1134    pub struct MatrixArbParams {
1135        pub p: Option<ValidPrime>,
1136        pub rows: BoxedStrategy<usize>,
1137        pub columns: BoxedStrategy<usize>,
1138    }
1139
1140    impl Default for MatrixArbParams {
1141        fn default() -> Self {
1142            Self {
1143                p: None,
1144                rows: (1..=MAX_ROWS).boxed(),
1145                columns: (1..=MAX_COLUMNS).boxed(),
1146            }
1147        }
1148    }
1149
1150    impl Arbitrary for Matrix {
1151        type Parameters = MatrixArbParams;
1152        type Strategy = BoxedStrategy<Self>;
1153
1154        fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
1155            let p = match args.p {
1156                Some(p) => Just(p).boxed(),
1157                None => any::<ValidPrime>().boxed(),
1158            };
1159
1160            (p, args.rows, args.columns)
1161                .prop_flat_map(|(p, rows, columns)| {
1162                    let row_strategy = any_with::<FqVector<_>>(FqVectorArbParams {
1163                        fq: Some(Fp::new(p)),
1164                        len: Just(columns).boxed(),
1165                    })
1166                    .prop_map(|v| -> FpVector { v.into() });
1167
1168                    let rows = proptest::collection::vec(row_strategy, rows);
1169                    (Just(p), rows, Just(columns))
1170                })
1171                .prop_map(|(p, rows, columns)| Self::from_rows(p, rows, columns))
1172                .boxed()
1173        }
1174    }
1175
1176    impl Matrix {
1177        /// Generate an arbitrary row-reduced matrix.
1178        ///
1179        /// This is more interesting than just generating an arbitrary matrix and row-reducing. If
1180        /// we pick a matrix uniformly at random in the space of all $n \times m$ matrices, it has a
1181        /// very high probability of having full rank with all its pivots in the first $n$ columns.
1182        /// This implies that, after projecting to the space of row-reduced matrices, the output is
1183        /// very likely to be an identity matrix augmented by a random matrix. If $m$ is
1184        /// significantly larger than $n$, this is only a tiny subspace of the space of all
1185        /// row-reduced matrices.
1186        ///
1187        /// While a search through *all* $n \times m$ matrices will also cover all row-reduced
1188        /// matrices, in practice this space is so large that we only test a vanishingly small
1189        /// fraction of it. Therefore, if a method that is sensitive to the pivot structure of the
1190        /// input matrix is proptested using `arbitrary_with`, it is unlikely that the tests will
1191        /// cover many matrices with interesting pivots, while those are the most likely to cause
1192        /// bugs. This function attempts to generate a matrix that is chosen uniformly at random
1193        /// directly in the space of all row-reduced matrices.
1194        ///
1195        /// In practice, this is not quite right. There is no randomness in the code; instead we
1196        /// generate a `Strategy` that samples from only the space of row-reduced matrices. Also,
1197        /// depending on the parameters, the strategy may output matrices that are not all of the
1198        /// same size or even over the same ground field, so using the word "space" is slightly
1199        /// improper, mathematically speaking.
1200        pub fn arbitrary_rref_with(args: MatrixArbParams) -> impl Strategy<Value = Self> {
1201            Self::arbitrary_with(args)
1202                .prop_flat_map(|m| {
1203                    let column_vec = (0..m.columns()).collect::<Vec<_>>();
1204                    let smallest_dim = std::cmp::min(m.rows(), m.columns());
1205                    let pivot_cols = proptest::sample::subsequence(column_vec, 0..=smallest_dim);
1206                    (Just(m), pivot_cols)
1207                })
1208                .prop_map(|(mut m, pivot_cols)| {
1209                    // Ensure rows start with 0s followed by a 1 in their pivot column
1210                    for (row_idx, mut row) in m.iter_mut().enumerate() {
1211                        if let Some(&col_idx) = pivot_cols.get(row_idx) {
1212                            row.slice_mut(0, col_idx).set_to_zero();
1213                            row.set_entry(col_idx, 1);
1214                        } else {
1215                            row.set_to_zero();
1216                        }
1217                    }
1218                    // Set all other entries in the pivot columns to 0
1219                    for (row_idx, &col_idx) in pivot_cols.iter().enumerate() {
1220                        for mut row in m.iter_mut().take(row_idx) {
1221                            row.set_entry(col_idx, 0);
1222                        }
1223                    }
1224                    m
1225                })
1226        }
1227
1228        pub fn arbitrary_rref() -> impl Strategy<Value = Self> {
1229            Self::arbitrary_rref_with(MatrixArbParams::default())
1230        }
1231    }
1232}
1233
1234fn row_to_limb_range(row: usize, stride: usize) -> Range<usize> {
1235    row_range_to_limb_range(&(row..row + 1), stride)
1236}
1237
1238fn row_range_to_limb_range(row_range: &Range<usize>, stride: usize) -> Range<usize> {
1239    row_range.start * stride..row_range.end * stride
1240}
1241
1242fn get_physical_rows(p: ValidPrime, rows: usize) -> usize {
1243    if p == 2 && rows >= 32 {
1244        // For 32+ rows, pad to next multiple of 64 for BLAS optimization
1245        // This bounds the memory overhead to at most 2x (for 32 rows → 64 rows)
1246        rows.next_multiple_of(64)
1247    } else {
1248        // For < 32 rows, don't pad (would waste too much memory)
1249        // These small matrices will use scalar multiplication
1250        rows
1251    }
1252}
1253
1254/// This models an augmented matrix.
1255///
1256/// In an ideal world, this will have no public fields. The inner matrix
1257/// can be accessed via deref, and there are functions that expose `end`
1258/// and `start`. However, in the real world, the borrow checker exists, and there are
1259/// cases where directly accessing these fields is what it takes to let you pass the
1260/// borrow checker.
1261///
1262/// In particular, if `m` is an augmented matrix and `f` is a function
1263/// that takes in `&mut Matrix`, trying to run `m.f(m.start[0])` produces an error
1264/// because it is not clear if we first do the `deref_mut` then retrieve `start[0]`.
1265/// (since `deref_mut` takes in a mutable borrow, it could in theory modify `m`
1266/// non-trivially)
1267#[derive(Clone)]
1268pub struct AugmentedMatrix<const N: usize> {
1269    pub end: [usize; N],
1270    pub start: [usize; N],
1271    pub inner: Matrix,
1272}
1273
1274impl<const N: usize> AugmentedMatrix<N> {
1275    pub fn new(p: ValidPrime, rows: usize, columns: [usize; N]) -> Self {
1276        let mut start = [0; N];
1277        let mut end = [0; N];
1278        for i in 1..N {
1279            start[i] = start[i - 1] + FpVector::padded_len(p, columns[i - 1]);
1280        }
1281        for i in 0..N {
1282            end[i] = start[i] + columns[i];
1283        }
1284
1285        Self {
1286            inner: Matrix::new(p, rows, end[N - 1]),
1287            start,
1288            end,
1289        }
1290    }
1291
1292    pub fn new_with_capacity(
1293        p: ValidPrime,
1294        rows: usize,
1295        columns: &[usize],
1296        row_capacity: usize,
1297        extra_column_capacity: usize,
1298    ) -> Self {
1299        let mut start = [0; N];
1300        let mut end = [0; N];
1301        for i in 1..N {
1302            start[i] = start[i - 1] + FpVector::padded_len(p, columns[i - 1]);
1303        }
1304        for i in 0..N {
1305            end[i] = start[i] + columns[i];
1306        }
1307
1308        Self {
1309            inner: Matrix::new_with_capacity(
1310                p,
1311                rows,
1312                end[N - 1],
1313                row_capacity,
1314                end[N - 1] + extra_column_capacity,
1315            ),
1316            start,
1317            end,
1318        }
1319    }
1320
1321    pub fn segment(&mut self, start: usize, end: usize) -> MatrixSliceMut<'_> {
1322        let rows = self.inner.rows();
1323        let start = self.start[start];
1324        let end = self.end[end];
1325        self.slice_mut(0, rows, start, end)
1326    }
1327
1328    pub fn row_segment_mut(&mut self, i: usize, start: usize, end: usize) -> FpSliceMut<'_> {
1329        let start_idx = self.start[start];
1330        let end_idx = self.end[end];
1331        let limb_range = row_to_limb_range(i, self.stride);
1332        FpSliceMut::new(self.prime(), &mut self.data[limb_range], start_idx, end_idx)
1333    }
1334
1335    pub fn row_segment(&self, i: usize, start: usize, end: usize) -> FpSlice<'_> {
1336        let start_idx = self.start[start];
1337        let end_idx = self.end[end];
1338        self.row(i).restrict(start_idx, end_idx)
1339    }
1340
1341    pub fn into_matrix(self) -> Matrix {
1342        self.inner
1343    }
1344
1345    pub fn into_tail_segment(
1346        mut self,
1347        row_start: usize,
1348        row_end: usize,
1349        segment_start: usize,
1350    ) -> Matrix {
1351        self.inner
1352            .trim(row_start, row_end, self.start[segment_start]);
1353        self.inner
1354    }
1355
1356    pub fn compute_kernel(&self) -> Subspace {
1357        self.inner.compute_kernel(self.start[N - 1])
1358    }
1359
1360    pub fn extend_column_dimension(&mut self, columns: usize) {
1361        if columns > self.columns {
1362            self.end[N - 1] += columns - self.columns;
1363            self.inner.extend_column_dimension(columns);
1364        }
1365    }
1366}
1367
1368impl<const N: usize> std::ops::Deref for AugmentedMatrix<N> {
1369    type Target = Matrix;
1370
1371    fn deref(&self) -> &Matrix {
1372        &self.inner
1373    }
1374}
1375
1376impl<const N: usize> std::ops::DerefMut for AugmentedMatrix<N> {
1377    fn deref_mut(&mut self) -> &mut Matrix {
1378        &mut self.inner
1379    }
1380}
1381
1382impl AugmentedMatrix<2> {
1383    pub fn compute_image(&self) -> Subspace {
1384        self.inner.compute_image(self.end[0], self.start[1])
1385    }
1386
1387    pub fn compute_quasi_inverse(&self) -> QuasiInverse {
1388        self.inner.compute_quasi_inverse(self.end[0], self.start[1])
1389    }
1390}
1391
1392impl AugmentedMatrix<3> {
1393    pub fn drop_first(mut self) -> AugmentedMatrix<2> {
1394        let offset = self.start[1];
1395        for mut row in self.inner.iter_mut() {
1396            row.shl_assign(offset);
1397        }
1398        self.inner.columns -= offset;
1399        AugmentedMatrix::<2> {
1400            inner: self.inner,
1401            start: [self.start[1] - offset, self.start[2] - offset],
1402            end: [self.end[1] - offset, self.end[2] - offset],
1403        }
1404    }
1405
1406    /// This function computes quasi-inverses for matrices A, B given a reduced row echelon form of
1407    /// [A|0|B|0|I] such that A is surjective. Moreover, if Q is the quasi-inverse of A, it is
1408    /// guaranteed that the image of QB and B|_{ker A} are disjoint.
1409    ///
1410    /// This takes ownership of the matrix since it heavily modifies the matrix. This is not
1411    /// strictly necessary but is fine in most applications.
1412    pub fn compute_quasi_inverses(mut self) -> (QuasiInverse, QuasiInverse) {
1413        let p = self.prime();
1414        let stride = self.stride;
1415
1416        let source_columns = self.end[2] - self.start[2];
1417
1418        if self.end[0] == 0 {
1419            let cc_qi = QuasiInverse::new(None, Matrix::new(p, 0, source_columns));
1420            let res_qi = Matrix::compute_quasi_inverse(&self, self.end[1], self.start[2]);
1421            (cc_qi, res_qi)
1422        } else {
1423            let mut cc_preimage = Matrix::new(p, self.end[0], source_columns);
1424            for i in 0..self.end[0] {
1425                cc_preimage
1426                    .row_mut(i)
1427                    .assign(self.row(i).restrict(self.start[2], self.end[2]));
1428            }
1429            let cm_qi = QuasiInverse::new(None, cc_preimage);
1430
1431            let first_kernel_row = self.find_first_row_in_block(self.start[2]);
1432            self.rows = first_kernel_row;
1433            self.data.truncate(first_kernel_row * stride);
1434
1435            let mut res_matrix = self.drop_first();
1436            res_matrix.row_reduce();
1437            let res_qi = res_matrix.compute_quasi_inverse();
1438
1439            (cm_qi, res_qi)
1440        }
1441    }
1442}
1443
1444pub struct MatrixSliceMut<'a> {
1445    fp: Fp<ValidPrime>,
1446    data: &'a mut [Limb],
1447    rows: usize,
1448    col_start: usize,
1449    col_end: usize,
1450    stride: usize,
1451}
1452
1453impl<'a> MatrixSliceMut<'a> {
1454    pub fn prime(&self) -> ValidPrime {
1455        self.fp.characteristic()
1456    }
1457
1458    pub fn columns(&self) -> usize {
1459        self.col_end - self.col_start
1460    }
1461
1462    pub fn rows(&self) -> usize {
1463        self.rows
1464    }
1465
1466    pub fn row_slice<'b: 'a>(&'b mut self, row_start: usize, row_end: usize) -> MatrixSliceMut<'b> {
1467        let limb_range = row_range_to_limb_range(&(row_start..row_end), self.stride);
1468        Self {
1469            fp: self.fp,
1470            data: &mut self.data[limb_range],
1471            rows: row_end - row_start,
1472            col_start: self.col_start,
1473            col_end: self.col_end,
1474            stride: self.stride,
1475        }
1476    }
1477
1478    pub fn iter(&self) -> impl Iterator<Item = FpSlice<'_>> + '_ {
1479        let start = self.col_start;
1480        let end = self.col_end;
1481        (0..self.rows).map(move |row_idx| {
1482            let limb_range = row_to_limb_range(row_idx, self.stride);
1483            FpSlice::new(self.prime(), &self.data[limb_range], start, end)
1484        })
1485    }
1486
1487    pub fn iter_mut(&mut self) -> impl Iterator<Item = FpSliceMut<'_>> + '_ {
1488        let p = self.prime();
1489        let start = self.col_start;
1490        let end = self.col_end;
1491
1492        if self.stride == 0 {
1493            Either::Left(std::iter::empty())
1494        } else {
1495            let rows = self
1496                .data
1497                .chunks_mut(self.stride)
1498                .map(move |row| FpSliceMut::new(p, row, start, end));
1499            Either::Right(rows)
1500        }
1501    }
1502
1503    pub fn maybe_par_iter_mut(
1504        &mut self,
1505    ) -> impl MaybeIndexedParallelIterator<Item = FpSliceMut<'_>> + '_ {
1506        let p = self.prime();
1507        let start = self.col_start;
1508        let end = self.col_end;
1509
1510        if self.stride == 0 {
1511            Either::Left(maybe_rayon::empty())
1512        } else {
1513            let rows = self
1514                .data
1515                .maybe_par_chunks_mut(self.stride)
1516                .map(move |row| FpSliceMut::new(p, row, start, end));
1517            Either::Right(rows)
1518        }
1519    }
1520
1521    pub fn row(&mut self, row: usize) -> FpSlice<'_> {
1522        let limb_range = row_to_limb_range(row, self.stride);
1523        FpSlice::new(
1524            self.prime(),
1525            &self.data[limb_range],
1526            self.col_start,
1527            self.col_end,
1528        )
1529    }
1530
1531    pub fn row_mut(&mut self, row: usize) -> FpSliceMut<'_> {
1532        let limb_range = row_to_limb_range(row, self.stride);
1533        FpSliceMut::new(
1534            self.prime(),
1535            &mut self.data[limb_range],
1536            self.col_start,
1537            self.col_end,
1538        )
1539    }
1540
1541    pub fn add_identity(&mut self) {
1542        debug_assert_eq!(self.rows(), self.columns());
1543        for row_idx in 0..self.rows {
1544            self.row_mut(row_idx).add_basis_element(row_idx, 1);
1545        }
1546    }
1547
1548    /// For each row, add the `v[i]`th entry of `other` to `self`.
1549    pub fn add_masked(&mut self, other: &Matrix, mask: &[usize]) {
1550        assert_eq!(self.rows(), other.rows());
1551
1552        for (mut l, r) in self.iter_mut().zip(other.iter()) {
1553            l.add_masked(r, 1, mask);
1554        }
1555    }
1556}
1557
1558#[cfg(test)]
1559mod tests {
1560    use proptest::prelude::*;
1561
1562    use super::*;
1563
1564    #[test]
1565    fn test_augmented_matrix() {
1566        test_augmented_matrix_inner([1, 0, 5]);
1567        test_augmented_matrix_inner([4, 6, 2]);
1568        test_augmented_matrix_inner([129, 4, 64]);
1569        test_augmented_matrix_inner([64, 64, 102]);
1570    }
1571
1572    fn test_augmented_matrix_inner(cols: [usize; 3]) {
1573        let mut aug = AugmentedMatrix::<3>::new(ValidPrime::new(2), 3, cols);
1574        assert_eq!(aug.segment(0, 0).columns(), cols[0]);
1575        assert_eq!(aug.segment(1, 1).columns(), cols[1]);
1576        assert_eq!(aug.segment(2, 2).columns(), cols[2]);
1577    }
1578
1579    #[test]
1580    fn test_row_reduce_2() {
1581        let p = ValidPrime::new(2);
1582        let tests = [(
1583            [
1584                vec![0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1],
1585                vec![0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
1586                vec![0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1],
1587                vec![1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0],
1588                vec![1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0],
1589                vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
1590                vec![0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1],
1591            ],
1592            [
1593                [1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1],
1594                [0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1],
1595                [0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0],
1596                [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0],
1597                [0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1],
1598                [0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1],
1599                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
1600            ],
1601            [0, 1, 2, 3, 4, 5, -1, -1, -1, -1, -1, -1, 6, -1, -1],
1602        )];
1603        for test in &tests {
1604            let input = &test.0;
1605            let goal_output = test.1;
1606            let goal_pivots = test.2;
1607
1608            let mut m = Matrix::from_vec(p, input);
1609            println!("{m}");
1610            m.row_reduce();
1611            for (i, goal_row) in goal_output.iter().enumerate() {
1612                assert_eq!(m.row(i).iter().collect::<Vec<_>>(), *goal_row);
1613            }
1614            assert_eq!(m.pivots(), &goal_pivots)
1615        }
1616    }
1617
1618    proptest! {
1619        // Test that `arbitrary_rref` generates matrices in rref.
1620        #[test]
1621        fn test_arbitrary_rref(m in Matrix::arbitrary_rref()) {
1622            let mut m_red = m.clone();
1623            m_red.row_reduce();
1624            prop_assert_eq!(m, m_red);
1625        }
1626    }
1627}