fp/blas/tile/
mod.rs

1pub use tiles::{MatrixTileSlice, MatrixTileSliceMut};
2
3use super::block;
4use crate::matrix::Matrix;
5
6pub mod tiles;
7
8impl Matrix {
9    pub fn as_tile(&self) -> MatrixTileSlice<'_> {
10        MatrixTileSlice::from_matrix(self)
11    }
12
13    pub fn as_tile_mut(&mut self) -> MatrixTileSliceMut<'_> {
14        MatrixTileSliceMut::from_matrix(self)
15    }
16}
17
18// Zero-cost loop ordering marker types
19/// Row-Column-Inner loop order: `for i { for j { for k { ... } } }`
20pub struct RCI;
21/// Column-Row-Inner loop order: `for j { for i { for k { ... } } }`
22pub struct CRI;
23/// Inner-Column-Row loop order: `for k { for j { for i { ... } } }`
24pub struct ICR;
25/// Row-Inner-Column loop order: `for i { for k { for j { ... } } }`
26pub struct RIC;
27/// Inner-Row-Column loop order: `for k { for i { for j { ... } } }`
28pub struct IRC;
29/// Column-Inner-Row loop order: `for j { for k { for i { ... } } }`
30pub struct CIR;
31
32/// Re-exports of loop ordering types for convenience.
33pub mod orders {
34    pub use super::{CIR, CRI, ICR, IRC, RCI, RIC};
35}
36
37/// Performs tile-level GEMM with a specified loop ordering.
38///
39/// This is the sequential (non-parallel) version. For large matrices, use [`gemm_concurrent`]
40/// instead.
41///
42/// # Loop Ordering
43///
44/// The choice of loop order affects cache locality and performance. Benchmarking suggests RIC is
45/// optimal for most cases, but this depends on matrix dimensions.
46#[inline]
47pub fn gemm<L: LoopOrder>(
48    alpha: bool,
49    a: MatrixTileSlice,
50    b: MatrixTileSlice,
51    beta: bool,
52    mut c: MatrixTileSliceMut,
53) {
54    if !beta {
55        c.zero_out();
56    }
57
58    if !alpha {
59        return;
60    }
61
62    assert_eq!(a.block_columns(), b.block_rows());
63
64    L::gemm(a, b, c);
65}
66
67/// Performs tile-level GEMM with recursive parallelization.
68///
69/// The matrix is recursively split along rows (if rows > M blocks) or columns (if cols > N blocks)
70/// until tiles are small enough, then all tiles are processed in parallel using rayon.
71///
72/// # Type Parameters
73///
74/// * `M` - Minimum block rows before parallelization stops
75/// * `N` - Minimum block columns before parallelization stops
76/// * `L` - Loop ordering strategy (see [`LoopOrder`])
77///
78/// # Performance
79///
80/// For best performance, choose M and N based on your matrix sizes. The defaults used in the
81/// codebase are M=1, N=16, which work well for many workloads.
82#[inline]
83pub fn gemm_concurrent<const M: usize, const N: usize, L: LoopOrder>(
84    alpha: bool,
85    a: MatrixTileSlice,
86    b: MatrixTileSlice,
87    beta: bool,
88    mut c: MatrixTileSliceMut,
89) {
90    if c.block_rows() > M {
91        let (a_first, a_second) = a.split_rows_at(a.block_rows() / 2);
92        let (c_first, c_second) = c.split_rows_at_mut(c.block_rows() / 2);
93        maybe_rayon::join(
94            move || gemm_concurrent::<M, N, L>(alpha, a_first, b, beta, c_first),
95            move || gemm_concurrent::<M, N, L>(alpha, a_second, b, beta, c_second),
96        );
97    } else if c.block_columns() > N {
98        let (b_first, b_second) = b.split_columns_at(b.block_columns() / 2);
99        let (c_first, c_second) = c.split_columns_at_mut(c.block_columns() / 2);
100        maybe_rayon::join(
101            move || gemm_concurrent::<M, N, L>(alpha, a, b_first, beta, c_first),
102            move || gemm_concurrent::<M, N, L>(alpha, a, b_second, beta, c_second),
103        );
104    } else {
105        gemm::<L>(alpha, a, b, beta, c);
106    }
107}
108
109/// Trait for zero-cost loop ordering strategies.
110///
111/// Different loop orders have different cache access patterns, which can significantly impact
112/// performance. All six permutations are provided: RCI, CRI, ICR, RIC, IRC, CIR (where R=row,
113/// C=column, I=inner).
114pub trait LoopOrder {
115    /// Performs GEMM with this loop ordering strategy.
116    fn gemm(a: MatrixTileSlice, b: MatrixTileSlice, c: MatrixTileSliceMut);
117}
118
119impl LoopOrder for RCI {
120    fn gemm(a: MatrixTileSlice, b: MatrixTileSlice, mut c: MatrixTileSliceMut) {
121        for i in 0..a.block_rows() {
122            for j in 0..b.block_columns() {
123                let mut c_block = c.block_mut_at(i, j).as_slice().gather();
124                for k in 0..b.block_rows() {
125                    let a_block = a.block_at(i, k).gather();
126                    let b_block = b.block_at(k, j).gather();
127                    block::gemm_block(a_block, b_block, &mut c_block);
128                }
129                c.block_mut_at(i, j).assign(c_block);
130            }
131        }
132    }
133}
134
135impl LoopOrder for CRI {
136    fn gemm(a: MatrixTileSlice, b: MatrixTileSlice, mut c: MatrixTileSliceMut) {
137        for j in 0..b.block_columns() {
138            for i in 0..a.block_rows() {
139                let mut c_block = c.block_mut_at(i, j).as_slice().gather();
140                for k in 0..b.block_rows() {
141                    let a_block = a.block_at(i, k).gather();
142                    let b_block = b.block_at(k, j).gather();
143                    block::gemm_block(a_block, b_block, &mut c_block);
144                }
145                c.block_mut_at(i, j).assign(c_block);
146            }
147        }
148    }
149}
150
151impl LoopOrder for ICR {
152    fn gemm(a: MatrixTileSlice, b: MatrixTileSlice, mut c: MatrixTileSliceMut) {
153        for k in 0..b.block_rows() {
154            for j in 0..b.block_columns() {
155                let b_block = b.block_at(k, j).gather();
156                for i in 0..a.block_rows() {
157                    let a_block = a.block_at(i, k).gather();
158                    let mut c_block = c.block_mut_at(i, j).as_slice().gather();
159                    block::gemm_block(a_block, b_block, &mut c_block);
160                    c.block_mut_at(i, j).assign(c_block);
161                }
162            }
163        }
164    }
165}
166
167impl LoopOrder for RIC {
168    fn gemm(a: MatrixTileSlice, b: MatrixTileSlice, mut c: MatrixTileSliceMut) {
169        for i in 0..a.block_rows() {
170            for k in 0..a.block_columns() {
171                let a_block = a.block_at(i, k).gather();
172                for j in 0..b.block_columns() {
173                    let b_block = b.block_at(k, j).gather();
174                    let mut c_block = c.block_mut_at(i, j).as_slice().gather();
175                    block::gemm_block(a_block, b_block, &mut c_block);
176                    c.block_mut_at(i, j).assign(c_block);
177                }
178            }
179        }
180    }
181}
182
183impl LoopOrder for IRC {
184    fn gemm(a: MatrixTileSlice, b: MatrixTileSlice, mut c: MatrixTileSliceMut) {
185        for k in 0..b.block_rows() {
186            for i in 0..a.block_rows() {
187                let a_block = a.block_at(i, k).gather();
188                for j in 0..b.block_columns() {
189                    let b_block = b.block_at(k, j).gather();
190                    let mut c_block = c.block_mut_at(i, j).as_slice().gather();
191                    block::gemm_block(a_block, b_block, &mut c_block);
192                    c.block_mut_at(i, j).assign(c_block);
193                }
194            }
195        }
196    }
197}
198
199impl LoopOrder for CIR {
200    fn gemm(a: MatrixTileSlice, b: MatrixTileSlice, mut c: MatrixTileSliceMut) {
201        for j in 0..b.block_columns() {
202            for k in 0..b.block_rows() {
203                let b_block = b.block_at(k, j).gather();
204                for i in 0..a.block_rows() {
205                    let a_block = a.block_at(i, k).gather();
206                    let mut c_block = c.block_mut_at(i, j).as_slice().gather();
207                    block::gemm_block(a_block, b_block, &mut c_block);
208                    c.block_mut_at(i, j).assign(c_block);
209                }
210            }
211        }
212    }
213}