fp/blas/tile/
tiles.rs

1use std::num::NonZeroUsize;
2
3use super::block::{MatrixBlockSlice, MatrixBlockSliceMut};
4use crate::{limb::Limb, matrix::Matrix};
5
6/// An immutable view of a tile within a matrix.
7///
8/// A tile is a rectangular region composed of multiple 64 x 64 blocks. Tiles enable hierarchical
9/// parallelization: large matrices are divided into tiles which are processed in parallel, and each
10/// tile is further divided into blocks for vectorization.
11///
12/// # Safety
13///
14/// The `limbs` pointer must remain valid for the lifetime `'a`, and must point to a region large
15/// enough for `dimensions[0] * 64` rows and `dimensions[1]` blocks with the given stride.
16#[derive(Debug, Clone, Copy)]
17pub struct MatrixTileSlice<'a> {
18    limbs: *const Limb,
19    /// Dimensions of the tile in units of 64 x 64 blocks: [block_rows, block_cols]
20    dimensions: [usize; 2],
21    /// Number of limbs between consecutive rows in the parent matrix
22    stride: NonZeroUsize,
23    _marker: std::marker::PhantomData<&'a ()>,
24}
25
26/// A mutable view of a tile within a matrix.
27///
28/// # Safety
29///
30/// The `limbs` pointer must remain valid and exclusively accessible for the lifetime `'a`, and must
31/// point to a region large enough for `dimensions[0] * 64` rows and `dimensions[1]` blocks with the
32/// given stride.
33#[derive(Debug, Clone, Copy)]
34pub struct MatrixTileSliceMut<'a> {
35    limbs: *mut Limb,
36    /// Dimensions of the tile in units of 64 x 64 blocks: [block_rows, block_cols]
37    dimensions: [usize; 2],
38    /// Number of limbs between consecutive rows in the parent matrix
39    stride: NonZeroUsize,
40    _marker: std::marker::PhantomData<&'a ()>,
41}
42
43impl<'a> MatrixTileSlice<'a> {
44    pub fn from_matrix(m: &'a Matrix) -> Self {
45        let stride = m.stride().try_into().expect("Can't tile empty matrix");
46        Self {
47            limbs: m.data().as_ptr(),
48            dimensions: [m.physical_rows() / 64, m.columns().div_ceil(64)],
49            stride,
50            _marker: std::marker::PhantomData,
51        }
52    }
53
54    /// Returns the number of 64 x 64 block rows in this tile.
55    #[inline]
56    pub fn block_rows(&self) -> usize {
57        self.dimensions[0]
58    }
59
60    /// Returns the number of 64 x 64 block columns in this tile.
61    #[inline]
62    pub fn block_columns(&self) -> usize {
63        self.dimensions[1]
64    }
65
66    /// Returns a view of the block at the given block coordinates.
67    ///
68    /// # Panics
69    ///
70    /// Panics if the block coordinates are out of bounds.
71    #[inline]
72    pub fn block_at(&self, block_row: usize, block_col: usize) -> MatrixBlockSlice<'_> {
73        assert!(
74            block_row < self.dimensions[0],
75            "block_row {block_row} out of bounds (max {})",
76            self.dimensions[0]
77        );
78        assert!(
79            block_col < self.dimensions[1],
80            "block_col {block_col} out of bounds (max {})",
81            self.dimensions[1]
82        );
83
84        let start_limb = 64 * block_row * self.stride.get() + block_col;
85        let stride = self.stride;
86
87        MatrixBlockSlice::new(
88            unsafe {
89                // SAFETY: block coordinates are in bounds, and the parent tile guarantees
90                // sufficient memory is allocated
91                self.limbs.add(start_limb)
92            },
93            stride,
94        )
95    }
96
97    pub fn split_rows_at(&self, block_rows: usize) -> (MatrixTileSlice<'_>, MatrixTileSlice<'_>) {
98        assert!(
99            block_rows <= self.block_rows(),
100            "split point {block_rows} exceeds block_rows {}",
101            self.block_rows()
102        );
103        let (first_rows, second_rows) = (block_rows, self.block_rows() - block_rows);
104
105        let first = MatrixTileSlice {
106            limbs: self.limbs,
107            dimensions: [first_rows, self.dimensions[1]],
108            stride: self.stride,
109            _marker: std::marker::PhantomData,
110        };
111        let second = MatrixTileSlice {
112            limbs: unsafe { self.limbs.add(64 * first_rows * self.stride.get()) },
113            dimensions: [second_rows, self.dimensions[1]],
114            stride: self.stride,
115            _marker: std::marker::PhantomData,
116        };
117        (first, second)
118    }
119
120    pub fn split_columns_at(
121        &self,
122        block_columns: usize,
123    ) -> (MatrixTileSlice<'_>, MatrixTileSlice<'_>) {
124        assert!(
125            block_columns <= self.block_columns(),
126            "split point {block_columns} exceeds block_columns {}",
127            self.block_columns()
128        );
129        let (first_cols, second_cols) = (block_columns, self.block_columns() - block_columns);
130
131        let first = MatrixTileSlice {
132            limbs: self.limbs,
133            dimensions: [self.dimensions[0], first_cols],
134            stride: self.stride,
135            _marker: std::marker::PhantomData,
136        };
137        let second = MatrixTileSlice {
138            limbs: unsafe { self.limbs.add(first_cols) },
139            dimensions: [self.dimensions[0], second_cols],
140            stride: self.stride,
141            _marker: std::marker::PhantomData,
142        };
143        (first, second)
144    }
145}
146
147impl<'a> MatrixTileSliceMut<'a> {
148    pub fn from_matrix(m: &'a mut Matrix) -> Self {
149        let stride = m.stride().try_into().expect("Can't tile empty matrix");
150        Self {
151            limbs: m.data_mut().as_mut_ptr(),
152            dimensions: [m.physical_rows() / 64, m.columns().div_ceil(64)],
153            stride,
154            _marker: std::marker::PhantomData,
155        }
156    }
157
158    pub fn block_rows(&self) -> usize {
159        self.dimensions[0]
160    }
161
162    pub fn block_columns(&self) -> usize {
163        self.dimensions[1]
164    }
165
166    pub fn block_mut_at(&mut self, block_row: usize, block_col: usize) -> MatrixBlockSliceMut<'_> {
167        assert!(
168            block_row < self.dimensions[0],
169            "block_row {block_row} out of bounds (max {})",
170            self.dimensions[0]
171        );
172        assert!(
173            block_col < self.dimensions[1],
174            "block_col {block_col} out of bounds (max {})",
175            self.dimensions[1]
176        );
177
178        let start_limb = 64 * block_row * self.stride.get() + block_col;
179        let stride = self.stride;
180
181        MatrixBlockSliceMut::new(
182            unsafe {
183                // SAFETY: block coordinates are in bounds, and the parent tile guarantees
184                // sufficient memory is allocated
185                self.limbs.add(start_limb)
186            },
187            stride,
188        )
189    }
190
191    pub fn split_rows_at_mut(
192        &mut self,
193        block_rows: usize,
194    ) -> (MatrixTileSliceMut<'_>, MatrixTileSliceMut<'_>) {
195        assert!(
196            block_rows <= self.block_rows(),
197            "split point {block_rows} exceeds block_rows {}",
198            self.block_rows()
199        );
200        let (first_rows, second_rows) = (block_rows, self.block_rows() - block_rows);
201
202        let first = MatrixTileSliceMut {
203            limbs: self.limbs,
204            dimensions: [first_rows, self.dimensions[1]],
205            stride: self.stride,
206            _marker: std::marker::PhantomData,
207        };
208        let second = MatrixTileSliceMut {
209            limbs: unsafe { self.limbs.add(64 * first_rows * self.stride.get()) },
210            dimensions: [second_rows, self.dimensions[1]],
211            stride: self.stride,
212            _marker: std::marker::PhantomData,
213        };
214        (first, second)
215    }
216
217    pub fn split_columns_at_mut(
218        &mut self,
219        block_columns: usize,
220    ) -> (MatrixTileSliceMut<'_>, MatrixTileSliceMut<'_>) {
221        assert!(
222            block_columns <= self.block_columns(),
223            "split point {block_columns} exceeds block_columns {}",
224            self.block_columns()
225        );
226        let (first_cols, second_cols) = (block_columns, self.block_columns() - block_columns);
227
228        let first = MatrixTileSliceMut {
229            limbs: self.limbs,
230            dimensions: [self.dimensions[0], first_cols],
231            stride: self.stride,
232            _marker: std::marker::PhantomData,
233        };
234        let second = MatrixTileSliceMut {
235            limbs: unsafe { self.limbs.add(first_cols) },
236            dimensions: [self.dimensions[0], second_cols],
237            stride: self.stride,
238            _marker: std::marker::PhantomData,
239        };
240        (first, second)
241    }
242
243    pub fn zero_out(&mut self) {
244        for block_row in 0..self.block_rows() {
245            for block_col in 0..self.block_columns() {
246                self.block_mut_at(block_row, block_col).zero_out();
247            }
248        }
249    }
250}
251
252// SAFETY: The tiles have &Limb / &mut Limb semantics, so inherit the same Send / Sync behavior.
253
254unsafe impl Send for MatrixTileSlice<'_> {}
255unsafe impl Sync for MatrixTileSlice<'_> {}
256
257unsafe impl Send for MatrixTileSliceMut<'_> {}