1use std::num::NonZeroUsize;
2
3use super::block::{MatrixBlockSlice, MatrixBlockSliceMut};
4use crate::{limb::Limb, matrix::Matrix};
5
6#[derive(Debug, Clone, Copy)]
17pub struct MatrixTileSlice<'a> {
18 limbs: *const Limb,
19 dimensions: [usize; 2],
21 stride: NonZeroUsize,
23 _marker: std::marker::PhantomData<&'a ()>,
24}
25
26#[derive(Debug, Clone, Copy)]
34pub struct MatrixTileSliceMut<'a> {
35 limbs: *mut Limb,
36 dimensions: [usize; 2],
38 stride: NonZeroUsize,
40 _marker: std::marker::PhantomData<&'a ()>,
41}
42
43impl<'a> MatrixTileSlice<'a> {
44 pub fn from_matrix(m: &'a Matrix) -> Self {
45 let stride = m.stride().try_into().expect("Can't tile empty matrix");
46 Self {
47 limbs: m.data().as_ptr(),
48 dimensions: [m.physical_rows() / 64, m.columns().div_ceil(64)],
49 stride,
50 _marker: std::marker::PhantomData,
51 }
52 }
53
54 #[inline]
56 pub fn block_rows(&self) -> usize {
57 self.dimensions[0]
58 }
59
60 #[inline]
62 pub fn block_columns(&self) -> usize {
63 self.dimensions[1]
64 }
65
66 #[inline]
72 pub fn block_at(&self, block_row: usize, block_col: usize) -> MatrixBlockSlice<'_> {
73 assert!(
74 block_row < self.dimensions[0],
75 "block_row {block_row} out of bounds (max {})",
76 self.dimensions[0]
77 );
78 assert!(
79 block_col < self.dimensions[1],
80 "block_col {block_col} out of bounds (max {})",
81 self.dimensions[1]
82 );
83
84 let start_limb = 64 * block_row * self.stride.get() + block_col;
85 let stride = self.stride;
86
87 MatrixBlockSlice::new(
88 unsafe {
89 self.limbs.add(start_limb)
92 },
93 stride,
94 )
95 }
96
97 pub fn split_rows_at(&self, block_rows: usize) -> (MatrixTileSlice<'_>, MatrixTileSlice<'_>) {
98 assert!(
99 block_rows <= self.block_rows(),
100 "split point {block_rows} exceeds block_rows {}",
101 self.block_rows()
102 );
103 let (first_rows, second_rows) = (block_rows, self.block_rows() - block_rows);
104
105 let first = MatrixTileSlice {
106 limbs: self.limbs,
107 dimensions: [first_rows, self.dimensions[1]],
108 stride: self.stride,
109 _marker: std::marker::PhantomData,
110 };
111 let second = MatrixTileSlice {
112 limbs: unsafe { self.limbs.add(64 * first_rows * self.stride.get()) },
113 dimensions: [second_rows, self.dimensions[1]],
114 stride: self.stride,
115 _marker: std::marker::PhantomData,
116 };
117 (first, second)
118 }
119
120 pub fn split_columns_at(
121 &self,
122 block_columns: usize,
123 ) -> (MatrixTileSlice<'_>, MatrixTileSlice<'_>) {
124 assert!(
125 block_columns <= self.block_columns(),
126 "split point {block_columns} exceeds block_columns {}",
127 self.block_columns()
128 );
129 let (first_cols, second_cols) = (block_columns, self.block_columns() - block_columns);
130
131 let first = MatrixTileSlice {
132 limbs: self.limbs,
133 dimensions: [self.dimensions[0], first_cols],
134 stride: self.stride,
135 _marker: std::marker::PhantomData,
136 };
137 let second = MatrixTileSlice {
138 limbs: unsafe { self.limbs.add(first_cols) },
139 dimensions: [self.dimensions[0], second_cols],
140 stride: self.stride,
141 _marker: std::marker::PhantomData,
142 };
143 (first, second)
144 }
145}
146
147impl<'a> MatrixTileSliceMut<'a> {
148 pub fn from_matrix(m: &'a mut Matrix) -> Self {
149 let stride = m.stride().try_into().expect("Can't tile empty matrix");
150 Self {
151 limbs: m.data_mut().as_mut_ptr(),
152 dimensions: [m.physical_rows() / 64, m.columns().div_ceil(64)],
153 stride,
154 _marker: std::marker::PhantomData,
155 }
156 }
157
158 pub fn block_rows(&self) -> usize {
159 self.dimensions[0]
160 }
161
162 pub fn block_columns(&self) -> usize {
163 self.dimensions[1]
164 }
165
166 pub fn block_mut_at(&mut self, block_row: usize, block_col: usize) -> MatrixBlockSliceMut<'_> {
167 assert!(
168 block_row < self.dimensions[0],
169 "block_row {block_row} out of bounds (max {})",
170 self.dimensions[0]
171 );
172 assert!(
173 block_col < self.dimensions[1],
174 "block_col {block_col} out of bounds (max {})",
175 self.dimensions[1]
176 );
177
178 let start_limb = 64 * block_row * self.stride.get() + block_col;
179 let stride = self.stride;
180
181 MatrixBlockSliceMut::new(
182 unsafe {
183 self.limbs.add(start_limb)
186 },
187 stride,
188 )
189 }
190
191 pub fn split_rows_at_mut(
192 &mut self,
193 block_rows: usize,
194 ) -> (MatrixTileSliceMut<'_>, MatrixTileSliceMut<'_>) {
195 assert!(
196 block_rows <= self.block_rows(),
197 "split point {block_rows} exceeds block_rows {}",
198 self.block_rows()
199 );
200 let (first_rows, second_rows) = (block_rows, self.block_rows() - block_rows);
201
202 let first = MatrixTileSliceMut {
203 limbs: self.limbs,
204 dimensions: [first_rows, self.dimensions[1]],
205 stride: self.stride,
206 _marker: std::marker::PhantomData,
207 };
208 let second = MatrixTileSliceMut {
209 limbs: unsafe { self.limbs.add(64 * first_rows * self.stride.get()) },
210 dimensions: [second_rows, self.dimensions[1]],
211 stride: self.stride,
212 _marker: std::marker::PhantomData,
213 };
214 (first, second)
215 }
216
217 pub fn split_columns_at_mut(
218 &mut self,
219 block_columns: usize,
220 ) -> (MatrixTileSliceMut<'_>, MatrixTileSliceMut<'_>) {
221 assert!(
222 block_columns <= self.block_columns(),
223 "split point {block_columns} exceeds block_columns {}",
224 self.block_columns()
225 );
226 let (first_cols, second_cols) = (block_columns, self.block_columns() - block_columns);
227
228 let first = MatrixTileSliceMut {
229 limbs: self.limbs,
230 dimensions: [self.dimensions[0], first_cols],
231 stride: self.stride,
232 _marker: std::marker::PhantomData,
233 };
234 let second = MatrixTileSliceMut {
235 limbs: unsafe { self.limbs.add(first_cols) },
236 dimensions: [self.dimensions[0], second_cols],
237 stride: self.stride,
238 _marker: std::marker::PhantomData,
239 };
240 (first, second)
241 }
242
243 pub fn zero_out(&mut self) {
244 for block_row in 0..self.block_rows() {
245 for block_col in 0..self.block_columns() {
246 self.block_mut_at(block_row, block_col).zero_out();
247 }
248 }
249 }
250}
251
252unsafe impl Send for MatrixTileSlice<'_> {}
255unsafe impl Sync for MatrixTileSlice<'_> {}
256
257unsafe impl Send for MatrixTileSliceMut<'_> {}