1pub use tiles::{MatrixTileSlice, MatrixTileSliceMut};
2
3use super::block;
4use crate::matrix::Matrix;
5
6pub mod tiles;
7
8impl Matrix {
9 pub fn as_tile(&self) -> MatrixTileSlice<'_> {
10 MatrixTileSlice::from_matrix(self)
11 }
12
13 pub fn as_tile_mut(&mut self) -> MatrixTileSliceMut<'_> {
14 MatrixTileSliceMut::from_matrix(self)
15 }
16}
17
18pub struct RCI;
21pub struct CRI;
23pub struct ICR;
25pub struct RIC;
27pub struct IRC;
29pub struct CIR;
31
32pub mod orders {
34 pub use super::{CIR, CRI, ICR, IRC, RCI, RIC};
35}
36
37#[inline]
47pub fn gemm<L: LoopOrder>(
48 alpha: bool,
49 a: MatrixTileSlice,
50 b: MatrixTileSlice,
51 beta: bool,
52 mut c: MatrixTileSliceMut,
53) {
54 if !beta {
55 c.zero_out();
56 }
57
58 if !alpha {
59 return;
60 }
61
62 assert_eq!(a.block_columns(), b.block_rows());
63
64 L::gemm(a, b, c);
65}
66
67#[inline]
83pub fn gemm_concurrent<const M: usize, const N: usize, L: LoopOrder>(
84 alpha: bool,
85 a: MatrixTileSlice,
86 b: MatrixTileSlice,
87 beta: bool,
88 mut c: MatrixTileSliceMut,
89) {
90 if c.block_rows() > M {
91 let (a_first, a_second) = a.split_rows_at(a.block_rows() / 2);
92 let (c_first, c_second) = c.split_rows_at_mut(c.block_rows() / 2);
93 maybe_rayon::join(
94 move || gemm_concurrent::<M, N, L>(alpha, a_first, b, beta, c_first),
95 move || gemm_concurrent::<M, N, L>(alpha, a_second, b, beta, c_second),
96 );
97 } else if c.block_columns() > N {
98 let (b_first, b_second) = b.split_columns_at(b.block_columns() / 2);
99 let (c_first, c_second) = c.split_columns_at_mut(c.block_columns() / 2);
100 maybe_rayon::join(
101 move || gemm_concurrent::<M, N, L>(alpha, a, b_first, beta, c_first),
102 move || gemm_concurrent::<M, N, L>(alpha, a, b_second, beta, c_second),
103 );
104 } else {
105 gemm::<L>(alpha, a, b, beta, c);
106 }
107}
108
109pub trait LoopOrder {
115 fn gemm(a: MatrixTileSlice, b: MatrixTileSlice, c: MatrixTileSliceMut);
117}
118
119impl LoopOrder for RCI {
120 fn gemm(a: MatrixTileSlice, b: MatrixTileSlice, mut c: MatrixTileSliceMut) {
121 for i in 0..a.block_rows() {
122 for j in 0..b.block_columns() {
123 let mut c_block = c.block_mut_at(i, j).as_slice().gather();
124 for k in 0..b.block_rows() {
125 let a_block = a.block_at(i, k).gather();
126 let b_block = b.block_at(k, j).gather();
127 block::gemm_block(a_block, b_block, &mut c_block);
128 }
129 c.block_mut_at(i, j).assign(c_block);
130 }
131 }
132 }
133}
134
135impl LoopOrder for CRI {
136 fn gemm(a: MatrixTileSlice, b: MatrixTileSlice, mut c: MatrixTileSliceMut) {
137 for j in 0..b.block_columns() {
138 for i in 0..a.block_rows() {
139 let mut c_block = c.block_mut_at(i, j).as_slice().gather();
140 for k in 0..b.block_rows() {
141 let a_block = a.block_at(i, k).gather();
142 let b_block = b.block_at(k, j).gather();
143 block::gemm_block(a_block, b_block, &mut c_block);
144 }
145 c.block_mut_at(i, j).assign(c_block);
146 }
147 }
148 }
149}
150
151impl LoopOrder for ICR {
152 fn gemm(a: MatrixTileSlice, b: MatrixTileSlice, mut c: MatrixTileSliceMut) {
153 for k in 0..b.block_rows() {
154 for j in 0..b.block_columns() {
155 let b_block = b.block_at(k, j).gather();
156 for i in 0..a.block_rows() {
157 let a_block = a.block_at(i, k).gather();
158 let mut c_block = c.block_mut_at(i, j).as_slice().gather();
159 block::gemm_block(a_block, b_block, &mut c_block);
160 c.block_mut_at(i, j).assign(c_block);
161 }
162 }
163 }
164 }
165}
166
167impl LoopOrder for RIC {
168 fn gemm(a: MatrixTileSlice, b: MatrixTileSlice, mut c: MatrixTileSliceMut) {
169 for i in 0..a.block_rows() {
170 for k in 0..a.block_columns() {
171 let a_block = a.block_at(i, k).gather();
172 for j in 0..b.block_columns() {
173 let b_block = b.block_at(k, j).gather();
174 let mut c_block = c.block_mut_at(i, j).as_slice().gather();
175 block::gemm_block(a_block, b_block, &mut c_block);
176 c.block_mut_at(i, j).assign(c_block);
177 }
178 }
179 }
180 }
181}
182
183impl LoopOrder for IRC {
184 fn gemm(a: MatrixTileSlice, b: MatrixTileSlice, mut c: MatrixTileSliceMut) {
185 for k in 0..b.block_rows() {
186 for i in 0..a.block_rows() {
187 let a_block = a.block_at(i, k).gather();
188 for j in 0..b.block_columns() {
189 let b_block = b.block_at(k, j).gather();
190 let mut c_block = c.block_mut_at(i, j).as_slice().gather();
191 block::gemm_block(a_block, b_block, &mut c_block);
192 c.block_mut_at(i, j).assign(c_block);
193 }
194 }
195 }
196 }
197}
198
199impl LoopOrder for CIR {
200 fn gemm(a: MatrixTileSlice, b: MatrixTileSlice, mut c: MatrixTileSliceMut) {
201 for j in 0..b.block_columns() {
202 for k in 0..b.block_rows() {
203 let b_block = b.block_at(k, j).gather();
204 for i in 0..a.block_rows() {
205 let a_block = a.block_at(i, k).gather();
206 let mut c_block = c.block_mut_at(i, j).as_slice().gather();
207 block::gemm_block(a_block, b_block, &mut c_block);
208 c.block_mut_at(i, j).assign(c_block);
209 }
210 }
211 }
212 }
213}