fp/matrix/
m4ri.rs

1use itertools::Itertools;
2
3use crate::{
4    field::{field_internal::FieldInternal, fp::F2},
5    limb::{Limb, LimbBitIndexPair},
6    matrix::Matrix,
7    simd,
8};
9
10#[derive(Debug, Default)]
11/// M4RI works as follows --- first row reduce k rows using the naive algorithm. We then construct
12/// a table of all 2^k linear combinations of these rows. This can be done in O(2^k) time. We then
13/// use this table to reduce the remaining rows, so that each row takes O(`num_columns`) time,
14/// which reduces the time taken by a factor of k x density.
15///
16/// Since we are likely to run into empty rows when doing row reduction, what we do in practice is
17/// that we keep reducing rows until we collect k of them. Whenever we find a row, we record it
18/// with [`M4riTable::add`]. We only record the row number and pivot column, as the values of these
19/// rows will change as we go on due to the desire to land in a RREF.
20///
21/// Once we have recorded enough rows, we generate the table using [`M4riTable::generate`], and
22/// then reduce limbs using [`M4riTable::reduce`]. When we are done we clear it using
23/// [`M4riTable::clear`] and proceed to the next k rows.
24pub(crate) struct M4riTable {
25    /// The indices of new rows in the table
26    rows: Vec<usize>,
27    /// The list of pivot columns of the rows
28    columns: Vec<LimbBitIndexPair>,
29    /// The 2^k linear combinations of the k rows, apart from the first one which is identically
30    /// zero.
31    data: Vec<Limb>,
32    /// The smallest non-zero limb in this table. We use this when row reducing to save a few
33    /// operations.
34    min_limb: usize,
35}
36
37impl M4riTable {
38    /// Create a table with space for `k` vectors, each with `cols` columns.
39    pub fn new(k: usize, cols: usize) -> Self {
40        let num_limbs = F2.number(cols);
41        Self {
42            rows: Vec::with_capacity(k),
43            columns: Vec::with_capacity(k),
44            min_limb: 0,
45            // There are 2^k rows but the first one is zero which we omit
46            data: Vec::with_capacity(((1 << k) - 1) * num_limbs),
47        }
48    }
49
50    /// Number of rows in the M4riTable
51    pub fn len(&self) -> usize {
52        self.columns.len()
53    }
54
55    /// Whether the table has no rows
56    pub fn is_empty(&self) -> bool {
57        self.columns.is_empty()
58    }
59
60    /// Get the list of pivot rows
61    pub fn rows(&self) -> &[usize] {
62        &self.rows
63    }
64
65    /// Add a row to the table.
66    ///
67    /// # Arguments
68    ///  - `column`: pivot column of the row
69    ///  - `row`: index of the row
70    pub fn add(&mut self, column: usize, row: usize) {
71        self.columns.push(F2.limb_bit_index_pair(column));
72        self.rows.push(row);
73    }
74
75    /// Clear the contents of the table
76    pub fn clear(&mut self) {
77        self.columns.clear();
78        self.rows.clear();
79        self.data.clear();
80    }
81
82    /// Generates the table from the known data
83    /// `num` is the number of the vector being added.
84    pub fn generate(&mut self, matrix: &Matrix) {
85        let num_limbs = matrix.row(0).limbs().len();
86        self.min_limb = usize::MAX;
87        for (n, (c, &r)) in self.columns.iter().zip_eq(&self.rows).enumerate() {
88            let old_len = self.data.len();
89            self.data.extend_from_slice(matrix.row(r).limbs());
90            self.data.extend_from_within(..old_len);
91            for i in 1 << n..2 * (1 << n) - 1 {
92                simd::add_simd(
93                    &mut self.data[i * num_limbs..(i + 1) * num_limbs],
94                    matrix.row(r).limbs(),
95                    c.limb,
96                );
97            }
98            self.min_limb = std::cmp::min(self.min_limb, c.limb);
99        }
100    }
101
102    pub fn reduce_naive(&self, matrix: &mut Matrix, target: usize) {
103        for (&row, col) in self.rows.iter().zip_eq(&self.columns) {
104            assert!(target != row);
105            unsafe {
106                let coef = (matrix.row(target).limbs()[col.limb] >> col.bit_index) & 1;
107                if coef != 0 {
108                    let (mut target, source) = matrix.split_borrow(target, row);
109                    simd::add_simd(target.limbs_mut(), source.limbs(), col.limb)
110                }
111            }
112        }
113    }
114
115    pub fn reduce(&self, v: &mut [Limb]) {
116        let num_limbs = v.len();
117        let mut index: usize = 0;
118        for &col in self.columns.iter().rev() {
119            index <<= 1;
120            index += ((v[col.limb] >> col.bit_index) & 1) as usize;
121        }
122        if index != 0 {
123            simd::add_simd(v, &self.data[(index - 1) * num_limbs..], self.min_limb);
124        }
125    }
126}