fp/simd/x86_64/
avx512.rs

1use std::arch::x86_64;
2
3use crate::{
4    blas::block::{MatrixBlock, MatrixBlockSlice},
5    limb::Limb,
6};
7
8type SimdLimb = x86_64::__m512i;
9
10#[target_feature(enable = "avx512f")]
11fn load(limb: *const Limb) -> SimdLimb {
12    unsafe { x86_64::_mm512_loadu_si512(limb as *const SimdLimb) }
13}
14
15#[target_feature(enable = "avx512f")]
16fn store(limb: *mut Limb, val: SimdLimb) {
17    unsafe { x86_64::_mm512_storeu_si512(limb as *mut SimdLimb, val) }
18}
19
20#[target_feature(enable = "avx512f")]
21fn xor(left: SimdLimb, right: SimdLimb) -> SimdLimb {
22    x86_64::_mm512_xor_si512(left, right)
23}
24
25super::add_simd_arch!("avx512f");
26
27const UNIT_OFFSETS: [i64; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
28
29/// Performs C = A * B + C where A, B, C are 64x64 matrices
30#[target_feature(enable = "avx512f")]
31pub fn gemm_block_simd(a: MatrixBlock, b: MatrixBlock, c: &mut MatrixBlock) {
32    unsafe {
33        std::arch::asm!(
34            "mov {zmm_idx}, 0",
35            "mov {limb_idx}, 0",
36            "2:",
37
38            "vmovdqa64 zmm10, zmmword ptr [{c_data_ptr} + {zmm_idx}]",
39
40            "vmovdqa64 zmm12, zmmword ptr [{b_data_ptr}]",
41            "vmovdqa64 zmm13, zmmword ptr [{b_data_ptr} + 64]",
42            "vmovdqa64 zmm14, zmmword ptr [{b_data_ptr} + 64*2]",
43            "vmovdqa64 zmm15, zmmword ptr [{b_data_ptr} + 64*3]",
44
45            "mov {limb0}, [{a_data_ptr} + 8*{limb_idx} + 0]",
46            "mov {limb1}, [{a_data_ptr} + 8*{limb_idx} + 8]",
47            "mov {limb2}, [{a_data_ptr} + 8*{limb_idx} + 16]",
48            "mov {limb3}, [{a_data_ptr} + 8*{limb_idx} + 24]",
49            "mov {limb4}, [{a_data_ptr} + 8*{limb_idx} + 32]",
50            "mov {limb5}, [{a_data_ptr} + 8*{limb_idx} + 40]",
51            "mov {limb6}, [{a_data_ptr} + 8*{limb_idx} + 48]",
52            "mov {limb7}, [{a_data_ptr} + 8*{limb_idx} + 56]",
53
54            "kmovq k1, {limb0}",
55            "kmovq k2, {limb1}",
56            "kmovq k3, {limb2}",
57            "kmovq k4, {limb3}",
58
59            "vpxorq zmm0, zmm0, zmm0",
60            "vpxorq zmm1, zmm1, zmm1",
61            "vpxorq zmm2, zmm2, zmm2",
62            "vpxorq zmm3, zmm3, zmm3",
63            "vpxorq zmm4, zmm4, zmm4",
64            "vpxorq zmm5, zmm5, zmm5",
65            "vpxorq zmm6, zmm6, zmm6",
66            "vpxorq zmm7, zmm7, zmm7",
67
68            "vpxorq zmm0 {{k1}}, zmm0, zmm12",
69            "vpxorq zmm1 {{k2}}, zmm1, zmm12",
70            "vpxorq zmm2 {{k3}}, zmm2, zmm12",
71            "vpxorq zmm3 {{k4}}, zmm3, zmm12",
72
73            "kshiftrq k1, k1, 8",
74            "kshiftrq k2, k2, 8",
75            "kshiftrq k3, k3, 8",
76            "kshiftrq k4, k4, 8",
77            "vpxorq zmm0 {{k1}}, zmm0, zmm13",
78            "vpxorq zmm1 {{k2}}, zmm1, zmm13",
79            "vpxorq zmm2 {{k3}}, zmm2, zmm13",
80            "vpxorq zmm3 {{k4}}, zmm3, zmm13",
81
82            "kshiftrq k1, k1, 8",
83            "kshiftrq k2, k2, 8",
84            "kshiftrq k3, k3, 8",
85            "kshiftrq k4, k4, 8",
86            "vpxorq zmm0 {{k1}}, zmm0, zmm14",
87            "vpxorq zmm1 {{k2}}, zmm1, zmm14",
88            "vpxorq zmm2 {{k3}}, zmm2, zmm14",
89            "vpxorq zmm3 {{k4}}, zmm3, zmm14",
90
91            "kshiftrq k1, k1, 8",
92            "kshiftrq k2, k2, 8",
93            "kshiftrq k3, k3, 8",
94            "kshiftrq k4, k4, 8",
95            "vpxorq zmm0 {{k1}}, zmm0, zmm15",
96            "vpxorq zmm1 {{k2}}, zmm1, zmm15",
97            "vpxorq zmm2 {{k3}}, zmm2, zmm15",
98            "vpxorq zmm3 {{k4}}, zmm3, zmm15",
99
100            "kmovq k1, {limb4}",
101            "kmovq k2, {limb5}",
102            "kmovq k3, {limb6}",
103            "kmovq k4, {limb7}",
104
105            "vpxorq zmm4 {{k1}}, zmm4, zmm12",
106            "vpxorq zmm5 {{k2}}, zmm5, zmm12",
107            "vpxorq zmm6 {{k3}}, zmm6, zmm12",
108            "vpxorq zmm7 {{k4}}, zmm7, zmm12",
109
110            "kshiftrq k1, k1, 8",
111            "kshiftrq k2, k2, 8",
112            "kshiftrq k3, k3, 8",
113            "kshiftrq k4, k4, 8",
114            "vpxorq zmm4 {{k1}}, zmm4, zmm13",
115            "vpxorq zmm5 {{k2}}, zmm5, zmm13",
116            "vpxorq zmm6 {{k3}}, zmm6, zmm13",
117            "vpxorq zmm7 {{k4}}, zmm7, zmm13",
118
119            "kshiftrq k1, k1, 8",
120            "kshiftrq k2, k2, 8",
121            "kshiftrq k3, k3, 8",
122            "kshiftrq k4, k4, 8",
123            "vpxorq zmm4 {{k1}}, zmm4, zmm14",
124            "vpxorq zmm5 {{k2}}, zmm5, zmm14",
125            "vpxorq zmm6 {{k3}}, zmm6, zmm14",
126            "vpxorq zmm7 {{k4}}, zmm7, zmm14",
127
128            "kshiftrq k1, k1, 8",
129            "kshiftrq k2, k2, 8",
130            "kshiftrq k3, k3, 8",
131            "kshiftrq k4, k4, 8",
132            "vpxorq zmm4 {{k1}}, zmm4, zmm15",
133            "vpxorq zmm5 {{k2}}, zmm5, zmm15",
134            "vpxorq zmm6 {{k3}}, zmm6, zmm15",
135            "vpxorq zmm7 {{k4}}, zmm7, zmm15",
136
137            "shr {limb0}, 32",
138            "shr {limb1}, 32",
139            "shr {limb2}, 32",
140            "shr {limb3}, 32",
141            "shr {limb4}, 32",
142            "shr {limb5}, 32",
143            "shr {limb6}, 32",
144            "shr {limb7}, 32",
145
146            "vmovdqa64 zmm12, zmmword ptr [{b_data_ptr} + 64*4]",
147            "vmovdqa64 zmm13, zmmword ptr [{b_data_ptr} + 64*5]",
148            "vmovdqa64 zmm14, zmmword ptr [{b_data_ptr} + 64*6]",
149            "vmovdqa64 zmm15, zmmword ptr [{b_data_ptr} + 64*7]",
150
151            "kmovq k1, {limb0}",
152            "kmovq k2, {limb1}",
153            "kmovq k3, {limb2}",
154            "kmovq k4, {limb3}",
155
156            "vpxorq zmm0 {{k1}}, zmm0, zmm12",
157            "vpxorq zmm1 {{k2}}, zmm1, zmm12",
158            "vpxorq zmm2 {{k3}}, zmm2, zmm12",
159            "vpxorq zmm3 {{k4}}, zmm3, zmm12",
160
161            "kshiftrq k1, k1, 8",
162            "kshiftrq k2, k2, 8",
163            "kshiftrq k3, k3, 8",
164            "kshiftrq k4, k4, 8",
165            "vpxorq zmm0 {{k1}}, zmm0, zmm13",
166            "vpxorq zmm1 {{k2}}, zmm1, zmm13",
167            "vpxorq zmm2 {{k3}}, zmm2, zmm13",
168            "vpxorq zmm3 {{k4}}, zmm3, zmm13",
169
170            "kshiftrq k1, k1, 8",
171            "kshiftrq k2, k2, 8",
172            "kshiftrq k3, k3, 8",
173            "kshiftrq k4, k4, 8",
174            "vpxorq zmm0 {{k1}}, zmm0, zmm14",
175            "vpxorq zmm1 {{k2}}, zmm1, zmm14",
176            "vpxorq zmm2 {{k3}}, zmm2, zmm14",
177            "vpxorq zmm3 {{k4}}, zmm3, zmm14",
178
179            "kshiftrq k1, k1, 8",
180            "kshiftrq k2, k2, 8",
181            "kshiftrq k3, k3, 8",
182            "kshiftrq k4, k4, 8",
183            "vpxorq zmm0 {{k1}}, zmm0, zmm15",
184            "vpxorq zmm1 {{k2}}, zmm1, zmm15",
185            "vpxorq zmm2 {{k3}}, zmm2, zmm15",
186            "vpxorq zmm3 {{k4}}, zmm3, zmm15",
187
188            "kmovq k1, {limb4}",
189            "kmovq k2, {limb5}",
190            "kmovq k3, {limb6}",
191            "kmovq k4, {limb7}",
192
193            "vpxorq zmm4 {{k1}}, zmm4, zmm12",
194            "vpxorq zmm5 {{k2}}, zmm5, zmm12",
195            "vpxorq zmm6 {{k3}}, zmm6, zmm12",
196            "vpxorq zmm7 {{k4}}, zmm7, zmm12",
197
198            "kshiftrq k1, k1, 8",
199            "kshiftrq k2, k2, 8",
200            "kshiftrq k3, k3, 8",
201            "kshiftrq k4, k4, 8",
202            "vpxorq zmm4 {{k1}}, zmm4, zmm13",
203            "vpxorq zmm5 {{k2}}, zmm5, zmm13",
204            "vpxorq zmm6 {{k3}}, zmm6, zmm13",
205            "vpxorq zmm7 {{k4}}, zmm7, zmm13",
206
207            "kshiftrq k1, k1, 8",
208            "kshiftrq k2, k2, 8",
209            "kshiftrq k3, k3, 8",
210            "kshiftrq k4, k4, 8",
211            "vpxorq zmm4 {{k1}}, zmm4, zmm14",
212            "vpxorq zmm5 {{k2}}, zmm5, zmm14",
213            "vpxorq zmm6 {{k3}}, zmm6, zmm14",
214            "vpxorq zmm7 {{k4}}, zmm7, zmm14",
215
216            "kshiftrq k1, k1, 8",
217            "kshiftrq k2, k2, 8",
218            "kshiftrq k3, k3, 8",
219            "kshiftrq k4, k4, 8",
220            "vpxorq zmm4 {{k1}}, zmm4, zmm15",
221            "vpxorq zmm5 {{k2}}, zmm5, zmm15",
222            "vpxorq zmm6 {{k3}}, zmm6, zmm15",
223            "vpxorq zmm7 {{k4}}, zmm7, zmm15",
224
225            "kmovq k1, {one}",
226
227            "vpermq zmm8, zmm0, {permute1}",
228            "vpermq zmm9, zmm1, {permute1}",
229            "vpxorq zmm0, zmm0, zmm8",
230            "vpxorq zmm1, zmm1, zmm9",
231            "vpermq zmm8, zmm0, {permute2}",
232            "vpermq zmm9, zmm1, {permute2}",
233            "vpxorq zmm0, zmm0, zmm8",
234            "vpxorq zmm1, zmm1, zmm9",
235            "vshufi64x2 zmm8, zmm0, zmm0, {permute2}",
236            "vshufi64x2 zmm9, zmm1, zmm1, {permute2}",
237            "vpxorq zmm0, zmm0, zmm8",
238            "vpxorq zmm1, zmm1, zmm9",
239
240            "vpxorq zmm10 {{k1}}, zmm10, zmm0",
241            "kshiftlq k1, k1, 1",
242            "vpxorq zmm10 {{k1}}, zmm10, zmm1",
243            "kshiftlq k1, k1, 1",
244
245            "vpermq zmm8, zmm2, {permute1}",
246            "vpermq zmm9, zmm3, {permute1}",
247            "vpxorq zmm2, zmm2, zmm8",
248            "vpxorq zmm3, zmm3, zmm9",
249            "vpermq zmm8, zmm2, {permute2}",
250            "vpermq zmm9, zmm3, {permute2}",
251            "vpxorq zmm2, zmm2, zmm8",
252            "vpxorq zmm3, zmm3, zmm9",
253            "vshufi64x2 zmm8, zmm2, zmm2, {permute2}",
254            "vshufi64x2 zmm9, zmm3, zmm3, {permute2}",
255            "vpxorq zmm2, zmm2, zmm8",
256            "vpxorq zmm3, zmm3, zmm9",
257
258            "vpxorq zmm10 {{k1}}, zmm10, zmm2",
259            "kshiftlq k1, k1, 1",
260            "vpxorq zmm10 {{k1}}, zmm10, zmm3",
261            "kshiftlq k1, k1, 1",
262
263            "vpermq zmm8, zmm4, {permute1}",
264            "vpermq zmm9, zmm5, {permute1}",
265            "vpxorq zmm4, zmm4, zmm8",
266            "vpxorq zmm5, zmm5, zmm9",
267            "vpermq zmm8, zmm4, {permute2}",
268            "vpermq zmm9, zmm5, {permute2}",
269            "vpxorq zmm4, zmm4, zmm8",
270            "vpxorq zmm5, zmm5, zmm9",
271            "vshufi64x2 zmm8, zmm4, zmm4, {permute2}",
272            "vshufi64x2 zmm9, zmm5, zmm5, {permute2}",
273            "vpxorq zmm4, zmm4, zmm8",
274            "vpxorq zmm5, zmm5, zmm9",
275
276            "vpxorq zmm10 {{k1}}, zmm10, zmm4",
277            "kshiftlq k1, k1, 1",
278            "vpxorq zmm10 {{k1}}, zmm10, zmm5",
279            "kshiftlq k1, k1, 1",
280
281            "vpermq zmm8, zmm6, {permute1}",
282            "vpermq zmm9, zmm7, {permute1}",
283            "vpxorq zmm6, zmm6, zmm8",
284            "vpxorq zmm7, zmm7, zmm9",
285            "vpermq zmm8, zmm6, {permute2}",
286            "vpermq zmm9, zmm7, {permute2}",
287            "vpxorq zmm6, zmm6, zmm8",
288            "vpxorq zmm7, zmm7, zmm9",
289            "vshufi64x2 zmm8, zmm6, zmm6, {permute2}",
290            "vshufi64x2 zmm9, zmm7, zmm7, {permute2}",
291            "vpxorq zmm6, zmm6, zmm8",
292            "vpxorq zmm7, zmm7, zmm9",
293
294            "vpxorq zmm10 {{k1}}, zmm10, zmm6",
295            "kshiftlq k1, k1, 1",
296            "vpxorq zmm10 {{k1}}, zmm10, zmm7",
297
298            "vmovdqa64 zmmword ptr [{c_data_ptr} + {zmm_idx}], zmm10",
299
300            "add {limb_idx}, 8",
301            "add {zmm_idx}, 64",
302            "cmp {limb_idx}, 64",
303            "jl 2b",
304
305            permute1 = const 0b10110001, // Permutation for horizontal XOR
306            permute2 = const 0b01001110, // Permutation for horizontal XOR
307
308            // Constraints
309            a_data_ptr = in(reg) a.limbs_ptr(),
310            b_data_ptr = in(reg) b.limbs_ptr(),
311            c_data_ptr = in(reg) c.limbs_mut_ptr(),
312            one = in(reg) 1u64,
313
314            // Counters
315            limb_idx = out(reg) _,
316            zmm_idx = out(reg) _,
317
318            // Scratch registers
319            limb0 = out(reg) _, limb1 = out(reg) _, limb2 = out(reg) _, limb3 = out(reg) _,
320            limb4 = out(reg) _, limb5 = out(reg) _, limb6 = out(reg) _, limb7 = out(reg) _,
321
322            // 4 k-registers for in-place rotation
323            out("k1") _, out("k2") _, out("k3") _, out("k4") _,
324
325            // ZMM registers
326            out("zmm0") _, out("zmm1") _, out("zmm2") _, out("zmm3") _,     // Results 0-3
327            out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _,     // Results 4-7
328            out("zmm8") _, out("zmm9") _,                                   // Temps for horizontal XOR
329            out("zmm10") _,                                                 // C[0], C[1], etc.
330            out("zmm12") _, out("zmm13") _, out("zmm14") _, out("zmm15") _, // B[0-3] and B[4-7]
331
332            options(nostack)
333        )
334    }
335}
336
337#[derive(Clone, Copy)]
338#[repr(align(128))]
339struct SimdBlock([SimdLimb; 8]);
340
341impl SimdBlock {
342    #[target_feature(enable = "avx512f")]
343    fn zero() -> Self {
344        Self([x86_64::_mm512_setzero_si512(); 8])
345    }
346
347    #[target_feature(enable = "avx512f")]
348    fn as_matrix_block(&self) -> MatrixBlock {
349        unsafe { std::mem::transmute::<Self, MatrixBlock>(*self) }
350    }
351}
352
353#[target_feature(enable = "avx512f")]
354pub unsafe fn gather_simd(slice: MatrixBlockSlice) -> MatrixBlock {
355    let mut result = SimdBlock::zero();
356    let offsets = unsafe { x86_64::_mm512_loadu_epi64(&UNIT_OFFSETS as *const i64) };
357    let stride = x86_64::_mm512_set1_epi64(slice.stride().get() as i64);
358    let offsets = unsafe { x86_64::_mm512_mullo_epi64(offsets, stride) };
359
360    for i in 0..8 {
361        let ptr = unsafe { slice.limbs().add(8 * i * slice.stride().get()) as *const i64 };
362        result.0[i] = unsafe { x86_64::_mm512_i64gather_epi64::<8>(offsets, ptr) };
363    }
364    result.as_matrix_block()
365}