1use tile::{LoopOrder, orders::*};
28
29use crate::matrix::Matrix;
30
31pub mod block;
32pub mod tile;
33
34impl std::ops::Mul for &Matrix {
35 type Output = Matrix;
36
37 fn mul(self, rhs: Self) -> Matrix {
38 assert_eq!(self.prime(), rhs.prime());
39 assert_eq!(self.columns(), rhs.rows());
40
41 if self.prime() == 2
42 && self.physical_rows().is_multiple_of(64)
43 && rhs.physical_rows().is_multiple_of(64)
44 {
45 self.fast_mul_concurrent(rhs)
48 } else {
49 self.naive_mul(rhs)
53 }
54 }
55}
56
57impl Matrix {
58 pub fn naive_mul(&self, rhs: &Self) -> Self {
59 assert_eq!(self.prime(), rhs.prime());
60 assert_eq!(self.columns(), rhs.rows());
61
62 let mut result = Self::new(self.prime(), self.rows(), rhs.columns());
63 for i in 0..self.rows() {
64 for j in 0..rhs.columns() {
65 for k in 0..self.columns() {
66 result
67 .row_mut(i)
68 .add_basis_element(j, self.row(i).entry(k) * rhs.row(k).entry(j));
69 }
70 }
71 }
72 result
73 }
74
75 pub fn fast_mul_sequential(&self, other: &Self) -> Self {
76 self.fast_mul_sequential_order::<RCI>(other)
78 }
79
80 pub fn fast_mul_sequential_order<L: LoopOrder>(&self, other: &Self) -> Self {
81 assert_eq!(self.prime(), 2);
82 assert_eq!(self.prime(), other.prime());
83 assert_eq!(self.columns(), other.rows());
84
85 let mut result = Self::new(self.prime(), self.rows(), other.columns());
86 tile::gemm::<L>(
87 true,
88 self.as_tile(),
89 other.as_tile(),
90 true,
91 result.as_tile_mut(),
92 );
93
94 result
95 }
96
97 pub fn fast_mul_concurrent(&self, other: &Self) -> Self {
98 self.fast_mul_concurrent_blocksize::<1, 16>(other)
101 }
102
103 pub fn fast_mul_concurrent_blocksize<const M: usize, const N: usize>(
104 &self,
105 other: &Self,
106 ) -> Self {
107 self.fast_mul_concurrent_blocksize_order::<M, N, RCI>(other)
109 }
110
111 pub fn fast_mul_concurrent_blocksize_order<const M: usize, const N: usize, L: LoopOrder>(
112 &self,
113 other: &Self,
114 ) -> Self {
115 assert_eq!(self.prime(), 2);
116 assert_eq!(self.prime(), other.prime());
117 assert_eq!(self.columns(), other.rows());
118
119 let mut result = Self::new(self.prime(), self.rows(), other.columns());
120 tile::gemm_concurrent::<M, N, L>(
121 true,
122 self.as_tile(),
123 other.as_tile(),
124 true,
125 result.as_tile_mut(),
126 );
127
128 result
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use proptest::prelude::*;
135
136 use super::*;
137 use crate::{matrix::arbitrary::MatrixArbParams, prime::TWO};
138
139 const DIMS: [usize; 11] = [32, 63, 64, 65, 128, 129, 192, 193, 256, 320, 449];
141
142 fn arb_multipliable_matrices(max: Option<usize>) -> impl Strategy<Value = (Matrix, Matrix)> {
143 let max_idx = max
144 .and_then(|max| DIMS.iter().position(|&size| size > max))
145 .unwrap_or(DIMS.len());
146 let arb_dim = proptest::sample::select(&DIMS[0..max_idx]);
147 arb_dim.clone().prop_flat_map(move |size| {
148 (
149 Matrix::arbitrary_with(MatrixArbParams {
150 p: Some(TWO),
151 rows: arb_dim.clone().boxed(),
152 columns: Just(size).boxed(),
153 }),
154 Matrix::arbitrary_with(MatrixArbParams {
155 p: Some(TWO),
156 rows: Just(size).boxed(),
157 columns: arb_dim.clone().boxed(),
158 }),
159 )
160 })
161 }
162
163 macro_rules! test_fast_mul {
164 () => {
165 test_fast_mul!(1);
166 test_fast_mul!(2);
167 test_fast_mul!(4);
168 };
169 ($m:literal) => {
170 test_fast_mul!($m, 1);
171 test_fast_mul!($m, 2);
172 test_fast_mul!($m, 4);
173 };
174 ($m:literal, $n:literal) => {
175 test_fast_mul!($m, $n, CIR);
176 test_fast_mul!($m, $n, CRI);
177 test_fast_mul!($m, $n, ICR);
178 test_fast_mul!($m, $n, IRC);
179 test_fast_mul!($m, $n, RCI);
180 test_fast_mul!($m, $n, RIC);
181 };
182 ($m:literal, $n:literal, $loop_order:ty) => {
183 paste::paste! {
184 proptest! {
185 #[test]
186 fn [<test_fast_mul_concurrent_ $m _ $n _ $loop_order:lower _ is_mul>]((
187 m, n
188 ) in arb_multipliable_matrices(None)) {
189 let prod1 = m.fast_mul_sequential(&n);
190 let prod2 = m.fast_mul_concurrent_blocksize_order::<$m, $n, $loop_order>(&n);
191 prop_assert_eq!(prod1, prod2);
192 }
193 }
194 }
195 };
196 }
197
198 test_fast_mul!();
199
200 proptest! {
201 #[test]
203 fn test_fast_mul_sequential_is_mul((m, n) in arb_multipliable_matrices(Some(64))) {
204 let prod1 = m.naive_mul(&n);
205 let prod2 = m.fast_mul_sequential(&n);
206 prop_assert_eq!(prod1, prod2);
207 }
208 }
209}