fp/matrix/m4ri.rs
1use itertools::Itertools;
2
3use crate::{
4 field::{field_internal::FieldInternal, fp::F2},
5 limb::{Limb, LimbBitIndexPair},
6 matrix::Matrix,
7 simd,
8};
9
10#[derive(Debug, Default)]
11/// M4RI works as follows --- first row reduce k rows using the naive algorithm. We then construct
12/// a table of all 2^k linear combinations of these rows. This can be done in O(2^k) time. We then
13/// use this table to reduce the remaining rows, so that each row takes O(`num_columns`) time,
14/// which reduces the time taken by a factor of k x density.
15///
16/// Since we are likely to run into empty rows when doing row reduction, what we do in practice is
17/// that we keep reducing rows until we collect k of them. Whenever we find a row, we record it
18/// with [`M4riTable::add`]. We only record the row number and pivot column, as the values of these
19/// rows will change as we go on due to the desire to land in a RREF.
20///
21/// Once we have recorded enough rows, we generate the table using [`M4riTable::generate`], and
22/// then reduce limbs using [`M4riTable::reduce`]. When we are done we clear it using
23/// [`M4riTable::clear`] and proceed to the next k rows.
24pub(crate) struct M4riTable {
25 /// The indices of new rows in the table
26 rows: Vec<usize>,
27 /// The list of pivot columns of the rows
28 columns: Vec<LimbBitIndexPair>,
29 /// The 2^k linear combinations of the k rows, apart from the first one which is identically
30 /// zero.
31 data: Vec<Limb>,
32 /// The smallest non-zero limb in this table. We use this when row reducing to save a few
33 /// operations.
34 min_limb: usize,
35}
36
37impl M4riTable {
38 /// Create a table with space for `k` vectors, each with `cols` columns.
39 pub fn new(k: usize, cols: usize) -> Self {
40 let num_limbs = F2.number(cols);
41 Self {
42 rows: Vec::with_capacity(k),
43 columns: Vec::with_capacity(k),
44 min_limb: 0,
45 // There are 2^k rows but the first one is zero which we omit
46 data: Vec::with_capacity(((1 << k) - 1) * num_limbs),
47 }
48 }
49
50 /// Number of rows in the M4riTable
51 pub fn len(&self) -> usize {
52 self.columns.len()
53 }
54
55 /// Whether the table has no rows
56 pub fn is_empty(&self) -> bool {
57 self.columns.is_empty()
58 }
59
60 /// Get the list of pivot rows
61 pub fn rows(&self) -> &[usize] {
62 &self.rows
63 }
64
65 /// Add a row to the table.
66 ///
67 /// # Arguments
68 /// - `column`: pivot column of the row
69 /// - `row`: index of the row
70 pub fn add(&mut self, column: usize, row: usize) {
71 self.columns.push(F2.limb_bit_index_pair(column));
72 self.rows.push(row);
73 }
74
75 /// Clear the contents of the table
76 pub fn clear(&mut self) {
77 self.columns.clear();
78 self.rows.clear();
79 self.data.clear();
80 }
81
82 /// Generates the table from the known data
83 /// `num` is the number of the vector being added.
84 pub fn generate(&mut self, matrix: &Matrix) {
85 let num_limbs = matrix.row(0).limbs().len();
86 self.min_limb = usize::MAX;
87 for (n, (c, &r)) in self.columns.iter().zip_eq(&self.rows).enumerate() {
88 let old_len = self.data.len();
89 self.data.extend_from_slice(matrix.row(r).limbs());
90 self.data.extend_from_within(..old_len);
91 for i in 1 << n..2 * (1 << n) - 1 {
92 simd::add_simd(
93 &mut self.data[i * num_limbs..(i + 1) * num_limbs],
94 matrix.row(r).limbs(),
95 c.limb,
96 );
97 }
98 self.min_limb = std::cmp::min(self.min_limb, c.limb);
99 }
100 }
101
102 pub fn reduce_naive(&self, matrix: &mut Matrix, target: usize) {
103 for (&row, col) in self.rows.iter().zip_eq(&self.columns) {
104 assert!(target != row);
105 unsafe {
106 let coef = (matrix.row(target).limbs()[col.limb] >> col.bit_index) & 1;
107 if coef != 0 {
108 let (mut target, source) = matrix.split_borrow(target, row);
109 simd::add_simd(target.limbs_mut(), source.limbs(), col.limb)
110 }
111 }
112 }
113 }
114
115 pub fn reduce(&self, v: &mut [Limb]) {
116 let num_limbs = v.len();
117 let mut index: usize = 0;
118 for &col in self.columns.iter().rev() {
119 index <<= 1;
120 index += ((v[col.limb] >> col.bit_index) & 1) as usize;
121 }
122 if index != 0 {
123 simd::add_simd(v, &self.data[(index - 1) * num_limbs..], self.min_limb);
124 }
125 }
126}