fp/blas/block.rs
1use std::num::NonZeroUsize;
2
3use crate::limb::Limb;
4
5/// A contiguous 64 x 64 block of bits stored in row-major order.
6///
7/// Each limb represents one row of 64 bits. The 128-byte alignment ensures efficient SIMD
8/// operations and cache line alignment.
9#[repr(align(128))]
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub struct MatrixBlock([Limb; 64]);
12
13impl MatrixBlock {
14 #[inline]
15 pub fn new(limbs: [Limb; 64]) -> Self {
16 Self(limbs)
17 }
18
19 /// Creates a zero-initialized block.
20 #[inline]
21 pub fn zero() -> Self {
22 Self([0; 64])
23 }
24
25 #[inline]
26 pub fn iter(&self) -> impl Iterator<Item = &Limb> {
27 self.0.iter()
28 }
29
30 /// Returns a mutable iterator over the limbs (rows) of this block.
31 #[inline]
32 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Limb> {
33 self.0.iter_mut()
34 }
35
36 #[cfg_attr(not(target_feature = "avx512f"), allow(dead_code))]
37 pub(crate) fn limbs_ptr(&self) -> *const Limb {
38 self.0.as_ptr()
39 }
40
41 #[cfg_attr(not(target_feature = "avx512f"), allow(dead_code))]
42 pub(crate) fn limbs_mut_ptr(&mut self) -> *mut Limb {
43 self.0.as_mut_ptr()
44 }
45}
46
47/// A non-contiguous view of a 64 x 64 block within a larger matrix.
48///
49/// The block is stored in row-major order with a configurable stride between rows. This allows
50/// efficient access to sub-blocks within a matrix without copying data.
51///
52/// # Safety
53///
54/// The `limbs` pointer must remain valid for the lifetime `'a`, and must point to at least 64 valid
55/// rows spaced `stride` limbs apart.
56pub struct MatrixBlockSlice<'a> {
57 limbs: *const Limb,
58 /// Number of limbs between consecutive rows
59 stride: NonZeroUsize,
60 _marker: std::marker::PhantomData<&'a ()>,
61}
62
63/// A mutable non-contiguous view of a 64 x 64 block within a larger matrix.
64///
65/// # Safety
66///
67/// The `limbs` pointer must remain valid and exclusively accessible for the lifetime `'a`, and must
68/// point to at least 64 valid rows spaced `stride` limbs apart.
69pub struct MatrixBlockSliceMut<'a> {
70 limbs: *mut Limb,
71 /// Number of limbs between consecutive rows
72 stride: NonZeroUsize,
73 _marker: std::marker::PhantomData<&'a mut ()>,
74}
75
76impl<'a> MatrixBlockSlice<'a> {
77 pub(super) fn new(limbs: *const Limb, stride: NonZeroUsize) -> Self {
78 Self {
79 limbs,
80 stride,
81 _marker: std::marker::PhantomData,
82 }
83 }
84
85 #[cfg_attr(not(target_arch = "x86_64"), allow(dead_code))]
86 pub(crate) fn limbs(&self) -> *const Limb {
87 self.limbs
88 }
89
90 #[cfg_attr(not(target_arch = "x86_64"), allow(dead_code))]
91 pub(crate) fn stride(&self) -> NonZeroUsize {
92 self.stride
93 }
94
95 /// Returns an iterator over the 64 rows of this block.
96 ///
97 /// # Safety
98 ///
99 /// Each element is obtained via `self.limbs.add(i * self.stride)`, which is safe because the
100 /// constructor guarantees 64 valid rows at the given stride.
101 pub fn iter(self) -> impl Iterator<Item = &'a Limb> {
102 (0..64).map(move |i| unsafe {
103 // SAFETY: Constructor guarantees 64 rows at stride intervals
104 &*self.limbs.add(i * self.stride.get())
105 })
106 }
107
108 /// Gathers the non-contiguous block into a contiguous `MatrixBlock`.
109 ///
110 /// This operation is necessary before performing block-level GEMM since the AVX-512 kernel
111 /// expects contiguous data.
112 #[inline]
113 pub fn gather(self) -> MatrixBlock {
114 // Delegate to SIMD specializations
115 crate::simd::gather_block_simd(self)
116 }
117}
118
119impl<'a> MatrixBlockSliceMut<'a> {
120 pub(super) fn new(limbs: *mut Limb, stride: NonZeroUsize) -> Self {
121 Self {
122 limbs,
123 stride,
124 _marker: std::marker::PhantomData,
125 }
126 }
127
128 /// Returns a mutable reference to the limb at the given row.
129 ///
130 /// # Safety
131 ///
132 /// The pointer arithmetic `self.limbs.add(row * self.stride)` is safe because the constructor
133 /// guarantees 64 valid rows, and this method will panic in debug mode if `row >= 64` (via debug
134 /// assertions in the caller).
135 #[inline]
136 pub fn get_mut(&mut self, row: usize) -> &mut Limb {
137 debug_assert!(row < 64, "row index {row} out of bounds for 64 x 64 block");
138 unsafe {
139 // SAFETY: Constructor guarantees 64 rows at stride intervals
140 &mut *self.limbs.add(row * self.stride.get())
141 }
142 }
143
144 /// Returns a mutable iterator over the 64 rows of this block.
145 #[inline]
146 pub fn iter_mut<'b>(&'b mut self) -> impl Iterator<Item = &'b mut Limb> + use<'a, 'b> {
147 (0..64).map(move |i| unsafe {
148 // SAFETY: Constructor guarantees 64 rows at stride intervals
149 &mut *self.limbs.add(i * self.stride.get())
150 })
151 }
152
153 /// Creates a copy of this mutable slice with a shorter lifetime.
154 ///
155 /// This is useful for splitting the lifetime when you need to pass the slice to a function that
156 /// doesn't need to hold it for the full `'a` lifetime.
157 #[inline]
158 pub fn copy(&mut self) -> MatrixBlockSliceMut<'_> {
159 MatrixBlockSliceMut {
160 limbs: self.limbs,
161 stride: self.stride,
162 _marker: std::marker::PhantomData,
163 }
164 }
165
166 /// Converts this mutable slice into an immutable slice.
167 #[inline]
168 pub fn as_slice(&self) -> MatrixBlockSlice<'_> {
169 MatrixBlockSlice {
170 limbs: self.limbs,
171 stride: self.stride,
172 _marker: std::marker::PhantomData,
173 }
174 }
175
176 /// Scatters a contiguous block into this non-contiguous slice.
177 ///
178 /// This is the inverse of `gather` and is used to write GEMM results back into the parent
179 /// matrix.
180 #[inline]
181 pub fn assign(&mut self, block: MatrixBlock) {
182 self.iter_mut()
183 .zip(block.iter())
184 .for_each(|(dst, &src)| *dst = src);
185 }
186
187 pub fn zero_out(&mut self) {
188 for limb in self.iter_mut() {
189 *limb = 0;
190 }
191 }
192}
193
194// SAFETY: The slices have &Limb / &mut Limb semantics, so inherit the same Send / Sync behavior.
195
196unsafe impl Send for MatrixBlockSlice<'_> {}
197unsafe impl Send for MatrixBlockSliceMut<'_> {}
198
199unsafe impl Sync for MatrixBlockSlice<'_> {}
200
201/// Performs block-level GEMM: `C = A * B + C` for 64 x 64 bit blocks.
202///
203/// # Arguments
204///
205/// * `a` - Left input block (64 x 64 bits)
206/// * `b` - Right input block (64 x 64 bits)
207/// * `c` - Accumulator block (64 x 64 bits)
208///
209/// For efficiency reasons, we mutate `C` in-place.
210///
211/// # Implementation Selection
212///
213/// - **x86_64 with AVX-512**: Uses optimized assembly kernel
214/// - **Other platforms**: Falls back to scalar implementation
215#[inline]
216pub fn gemm_block(a: MatrixBlock, b: MatrixBlock, c: &mut MatrixBlock) {
217 // Delegate to SIMD specializations
218 crate::simd::gemm_block_simd(a, b, c)
219}
220
221#[cfg(feature = "proptest")]
222mod arbitrary {
223
224 use proptest::prelude::*;
225
226 use super::*;
227
228 impl Arbitrary for MatrixBlock {
229 type Parameters = ();
230 type Strategy = BoxedStrategy<Self>;
231
232 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
233 proptest::array::uniform(any::<Limb>())
234 .prop_map(Self)
235 .boxed()
236 }
237 }
238}