1use std::arch::x86_64;
2
3use crate::{
4 blas::block::{MatrixBlock, MatrixBlockSlice},
5 limb::Limb,
6};
7
8type SimdLimb = x86_64::__m512i;
9
10#[target_feature(enable = "avx512f")]
11fn load(limb: *const Limb) -> SimdLimb {
12 unsafe { x86_64::_mm512_loadu_si512(limb as *const SimdLimb) }
13}
14
15#[target_feature(enable = "avx512f")]
16fn store(limb: *mut Limb, val: SimdLimb) {
17 unsafe { x86_64::_mm512_storeu_si512(limb as *mut SimdLimb, val) }
18}
19
20#[target_feature(enable = "avx512f")]
21fn xor(left: SimdLimb, right: SimdLimb) -> SimdLimb {
22 x86_64::_mm512_xor_si512(left, right)
23}
24
25super::add_simd_arch!("avx512f");
26
27const UNIT_OFFSETS: [i64; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
28
29#[target_feature(enable = "avx512f")]
31pub fn gemm_block_simd(a: MatrixBlock, b: MatrixBlock, c: &mut MatrixBlock) {
32 unsafe {
33 std::arch::asm!(
34 "mov {zmm_idx}, 0",
35 "mov {limb_idx}, 0",
36 "2:",
37
38 "vmovdqa64 zmm10, zmmword ptr [{c_data_ptr} + {zmm_idx}]",
39
40 "vmovdqa64 zmm12, zmmword ptr [{b_data_ptr}]",
41 "vmovdqa64 zmm13, zmmword ptr [{b_data_ptr} + 64]",
42 "vmovdqa64 zmm14, zmmword ptr [{b_data_ptr} + 64*2]",
43 "vmovdqa64 zmm15, zmmword ptr [{b_data_ptr} + 64*3]",
44
45 "mov {limb0}, [{a_data_ptr} + 8*{limb_idx} + 0]",
46 "mov {limb1}, [{a_data_ptr} + 8*{limb_idx} + 8]",
47 "mov {limb2}, [{a_data_ptr} + 8*{limb_idx} + 16]",
48 "mov {limb3}, [{a_data_ptr} + 8*{limb_idx} + 24]",
49 "mov {limb4}, [{a_data_ptr} + 8*{limb_idx} + 32]",
50 "mov {limb5}, [{a_data_ptr} + 8*{limb_idx} + 40]",
51 "mov {limb6}, [{a_data_ptr} + 8*{limb_idx} + 48]",
52 "mov {limb7}, [{a_data_ptr} + 8*{limb_idx} + 56]",
53
54 "kmovq k1, {limb0}",
55 "kmovq k2, {limb1}",
56 "kmovq k3, {limb2}",
57 "kmovq k4, {limb3}",
58
59 "vpxorq zmm0, zmm0, zmm0",
60 "vpxorq zmm1, zmm1, zmm1",
61 "vpxorq zmm2, zmm2, zmm2",
62 "vpxorq zmm3, zmm3, zmm3",
63 "vpxorq zmm4, zmm4, zmm4",
64 "vpxorq zmm5, zmm5, zmm5",
65 "vpxorq zmm6, zmm6, zmm6",
66 "vpxorq zmm7, zmm7, zmm7",
67
68 "vpxorq zmm0 {{k1}}, zmm0, zmm12",
69 "vpxorq zmm1 {{k2}}, zmm1, zmm12",
70 "vpxorq zmm2 {{k3}}, zmm2, zmm12",
71 "vpxorq zmm3 {{k4}}, zmm3, zmm12",
72
73 "kshiftrq k1, k1, 8",
74 "kshiftrq k2, k2, 8",
75 "kshiftrq k3, k3, 8",
76 "kshiftrq k4, k4, 8",
77 "vpxorq zmm0 {{k1}}, zmm0, zmm13",
78 "vpxorq zmm1 {{k2}}, zmm1, zmm13",
79 "vpxorq zmm2 {{k3}}, zmm2, zmm13",
80 "vpxorq zmm3 {{k4}}, zmm3, zmm13",
81
82 "kshiftrq k1, k1, 8",
83 "kshiftrq k2, k2, 8",
84 "kshiftrq k3, k3, 8",
85 "kshiftrq k4, k4, 8",
86 "vpxorq zmm0 {{k1}}, zmm0, zmm14",
87 "vpxorq zmm1 {{k2}}, zmm1, zmm14",
88 "vpxorq zmm2 {{k3}}, zmm2, zmm14",
89 "vpxorq zmm3 {{k4}}, zmm3, zmm14",
90
91 "kshiftrq k1, k1, 8",
92 "kshiftrq k2, k2, 8",
93 "kshiftrq k3, k3, 8",
94 "kshiftrq k4, k4, 8",
95 "vpxorq zmm0 {{k1}}, zmm0, zmm15",
96 "vpxorq zmm1 {{k2}}, zmm1, zmm15",
97 "vpxorq zmm2 {{k3}}, zmm2, zmm15",
98 "vpxorq zmm3 {{k4}}, zmm3, zmm15",
99
100 "kmovq k1, {limb4}",
101 "kmovq k2, {limb5}",
102 "kmovq k3, {limb6}",
103 "kmovq k4, {limb7}",
104
105 "vpxorq zmm4 {{k1}}, zmm4, zmm12",
106 "vpxorq zmm5 {{k2}}, zmm5, zmm12",
107 "vpxorq zmm6 {{k3}}, zmm6, zmm12",
108 "vpxorq zmm7 {{k4}}, zmm7, zmm12",
109
110 "kshiftrq k1, k1, 8",
111 "kshiftrq k2, k2, 8",
112 "kshiftrq k3, k3, 8",
113 "kshiftrq k4, k4, 8",
114 "vpxorq zmm4 {{k1}}, zmm4, zmm13",
115 "vpxorq zmm5 {{k2}}, zmm5, zmm13",
116 "vpxorq zmm6 {{k3}}, zmm6, zmm13",
117 "vpxorq zmm7 {{k4}}, zmm7, zmm13",
118
119 "kshiftrq k1, k1, 8",
120 "kshiftrq k2, k2, 8",
121 "kshiftrq k3, k3, 8",
122 "kshiftrq k4, k4, 8",
123 "vpxorq zmm4 {{k1}}, zmm4, zmm14",
124 "vpxorq zmm5 {{k2}}, zmm5, zmm14",
125 "vpxorq zmm6 {{k3}}, zmm6, zmm14",
126 "vpxorq zmm7 {{k4}}, zmm7, zmm14",
127
128 "kshiftrq k1, k1, 8",
129 "kshiftrq k2, k2, 8",
130 "kshiftrq k3, k3, 8",
131 "kshiftrq k4, k4, 8",
132 "vpxorq zmm4 {{k1}}, zmm4, zmm15",
133 "vpxorq zmm5 {{k2}}, zmm5, zmm15",
134 "vpxorq zmm6 {{k3}}, zmm6, zmm15",
135 "vpxorq zmm7 {{k4}}, zmm7, zmm15",
136
137 "shr {limb0}, 32",
138 "shr {limb1}, 32",
139 "shr {limb2}, 32",
140 "shr {limb3}, 32",
141 "shr {limb4}, 32",
142 "shr {limb5}, 32",
143 "shr {limb6}, 32",
144 "shr {limb7}, 32",
145
146 "vmovdqa64 zmm12, zmmword ptr [{b_data_ptr} + 64*4]",
147 "vmovdqa64 zmm13, zmmword ptr [{b_data_ptr} + 64*5]",
148 "vmovdqa64 zmm14, zmmword ptr [{b_data_ptr} + 64*6]",
149 "vmovdqa64 zmm15, zmmword ptr [{b_data_ptr} + 64*7]",
150
151 "kmovq k1, {limb0}",
152 "kmovq k2, {limb1}",
153 "kmovq k3, {limb2}",
154 "kmovq k4, {limb3}",
155
156 "vpxorq zmm0 {{k1}}, zmm0, zmm12",
157 "vpxorq zmm1 {{k2}}, zmm1, zmm12",
158 "vpxorq zmm2 {{k3}}, zmm2, zmm12",
159 "vpxorq zmm3 {{k4}}, zmm3, zmm12",
160
161 "kshiftrq k1, k1, 8",
162 "kshiftrq k2, k2, 8",
163 "kshiftrq k3, k3, 8",
164 "kshiftrq k4, k4, 8",
165 "vpxorq zmm0 {{k1}}, zmm0, zmm13",
166 "vpxorq zmm1 {{k2}}, zmm1, zmm13",
167 "vpxorq zmm2 {{k3}}, zmm2, zmm13",
168 "vpxorq zmm3 {{k4}}, zmm3, zmm13",
169
170 "kshiftrq k1, k1, 8",
171 "kshiftrq k2, k2, 8",
172 "kshiftrq k3, k3, 8",
173 "kshiftrq k4, k4, 8",
174 "vpxorq zmm0 {{k1}}, zmm0, zmm14",
175 "vpxorq zmm1 {{k2}}, zmm1, zmm14",
176 "vpxorq zmm2 {{k3}}, zmm2, zmm14",
177 "vpxorq zmm3 {{k4}}, zmm3, zmm14",
178
179 "kshiftrq k1, k1, 8",
180 "kshiftrq k2, k2, 8",
181 "kshiftrq k3, k3, 8",
182 "kshiftrq k4, k4, 8",
183 "vpxorq zmm0 {{k1}}, zmm0, zmm15",
184 "vpxorq zmm1 {{k2}}, zmm1, zmm15",
185 "vpxorq zmm2 {{k3}}, zmm2, zmm15",
186 "vpxorq zmm3 {{k4}}, zmm3, zmm15",
187
188 "kmovq k1, {limb4}",
189 "kmovq k2, {limb5}",
190 "kmovq k3, {limb6}",
191 "kmovq k4, {limb7}",
192
193 "vpxorq zmm4 {{k1}}, zmm4, zmm12",
194 "vpxorq zmm5 {{k2}}, zmm5, zmm12",
195 "vpxorq zmm6 {{k3}}, zmm6, zmm12",
196 "vpxorq zmm7 {{k4}}, zmm7, zmm12",
197
198 "kshiftrq k1, k1, 8",
199 "kshiftrq k2, k2, 8",
200 "kshiftrq k3, k3, 8",
201 "kshiftrq k4, k4, 8",
202 "vpxorq zmm4 {{k1}}, zmm4, zmm13",
203 "vpxorq zmm5 {{k2}}, zmm5, zmm13",
204 "vpxorq zmm6 {{k3}}, zmm6, zmm13",
205 "vpxorq zmm7 {{k4}}, zmm7, zmm13",
206
207 "kshiftrq k1, k1, 8",
208 "kshiftrq k2, k2, 8",
209 "kshiftrq k3, k3, 8",
210 "kshiftrq k4, k4, 8",
211 "vpxorq zmm4 {{k1}}, zmm4, zmm14",
212 "vpxorq zmm5 {{k2}}, zmm5, zmm14",
213 "vpxorq zmm6 {{k3}}, zmm6, zmm14",
214 "vpxorq zmm7 {{k4}}, zmm7, zmm14",
215
216 "kshiftrq k1, k1, 8",
217 "kshiftrq k2, k2, 8",
218 "kshiftrq k3, k3, 8",
219 "kshiftrq k4, k4, 8",
220 "vpxorq zmm4 {{k1}}, zmm4, zmm15",
221 "vpxorq zmm5 {{k2}}, zmm5, zmm15",
222 "vpxorq zmm6 {{k3}}, zmm6, zmm15",
223 "vpxorq zmm7 {{k4}}, zmm7, zmm15",
224
225 "kmovq k1, {one}",
226
227 "vpermq zmm8, zmm0, {permute1}",
228 "vpermq zmm9, zmm1, {permute1}",
229 "vpxorq zmm0, zmm0, zmm8",
230 "vpxorq zmm1, zmm1, zmm9",
231 "vpermq zmm8, zmm0, {permute2}",
232 "vpermq zmm9, zmm1, {permute2}",
233 "vpxorq zmm0, zmm0, zmm8",
234 "vpxorq zmm1, zmm1, zmm9",
235 "vshufi64x2 zmm8, zmm0, zmm0, {permute2}",
236 "vshufi64x2 zmm9, zmm1, zmm1, {permute2}",
237 "vpxorq zmm0, zmm0, zmm8",
238 "vpxorq zmm1, zmm1, zmm9",
239
240 "vpxorq zmm10 {{k1}}, zmm10, zmm0",
241 "kshiftlq k1, k1, 1",
242 "vpxorq zmm10 {{k1}}, zmm10, zmm1",
243 "kshiftlq k1, k1, 1",
244
245 "vpermq zmm8, zmm2, {permute1}",
246 "vpermq zmm9, zmm3, {permute1}",
247 "vpxorq zmm2, zmm2, zmm8",
248 "vpxorq zmm3, zmm3, zmm9",
249 "vpermq zmm8, zmm2, {permute2}",
250 "vpermq zmm9, zmm3, {permute2}",
251 "vpxorq zmm2, zmm2, zmm8",
252 "vpxorq zmm3, zmm3, zmm9",
253 "vshufi64x2 zmm8, zmm2, zmm2, {permute2}",
254 "vshufi64x2 zmm9, zmm3, zmm3, {permute2}",
255 "vpxorq zmm2, zmm2, zmm8",
256 "vpxorq zmm3, zmm3, zmm9",
257
258 "vpxorq zmm10 {{k1}}, zmm10, zmm2",
259 "kshiftlq k1, k1, 1",
260 "vpxorq zmm10 {{k1}}, zmm10, zmm3",
261 "kshiftlq k1, k1, 1",
262
263 "vpermq zmm8, zmm4, {permute1}",
264 "vpermq zmm9, zmm5, {permute1}",
265 "vpxorq zmm4, zmm4, zmm8",
266 "vpxorq zmm5, zmm5, zmm9",
267 "vpermq zmm8, zmm4, {permute2}",
268 "vpermq zmm9, zmm5, {permute2}",
269 "vpxorq zmm4, zmm4, zmm8",
270 "vpxorq zmm5, zmm5, zmm9",
271 "vshufi64x2 zmm8, zmm4, zmm4, {permute2}",
272 "vshufi64x2 zmm9, zmm5, zmm5, {permute2}",
273 "vpxorq zmm4, zmm4, zmm8",
274 "vpxorq zmm5, zmm5, zmm9",
275
276 "vpxorq zmm10 {{k1}}, zmm10, zmm4",
277 "kshiftlq k1, k1, 1",
278 "vpxorq zmm10 {{k1}}, zmm10, zmm5",
279 "kshiftlq k1, k1, 1",
280
281 "vpermq zmm8, zmm6, {permute1}",
282 "vpermq zmm9, zmm7, {permute1}",
283 "vpxorq zmm6, zmm6, zmm8",
284 "vpxorq zmm7, zmm7, zmm9",
285 "vpermq zmm8, zmm6, {permute2}",
286 "vpermq zmm9, zmm7, {permute2}",
287 "vpxorq zmm6, zmm6, zmm8",
288 "vpxorq zmm7, zmm7, zmm9",
289 "vshufi64x2 zmm8, zmm6, zmm6, {permute2}",
290 "vshufi64x2 zmm9, zmm7, zmm7, {permute2}",
291 "vpxorq zmm6, zmm6, zmm8",
292 "vpxorq zmm7, zmm7, zmm9",
293
294 "vpxorq zmm10 {{k1}}, zmm10, zmm6",
295 "kshiftlq k1, k1, 1",
296 "vpxorq zmm10 {{k1}}, zmm10, zmm7",
297
298 "vmovdqa64 zmmword ptr [{c_data_ptr} + {zmm_idx}], zmm10",
299
300 "add {limb_idx}, 8",
301 "add {zmm_idx}, 64",
302 "cmp {limb_idx}, 64",
303 "jl 2b",
304
305 permute1 = const 0b10110001, permute2 = const 0b01001110, a_data_ptr = in(reg) a.limbs_ptr(),
310 b_data_ptr = in(reg) b.limbs_ptr(),
311 c_data_ptr = in(reg) c.limbs_mut_ptr(),
312 one = in(reg) 1u64,
313
314 limb_idx = out(reg) _,
316 zmm_idx = out(reg) _,
317
318 limb0 = out(reg) _, limb1 = out(reg) _, limb2 = out(reg) _, limb3 = out(reg) _,
320 limb4 = out(reg) _, limb5 = out(reg) _, limb6 = out(reg) _, limb7 = out(reg) _,
321
322 out("k1") _, out("k2") _, out("k3") _, out("k4") _,
324
325 out("zmm0") _, out("zmm1") _, out("zmm2") _, out("zmm3") _, out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _, out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm12") _, out("zmm13") _, out("zmm14") _, out("zmm15") _, options(nostack)
333 )
334 }
335}
336
337#[derive(Clone, Copy)]
338#[repr(align(128))]
339struct SimdBlock([SimdLimb; 8]);
340
341impl SimdBlock {
342 #[target_feature(enable = "avx512f")]
343 fn zero() -> Self {
344 Self([x86_64::_mm512_setzero_si512(); 8])
345 }
346
347 #[target_feature(enable = "avx512f")]
348 fn as_matrix_block(&self) -> MatrixBlock {
349 unsafe { std::mem::transmute::<Self, MatrixBlock>(*self) }
350 }
351}
352
353#[target_feature(enable = "avx512f")]
354pub unsafe fn gather_simd(slice: MatrixBlockSlice) -> MatrixBlock {
355 let mut result = SimdBlock::zero();
356 let offsets = unsafe { x86_64::_mm512_loadu_epi64(&UNIT_OFFSETS as *const i64) };
357 let stride = x86_64::_mm512_set1_epi64(slice.stride().get() as i64);
358 let offsets = unsafe { x86_64::_mm512_mullo_epi64(offsets, stride) };
359
360 for i in 0..8 {
361 let ptr = unsafe { slice.limbs().add(8 * i * slice.stride().get()) as *const i64 };
362 result.0[i] = unsafe { x86_64::_mm512_i64gather_epi64::<8>(offsets, ptr) };
363 }
364 result.as_matrix_block()
365}