fp/blas/
block.rs

1use std::num::NonZeroUsize;
2
3use crate::limb::Limb;
4
5/// A contiguous 64 x 64 block of bits stored in row-major order.
6///
7/// Each limb represents one row of 64 bits. The 128-byte alignment ensures efficient SIMD
8/// operations and cache line alignment.
9#[repr(align(128))]
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub struct MatrixBlock([Limb; 64]);
12
13impl MatrixBlock {
14    #[inline]
15    pub fn new(limbs: [Limb; 64]) -> Self {
16        Self(limbs)
17    }
18
19    /// Creates a zero-initialized block.
20    #[inline]
21    pub fn zero() -> Self {
22        Self([0; 64])
23    }
24
25    #[inline]
26    pub fn iter(&self) -> impl Iterator<Item = &Limb> {
27        self.0.iter()
28    }
29
30    /// Returns a mutable iterator over the limbs (rows) of this block.
31    #[inline]
32    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Limb> {
33        self.0.iter_mut()
34    }
35
36    #[cfg_attr(not(target_feature = "avx512f"), allow(dead_code))]
37    pub(crate) fn limbs_ptr(&self) -> *const Limb {
38        self.0.as_ptr()
39    }
40
41    #[cfg_attr(not(target_feature = "avx512f"), allow(dead_code))]
42    pub(crate) fn limbs_mut_ptr(&mut self) -> *mut Limb {
43        self.0.as_mut_ptr()
44    }
45}
46
47/// A non-contiguous view of a 64 x 64 block within a larger matrix.
48///
49/// The block is stored in row-major order with a configurable stride between rows. This allows
50/// efficient access to sub-blocks within a matrix without copying data.
51///
52/// # Safety
53///
54/// The `limbs` pointer must remain valid for the lifetime `'a`, and must point to at least 64 valid
55/// rows spaced `stride` limbs apart.
56pub struct MatrixBlockSlice<'a> {
57    limbs: *const Limb,
58    /// Number of limbs between consecutive rows
59    stride: NonZeroUsize,
60    _marker: std::marker::PhantomData<&'a ()>,
61}
62
63/// A mutable non-contiguous view of a 64 x 64 block within a larger matrix.
64///
65/// # Safety
66///
67/// The `limbs` pointer must remain valid and exclusively accessible for the lifetime `'a`, and must
68/// point to at least 64 valid rows spaced `stride` limbs apart.
69pub struct MatrixBlockSliceMut<'a> {
70    limbs: *mut Limb,
71    /// Number of limbs between consecutive rows
72    stride: NonZeroUsize,
73    _marker: std::marker::PhantomData<&'a mut ()>,
74}
75
76impl<'a> MatrixBlockSlice<'a> {
77    pub(super) fn new(limbs: *const Limb, stride: NonZeroUsize) -> Self {
78        Self {
79            limbs,
80            stride,
81            _marker: std::marker::PhantomData,
82        }
83    }
84
85    #[cfg_attr(not(target_arch = "x86_64"), allow(dead_code))]
86    pub(crate) fn limbs(&self) -> *const Limb {
87        self.limbs
88    }
89
90    #[cfg_attr(not(target_arch = "x86_64"), allow(dead_code))]
91    pub(crate) fn stride(&self) -> NonZeroUsize {
92        self.stride
93    }
94
95    /// Returns an iterator over the 64 rows of this block.
96    ///
97    /// # Safety
98    ///
99    /// Each element is obtained via `self.limbs.add(i * self.stride)`, which is safe because the
100    /// constructor guarantees 64 valid rows at the given stride.
101    pub fn iter(self) -> impl Iterator<Item = &'a Limb> {
102        (0..64).map(move |i| unsafe {
103            // SAFETY: Constructor guarantees 64 rows at stride intervals
104            &*self.limbs.add(i * self.stride.get())
105        })
106    }
107
108    /// Gathers the non-contiguous block into a contiguous `MatrixBlock`.
109    ///
110    /// This operation is necessary before performing block-level GEMM since the AVX-512 kernel
111    /// expects contiguous data.
112    #[inline]
113    pub fn gather(self) -> MatrixBlock {
114        // Delegate to SIMD specializations
115        crate::simd::gather_block_simd(self)
116    }
117}
118
119impl<'a> MatrixBlockSliceMut<'a> {
120    pub(super) fn new(limbs: *mut Limb, stride: NonZeroUsize) -> Self {
121        Self {
122            limbs,
123            stride,
124            _marker: std::marker::PhantomData,
125        }
126    }
127
128    /// Returns a mutable reference to the limb at the given row.
129    ///
130    /// # Safety
131    ///
132    /// The pointer arithmetic `self.limbs.add(row * self.stride)` is safe because the constructor
133    /// guarantees 64 valid rows, and this method will panic in debug mode if `row >= 64` (via debug
134    /// assertions in the caller).
135    #[inline]
136    pub fn get_mut(&mut self, row: usize) -> &mut Limb {
137        debug_assert!(row < 64, "row index {row} out of bounds for 64 x 64 block");
138        unsafe {
139            // SAFETY: Constructor guarantees 64 rows at stride intervals
140            &mut *self.limbs.add(row * self.stride.get())
141        }
142    }
143
144    /// Returns a mutable iterator over the 64 rows of this block.
145    #[inline]
146    pub fn iter_mut<'b>(&'b mut self) -> impl Iterator<Item = &'b mut Limb> + use<'a, 'b> {
147        (0..64).map(move |i| unsafe {
148            // SAFETY: Constructor guarantees 64 rows at stride intervals
149            &mut *self.limbs.add(i * self.stride.get())
150        })
151    }
152
153    /// Creates a copy of this mutable slice with a shorter lifetime.
154    ///
155    /// This is useful for splitting the lifetime when you need to pass the slice to a function that
156    /// doesn't need to hold it for the full `'a` lifetime.
157    #[inline]
158    pub fn copy(&mut self) -> MatrixBlockSliceMut<'_> {
159        MatrixBlockSliceMut {
160            limbs: self.limbs,
161            stride: self.stride,
162            _marker: std::marker::PhantomData,
163        }
164    }
165
166    /// Converts this mutable slice into an immutable slice.
167    #[inline]
168    pub fn as_slice(&self) -> MatrixBlockSlice<'_> {
169        MatrixBlockSlice {
170            limbs: self.limbs,
171            stride: self.stride,
172            _marker: std::marker::PhantomData,
173        }
174    }
175
176    /// Scatters a contiguous block into this non-contiguous slice.
177    ///
178    /// This is the inverse of `gather` and is used to write GEMM results back into the parent
179    /// matrix.
180    #[inline]
181    pub fn assign(&mut self, block: MatrixBlock) {
182        self.iter_mut()
183            .zip(block.iter())
184            .for_each(|(dst, &src)| *dst = src);
185    }
186
187    pub fn zero_out(&mut self) {
188        for limb in self.iter_mut() {
189            *limb = 0;
190        }
191    }
192}
193
194// SAFETY: The slices have &Limb / &mut Limb semantics, so inherit the same Send / Sync behavior.
195
196unsafe impl Send for MatrixBlockSlice<'_> {}
197unsafe impl Send for MatrixBlockSliceMut<'_> {}
198
199unsafe impl Sync for MatrixBlockSlice<'_> {}
200
201/// Performs block-level GEMM: `C = A * B + C` for 64 x 64 bit blocks.
202///
203/// # Arguments
204///
205/// * `a` - Left input block (64 x 64 bits)
206/// * `b` - Right input block (64 x 64 bits)
207/// * `c` - Accumulator block (64 x 64 bits)
208///
209/// For efficiency reasons, we mutate `C` in-place.
210///
211/// # Implementation Selection
212///
213/// - **x86_64 with AVX-512**: Uses optimized assembly kernel
214/// - **Other platforms**: Falls back to scalar implementation
215#[inline]
216pub fn gemm_block(a: MatrixBlock, b: MatrixBlock, c: &mut MatrixBlock) {
217    // Delegate to SIMD specializations
218    crate::simd::gemm_block_simd(a, b, c)
219}
220
221#[cfg(feature = "proptest")]
222mod arbitrary {
223
224    use proptest::prelude::*;
225
226    use super::*;
227
228    impl Arbitrary for MatrixBlock {
229        type Parameters = ();
230        type Strategy = BoxedStrategy<Self>;
231
232        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
233            proptest::array::uniform(any::<Limb>())
234                .prop_map(Self)
235                .boxed()
236        }
237    }
238}