Highly Efficient FFT for Exascale: HeFFTe v2.4
Loading...
Searching...
No Matches
heffte_stock_vec_types.h
1/*
2 -- heFFTe --
3 Univ. of Tennessee, Knoxville
4 @date
5*/
6
7#ifndef HEFFTE_STOCK_VEC_TYPES_H
8#define HEFFTE_STOCK_VEC_TYPES_H
9
10#include "heffte_config.h"
11
12#ifdef Heffte_ENABLE_AVX
13#include <immintrin.h>
14#endif
15#include <complex>
16
17namespace heffte {
18namespace stock {
20template<typename T>
21using is_float = std::is_same<float, typename std::remove_cv<T>::type>;
23template<typename T>
24using is_double = std::is_same<double, typename std::remove_cv<T>::type>;
26template<typename T>
27using is_fcomplex = std::is_same<std::complex<float>, typename std::remove_cv<T>::type>;
29template<typename T>
30using is_dcomplex = std::is_same<std::complex<double>, typename std::remove_cv<T>::type>;
31
33template<typename T> struct is_real {
34 static constexpr bool value = is_float<T>::value || is_double<T>::value;
35};
36
38template<typename T> struct is_complex {
39 static constexpr bool value = is_fcomplex<T>::value || is_dcomplex<T>::value;
40};
41
43template<typename T, int N> struct pack {};
45template<> struct pack<float, 1> { using type = std::complex<float>; };
47template<> struct pack<double, 1> { using type = std::complex<double>; };
48
49// Some simple operations that will be useful for vectorized types.
50
55template<typename F, int L>
56inline typename pack<F,L>::type mm_zero(){return 0.0;}
61template<typename F, int L>
62inline typename pack<F,L>::type mm_load(F const *src) { return typename pack<F,L>::type {src[0], src[1]}; }
67template<typename F, int L>
68inline void mm_store(F *dest, typename pack<F,L>::type const &src) {
69 dest[0] = src.real(); dest[1] = src.imag();
70}
75template<typename F, int L>
76inline typename pack<F,L>::type mm_pair_set(F x, F y) { return typename pack<F,L>::type(x, y); }
81template<typename F, int L>
82inline typename pack<F,L>::type mm_set1(F src) { return typename pack<F,L>::type(src,src); }
87template<typename F, int L>
88inline typename pack<F,L>::type mm_complex_load(std::complex<F> const *src) { return *src; }
93template<typename F, int L>
94inline typename pack<F,L>::type mm_complex_load(std::complex<F> const *src, int) { return *src; }
95
96// Real basic arithmetic for the "none" case
97
99inline typename pack<float, 1>::type mm_add(typename pack<float, 1>::type const &a, typename pack<float, 1>::type const &b){ return a + b; }
101inline typename pack<double, 1>::type mm_add(typename pack<double, 1>::type const &a, typename pack<double, 1>::type const &b){ return a + b; }
103inline typename pack<float, 1>::type mm_sub(typename pack<float, 1>::type const &a, typename pack<float, 1>::type const &b){ return a - b; }
105inline typename pack<double, 1>::type mm_sub(typename pack<double, 1>::type const &a, typename pack<double, 1>::type const &b){ return a - b; }
107inline typename pack<float, 1>::type mm_div(typename pack<float, 1>::type const &a, typename pack<float, 1>::type const &b){ return a / b.real(); }
109inline typename pack<double, 1>::type mm_div(typename pack<double, 1>::type const &a, typename pack<double, 1>::type const &b){ return a / b.real(); }
111inline typename pack<float, 1>::type mm_neg(typename pack<float, 1>::type const &a){ return -a; }
113inline typename pack<double, 1>::type mm_neg(typename pack<double, 1>::type const &a){ return -a; }
115inline typename pack<float, 1>::type mm_mul(typename pack<float, 1>::type const &a, typename pack<float, 1>::type const &b){ return a * b.real(); }
117inline typename pack<double, 1>::type mm_mul(typename pack<double, 1>::type const &a, typename pack<double, 1>::type const &b){ return a * b.real(); }
119inline typename pack<float, 1>::type mm_complex_mul(typename pack<float, 1>::type const &a, typename pack<float, 1>::type const &b){ return a * b; }
121inline typename pack<double, 1>::type mm_complex_mul(typename pack<double, 1>::type const &a, typename pack<double, 1>::type const &b){ return a * b; }
123inline typename pack<float, 1>::type mm_complex_fmadd(typename pack<float, 1>::type const &a, typename pack<float, 1>::type const &b, typename pack<float, 1>::type const &c){ return a * b + c; }
125inline typename pack<double, 1>::type mm_complex_fmadd(typename pack<double, 1>::type const &a, typename pack<double, 1>::type const &b, typename pack<double, 1>::type const &c){ return a * b + c; }
127inline typename pack<float, 1>::type mm_complex_fmsub(typename pack<float, 1>::type const &a, typename pack<float, 1>::type const &b, typename pack<float, 1>::type const &c){ return a * b - c; }
129inline typename pack<double, 1>::type mm_complex_fmsub(typename pack<double, 1>::type const &a, typename pack<double, 1>::type const &b, typename pack<double, 1>::type const &c){ return a * b - c; }
131inline typename pack<float, 1>::type mm_complex_mul_i(typename pack<float, 1>::type const &a){return a * std::complex<float>{0.f,1.f}; }
133inline typename pack<double, 1>::type mm_complex_mul_i(typename pack<double, 1>::type const &a){ return a * std::complex<double>{0.,1.}; }
135inline typename pack<float, 1>::type mm_complex_mul_neg_i(typename pack<float, 1>::type const &a){return a * std::complex<float>{0.f,-1.f}; }
137inline typename pack<double, 1>::type mm_complex_mul_neg_i(typename pack<double, 1>::type const &a){ return a * std::complex<double>{0.,-1.}; }
139inline typename pack<float, 1>::type mm_complex_sq_mod(typename pack<float,1>::type const &a){ return std::complex<float>{norm(a), norm(a)}; }
141inline typename pack<double, 1>::type mm_complex_sq_mod(typename pack<double,1>::type const &a){ return std::complex<double>{norm(a), norm(a)}; }
143inline typename pack<float, 1>::type mm_complex_mod(typename pack<float,1>::type const &a){ return std::complex<float>{std::abs(a), std::abs(a)}; }
145inline typename pack<double, 1>::type mm_complex_mod(typename pack<double,1>::type const &a){ return std::complex<double>{std::abs(a), std::abs(a)}; }
147inline typename pack<float, 1>::type mm_complex_conj(typename pack<float,1>::type const &a){ return conj(a); }
149inline typename pack<double, 1>::type mm_complex_conj(typename pack<double,1>::type const &a){ return conj(a); }
151inline typename pack<float, 1>::type mm_complex_div(typename pack<float, 1>::type const &a, typename pack<float, 1>::type const &b){ return a / b; }
153inline typename pack<double, 1>::type mm_complex_div(typename pack<double, 1>::type const &a, typename pack<double, 1>::type const &b){ return a / b; }
154
156/* Below is functionality for vector packs */
158
159#ifdef Heffte_ENABLE_AVX
160
162template<> struct pack<double, 2> { using type = __m128d; };
164template<> struct pack<float, 4> { using type = __m128; };
166template<> struct pack<double, 4> { using type = __m256d; };
168template<> struct pack<float, 8> { using type = __m256; };
169
171/* Below are structs for pack<float, 4> */
173
175template<>
176inline typename pack<float, 4>::type mm_zero<float, 4>(){ return _mm_setzero_ps(); }
177
179template<>
180inline typename pack<float, 4>::type mm_load<float, 4>(float const *src) { return _mm_loadu_ps(src); }
181
183template<>
184inline void mm_store<float, 4>(float *dest, pack<float, 4>::type const &src) { _mm_storeu_ps(dest, src); }
185
187template<>
188inline typename pack<float, 4>::type mm_pair_set<float, 4>(float x, float y) { return _mm_setr_ps(x, y, x, y); }
189
191template<>
192inline typename pack<float, 4>::type mm_set1<float,4> (float x) { return _mm_set1_ps(x); }
193
195template<>
196inline typename pack<float, 4>::type mm_complex_load<float,4>(std::complex<float> const *src, int stride) {
197 return _mm_setr_ps(src[0].real(), src[0].imag(), src[stride].real(), src[stride].imag());
198}
200template<>
201inline typename pack<float, 4>::type mm_complex_load<float,4>(std::complex<float> const *src) {
202 return mm_complex_load<float,4>(src, 1);
203}
204
206/* Below are structs for pack<float, 8> */
208
210template<>
211inline typename pack<float, 8>::type mm_zero<float, 8>(){ return _mm256_setzero_ps(); }
212
214template<>
215inline typename pack<float, 8>::type mm_load<float, 8>(float const *src) { return _mm256_loadu_ps(src); }
216
218template<>
219inline void mm_store<float, 8>(float *dest, pack<float, 8>::type const &src) { _mm256_storeu_ps(dest, src); }
220
222template<>
223inline typename pack<float, 8>::type mm_pair_set<float, 8>(float x, float y) { return _mm256_setr_ps(x, y, x, y, x, y, x, y); }
224
226template<>
227inline typename pack<float, 8>::type mm_set1<float,8> (float x) { return _mm256_set1_ps(x); }
228
230template<>
231inline typename pack<float, 8>::type mm_complex_load<float, 8>(std::complex<float> const *src, int stride) {
232 return _mm256_setr_ps(src[0*stride].real(), src[0*stride].imag(),
233 src[1*stride].real(), src[1*stride].imag(),
234 src[2*stride].real(), src[2*stride].imag(),
235 src[3*stride].real(), src[3*stride].imag());
236}
238template<>
239inline typename pack<float, 8>::type mm_complex_load<float,8>(std::complex<float> const *src) {
240 return mm_complex_load<float,8>(src, 1);
241}
242
244/* Below are structs for pack<double, 2> */
246
248template<>
249inline typename pack<double, 2>::type mm_zero<double, 2>(){ return _mm_setzero_pd(); }
250
252template<>
253inline typename pack<double, 2>::type mm_load<double, 2>(double const *src) { return _mm_loadu_pd(src); }
254
256template<>
257inline void mm_store<double, 2>(double *dest, pack<double, 2>::type const &src) { _mm_storeu_pd(dest, src); }
258
260template<>
261inline typename pack<double, 2>::type mm_pair_set<double, 2>(double x, double y) { return _mm_setr_pd(x, y); }
262
264template<>
265inline typename pack<double, 2>::type mm_set1<double, 2>(double x) { return _mm_set1_pd(x); }
266
268template<>
269inline typename pack<double,2>::type mm_complex_load<double,2>(std::complex<double> const *src, int) {
270 return _mm_setr_pd(src[0].real(), src[0].imag());
271}
272template<>
273inline typename pack<double,2>::type mm_complex_load<double,2>(std::complex<double> const *src) {
274 return mm_complex_load<double,2>(src, 1);
275}
276
278/* Below are structs for pack<double, 4> */
280
282template<>
283inline typename pack<double, 4>::type mm_zero<double, 4>(){ return _mm256_setzero_pd(); }
284
286template<>
287inline typename pack<double, 4>::type mm_load<double, 4>(double const *src) { return _mm256_loadu_pd(src); }
288
290template<>
291inline void mm_store<double, 4>(double *dest, pack<double, 4>::type const &src) { _mm256_storeu_pd(dest, src); }
292
294template<>
295inline typename pack<double, 4>::type mm_pair_set<double, 4>(double x, double y) { return _mm256_setr_pd(x, y, x, y); }
296
298template<>
299inline typename pack<double, 4>::type mm_set1<double, 4>(double x) { return _mm256_set1_pd(x); }
300
302template<>
303inline typename pack<double,4>::type mm_complex_load<double,4>(std::complex<double> const *src, int stride) {
304 return _mm256_setr_pd(src[0].real(), src[0].imag(), src[stride].real(), src[stride].imag());
305}
307template<>
308inline typename pack<double,4>::type mm_complex_load<double,4>(std::complex<double> const *src) {
309 return mm_complex_load<double,4>(src, 1);
310}
311
313/* Elementary operations for vector packs */
315
316/* Addition */
317
319inline pack<float, 4>::type mm_add(pack<float, 4>::type const &x,pack<float, 4>::type const &y) {
320 return _mm_add_ps(x, y);
321}
322
324inline pack<float, 8>::type mm_add(pack<float, 8>::type const &x, pack<float, 8>::type const &y) {
325 return _mm256_add_ps(x, y);
326}
327
329inline pack<double, 2>::type mm_add(pack<double, 2>::type const &x, pack<double, 2>::type const &y) {
330 return _mm_add_pd(x, y);
331}
332
334inline pack<double, 4>::type mm_add(pack<double, 4>::type const &x, pack<double, 4>::type const &y) {
335 return _mm256_add_pd(x, y);
336}
337
338/* Subtraction */
339
341inline pack<float, 4>::type mm_sub(pack<float, 4>::type const &x,pack<float, 4>::type const &y) {
342 return _mm_sub_ps(x, y);
343}
344
346inline pack<float, 8>::type mm_sub(pack<float, 8>::type const &x, pack<float, 8>::type const &y) {
347 return _mm256_sub_ps(x, y);
348}
349
351inline pack<double, 2>::type mm_sub(pack<double, 2>::type const &x, pack<double, 2>::type const &y) {
352 return _mm_sub_pd(x, y);
353}
354
356inline pack<double, 4>::type mm_sub(pack<double, 4>::type const &x, pack<double, 4>::type const &y) {
357 return _mm256_sub_pd(x, y);
358}
359
360/* Multiplication */
361
363inline pack<float, 4>::type mm_mul(pack<float, 4>::type const &x, pack<float, 4>::type const &y) {
364 return _mm_mul_ps(x, y);
365}
366
368inline pack<float, 8>::type mm_mul(pack<float, 8>::type const &x, pack<float, 8>::type const &y) {
369 return _mm256_mul_ps(x, y);
370}
371
373inline pack<double, 2>::type mm_mul(pack<double, 2>::type const &x, pack<double, 2>::type const &y) {
374 return _mm_mul_pd(x, y);
375}
376
378inline pack<double, 4>::type mm_mul(pack<double, 4>::type const &x, pack<double, 4>::type const &y) {
379 return _mm256_mul_pd(x, y);
380}
381
382/* Division */
383
385inline pack<float, 4>::type mm_div(pack<float, 4>::type const &x,pack<float, 4>::type const &y) {
386 return _mm_div_ps(x, y);
387}
388
390inline pack<float, 8>::type mm_div(pack<float, 8>::type const &x, pack<float, 8>::type const &y) {
391 return _mm256_div_ps(x, y);
392}
393
395inline pack<double, 2>::type mm_div(pack<double, 2>::type const &x, pack<double, 2>::type const &y) {
396 return _mm_div_pd(x, y);
397}
398
400inline pack<double, 4>::type mm_div(pack<double, 4>::type const &x, pack<double, 4>::type const &y) {
401 return _mm256_div_pd(x, y);
402}
403
404
405/* Negation */
407inline pack<float, 4>::type mm_neg(pack<float, 4>::type const &x) {
408 return _mm_xor_ps(x, (mm_set1<float, 4>(-0.f)));
409}
410
412inline pack<float, 8>::type mm_neg(pack<float, 8>::type const &x) {
413 return _mm256_xor_ps(x, (mm_set1<float, 8>(-0.f)));
414}
415
417inline pack<double, 2>::type mm_neg(pack<double, 2>::type const &x) {
418 return _mm_xor_pd(x, (mm_set1<double, 2>(-0.)));
419}
420
422inline pack<double, 4>::type mm_neg(pack<double, 4>::type const &x) {
423 return _mm256_xor_pd(x, (mm_set1<double, 4>(-0.)));
424}
425
427/* Complex operations using vector packs */
429
430// Complex Multiplication
431
433inline pack<float,4>::type mm_complex_mul(pack<float, 4>::type const &x, pack<float, 4>::type const &y) {
434 typename pack<float,4>::type cc = _mm_permute_ps(y, 0b10100000);
435 typename pack<float,4>::type ba = _mm_permute_ps(x, 0b10110001);
436 typename pack<float,4>::type dd = _mm_permute_ps(y, 0b11110101);
437 typename pack<float,4>::type dba = _mm_mul_ps(ba, dd);
438 typename pack<float,4>::type mult = _mm_fmaddsub_ps(x, cc, dba);
439 return mult;
440}
441
443inline pack<float, 8>::type mm_complex_mul(pack<float, 8>::type const &x, pack<float, 8>::type const &y) {
444 typename pack<float,8>::type cc = _mm256_permute_ps(y, 0b10100000);
445 typename pack<float,8>::type ba = _mm256_permute_ps(x, 0b10110001);
446 typename pack<float,8>::type dd = _mm256_permute_ps(y, 0b11110101);
447 typename pack<float,8>::type dba = _mm256_mul_ps(ba, dd);
448 typename pack<float,8>::type mult = _mm256_fmaddsub_ps(x, cc, dba);
449 return mult;
450}
451
453inline pack<double, 2>::type mm_complex_mul(pack<double, 2>::type const &x, pack<double, 2>::type const &y) {
454 typename pack<double,2>::type cc = _mm_permute_pd(y, 0);
455 typename pack<double,2>::type ba = _mm_permute_pd(x, 0b01);
456 typename pack<double,2>::type dd = _mm_permute_pd(y, 0b11);
457 typename pack<double,2>::type dba = _mm_mul_pd(ba, dd);
458 typename pack<double,2>::type mult = _mm_fmaddsub_pd(x, cc, dba);
459 return mult;
460}
461
463inline pack<double, 4>::type mm_complex_mul(pack<double, 4>::type const &x, pack<double, 4>::type const &y) {
464 typename pack<double,4>::type cc = _mm256_permute_pd(y, 0b0000);
465 typename pack<double,4>::type ba = _mm256_permute_pd(x, 0b0101);
466 typename pack<double,4>::type dd = _mm256_permute_pd(y, 0b1111);
467 typename pack<double,4>::type dba = _mm256_mul_pd(ba, dd);
468 typename pack<double,4>::type mult = _mm256_fmaddsub_pd(x, cc, dba);
469 return mult;
470}
471
472// Fused multiply-add
473
475inline pack<float,4>::type mm_complex_fmadd(pack<float, 4>::type const &x, pack<float, 4>::type const &y, pack<float, 4>::type const &z) {
476 typename pack<float,4>::type cc = _mm_permute_ps(y, 0b10100000);
477 typename pack<float,4>::type ba = _mm_permute_ps(x, 0b10110001);
478 typename pack<float,4>::type dd = _mm_permute_ps(y, 0b11110101);
479 typename pack<float,4>::type dba = _mm_fmaddsub_ps(ba, dd, z);
480 typename pack<float,4>::type mult = _mm_fmaddsub_ps(x, cc, dba);
481 return mult;
482}
483
485inline pack<float, 8>::type mm_complex_fmadd(pack<float, 8>::type const &x, pack<float, 8>::type const &y, pack<float, 8>::type const &z) {
486 typename pack<float,8>::type cc = _mm256_permute_ps(y, 0b10100000);
487 typename pack<float,8>::type ba = _mm256_permute_ps(x, 0b10110001);
488 typename pack<float,8>::type dd = _mm256_permute_ps(y, 0b11110101);
489 typename pack<float,8>::type dba = _mm256_fmaddsub_ps(ba, dd, z);
490 typename pack<float,8>::type mult = _mm256_fmaddsub_ps(x, cc, dba);
491 return mult;
492}
493
495inline pack<double, 2>::type mm_complex_fmadd(pack<double, 2>::type const &x, pack<double, 2>::type const &y, pack<double, 2>::type const &z) {
496 typename pack<double,2>::type cc = _mm_permute_pd(y, 0);
497 typename pack<double,2>::type ba = _mm_permute_pd(x, 0b01);
498 typename pack<double,2>::type dd = _mm_permute_pd(y, 0b11);
499 typename pack<double,2>::type dba = _mm_fmaddsub_pd(ba, dd, z);
500 typename pack<double,2>::type mult = _mm_fmaddsub_pd(x, cc, dba);
501 return mult;
502}
503
505inline pack<double, 4>::type mm_complex_fmadd(pack<double, 4>::type const &x, pack<double, 4>::type const &y, pack<double, 4>::type const &z) {
506 typename pack<double,4>::type cc = _mm256_permute_pd(y, 0b0000);
507 typename pack<double,4>::type ba = _mm256_permute_pd(x, 0b0101);
508 typename pack<double,4>::type dd = _mm256_permute_pd(y, 0b1111);
509 typename pack<double,4>::type dba = _mm256_fmaddsub_pd(ba, dd, z);
510 typename pack<double,4>::type mult = _mm256_fmaddsub_pd(x, cc, dba);
511 return mult;
512}
513
515inline pack<float,4>::type mm_complex_fmsub(pack<float, 4>::type const &x, pack<float, 4>::type const &y, pack<float, 4>::type const &z) {
516 typename pack<float,4>::type cc = _mm_permute_ps(y, 0b10100000);
517 typename pack<float,4>::type ba = _mm_permute_ps(x, 0b10110001);
518 typename pack<float,4>::type dd = _mm_permute_ps(y, 0b11110101);
519 typename pack<float,4>::type dba = _mm_fmsubadd_ps(ba, dd, z);
520 typename pack<float,4>::type mult = _mm_fmaddsub_ps(x, cc, dba);
521 return mult;
522}
523
525inline pack<float, 8>::type mm_complex_fmsub(pack<float, 8>::type const &x, pack<float, 8>::type const &y, pack<float, 8>::type const &z) {
526 typename pack<float,8>::type cc = _mm256_permute_ps(y, 0b10100000);
527 typename pack<float,8>::type ba = _mm256_permute_ps(x, 0b10110001);
528 typename pack<float,8>::type dd = _mm256_permute_ps(y, 0b11110101);
529 typename pack<float,8>::type dba = _mm256_fmsubadd_ps(ba, dd, z);
530 typename pack<float,8>::type mult = _mm256_fmaddsub_ps(x, cc, dba);
531 return mult;
532}
533
535inline pack<double, 2>::type mm_complex_fmsub(pack<double, 2>::type const &x, pack<double, 2>::type const &y, pack<double, 2>::type const &z) {
536 typename pack<double,2>::type cc = _mm_permute_pd(y, 0);
537 typename pack<double,2>::type ba = _mm_permute_pd(x, 0b01);
538 typename pack<double,2>::type dd = _mm_permute_pd(y, 0b11);
539 typename pack<double,2>::type dba = _mm_fmsubadd_pd(ba, dd, z);
540 typename pack<double,2>::type mult = _mm_fmaddsub_pd(x, cc, dba);
541 return mult;
542}
543
545inline pack<double, 4>::type mm_complex_fmsub(pack<double, 4>::type const &x, pack<double, 4>::type const &y, pack<double, 4>::type const &z) {
546 typename pack<double,4>::type cc = _mm256_permute_pd(y, 0b0000);
547 typename pack<double,4>::type ba = _mm256_permute_pd(x, 0b0101);
548 typename pack<double,4>::type dd = _mm256_permute_pd(y, 0b1111);
549 typename pack<double,4>::type dba = _mm256_fmsubadd_pd(ba, dd, z);
550 typename pack<double,4>::type mult = _mm256_fmaddsub_pd(x, cc, dba);
551 return mult;
552}
553
554// Squared modulus of the complex numbers in a pack
555
557inline pack<float, 4>::type mm_complex_sq_mod(pack<float, 4>::type const &x) {
558 return _mm_or_ps(_mm_dp_ps(x, x, 0b11001100), _mm_dp_ps(x, x, 0b00110011));
559}
560
562inline pack<float, 8>::type mm_complex_sq_mod(pack<float, 8>::type const &x) {
563 return _mm256_or_ps(_mm256_dp_ps(x, x, 0b11001100), _mm256_dp_ps(x, x, 0b00110011));
564}
565
567inline pack<double, 2>::type mm_complex_sq_mod(pack<double, 2>::type const &x) {
568 return _mm_dp_pd(x, x, 0b11111111);
569}
570
572inline pack<double, 4>::type mm_complex_sq_mod(pack<double, 4>::type const &x) {
573 typename pack<double,4>::type a = _mm256_mul_pd(x, x);
574 return _mm256_hadd_pd(a, a);
575}
576
577// Moduli (with square root) of complex numbers
578
580inline pack<float, 4>::type mm_complex_mod(pack<float, 4>::type const &x) {
581 return _mm_sqrt_ps(mm_complex_sq_mod(x));
582}
583
585inline pack<float, 8>::type mm_complex_mod(pack<float, 8>::type const &x) {
586 return _mm256_sqrt_ps(mm_complex_sq_mod(x));
587}
588
590inline pack<double, 2>::type mm_complex_mod(pack<double, 2>::type const &x) {
591 return _mm_sqrt_pd(mm_complex_sq_mod(x));
592}
593
595inline pack<double, 4>::type mm_complex_mod(pack<double, 4>::type const &x) {
596 return _mm256_sqrt_pd(mm_complex_sq_mod(x));
597}
598
600inline pack<float, 4>::type mm_complex_conj(pack<float, 4>::type const &x) {
601 return _mm_blend_ps(x, (mm_neg(x)), 0b1010);
602}
603
605inline pack<float, 8>::type mm_complex_conj(pack<float, 8>::type const &x) {
606 return _mm256_blend_ps(x, (mm_neg(x)), 0b10101010);
607}
608
610inline pack<double, 2>::type mm_complex_conj(pack<double, 2>::type const &x) {
611 return _mm_blend_pd(x, (mm_neg(x)), 0b10);
612}
613
615inline pack<double, 4>::type mm_complex_conj(pack<double, 4>::type const &x) {
616 return _mm256_blend_pd(x, (mm_neg(x)), 0b1010);
617}
618
619// Special operation when multiplying by i and -i
621inline pack<float, 4>::type mm_complex_mul_i(pack<float, 4>::type const &x) {
622 return _mm_permute_ps( (mm_complex_conj(x)), 0b10110001);
623}
624
626inline pack<float, 8>::type mm_complex_mul_i(pack<float, 8>::type const &x) {
627 return _mm256_permute_ps( (mm_complex_conj(x)), 0b10110001);
628}
629
631inline pack<double, 2>::type mm_complex_mul_i(pack<double, 2>::type const &x) {
632 return _mm_permute_pd( (mm_complex_conj(x)), 0b00000001);
633}
634
636inline pack<double, 4>::type mm_complex_mul_i(pack<double, 4>::type const &x) {
637 return _mm256_permute_pd( (mm_complex_conj(x)), 0b00000101);
638}
639
641inline pack<float, 4>::type mm_complex_mul_neg_i(pack<float, 4>::type const &x) {
642 return mm_complex_conj(_mm_permute_ps(x, 0b10110001));
643}
644
646inline pack<float, 8>::type mm_complex_mul_neg_i(pack<float, 8>::type const &x) {
647 return mm_complex_conj(_mm256_permute_ps(x, 0b10110001));
648}
649
651inline pack<double, 2>::type mm_complex_mul_neg_i(pack<double, 2>::type const &x) {
652 return mm_complex_conj(_mm_permute_pd(x, 0b0000001));
653}
654
656inline pack<double, 4>::type mm_complex_mul_neg_i(pack<double, 4>::type const &x) {
657 return mm_complex_conj(_mm256_permute_pd(x, 0b00000101));
658}
659
660// Complex division
661
663inline pack<float, 4>::type mm_complex_div(pack<float, 4>::type const &x, pack<float, 4>::type const &y) {
664 return _mm_div_ps(mm_complex_mul(x, mm_complex_conj(y)), mm_complex_sq_mod(y));
665}
666
668inline pack<float, 8>::type mm_complex_div(pack<float, 8>::type const &x, pack<float, 8>::type const &y) {
669 return _mm256_div_ps(mm_complex_mul(x, mm_complex_conj(y)), mm_complex_sq_mod(y));
670}
671
673inline pack<double, 2>::type mm_complex_div(pack<double, 2>::type const &x, pack<double, 2>::type const &y) {
674 return _mm_div_pd(mm_complex_mul(x, mm_complex_conj(y)), mm_complex_sq_mod(y));
675}
676
678inline pack<double, 4>::type mm_complex_div(pack<double, 4>::type const &x, pack<double, 4>::type const &y) {
679 return _mm256_div_pd(mm_complex_mul(x, mm_complex_conj(y)), mm_complex_sq_mod(y));
680}
681
682// Now all the implementations for types in AVX512 headers
683#ifdef Heffte_ENABLE_AVX512
684
686template<> struct pack<double, 8> { using type = __m512d; };
688template<> struct pack<float, 16> { using type = __m512; };
689
691/* Below are structs for pack<float, 16> */
693
695template<>
696inline typename pack<float, 16>::type mm_zero<float, 16>(){ return _mm512_setzero_ps(); }
697
699template<>
700inline typename pack<float, 16>::type mm_load<float, 16>(float const *src) { return _mm512_loadu_ps(src); }
701
703template<>
704inline void mm_store<float, 16>(float *dest, pack<float, 16>::type const &src) { _mm512_storeu_ps(dest, src); }
705
707template<>
708inline typename pack<float, 16>::type mm_pair_set<float, 16>(float x, float y) { return _mm512_setr_ps(x, y, x, y, x, y, x, y, x, y, x, y, x, y, x, y); }
709
711template<>
712inline typename pack<float, 16>::type mm_set1<float, 16> (float x) { return _mm512_set1_ps(x); }
713
715template<>
716inline typename pack<float, 16>::type mm_complex_load<float, 16>(std::complex<float> const *src, int stride) {
717 return _mm512_setr_ps(src[0*stride].real(), src[0*stride].imag(), src[1*stride].real(), src[1*stride].imag(),
718 src[2*stride].real(), src[2*stride].imag(), src[3*stride].real(), src[3*stride].imag(),
719 src[4*stride].real(), src[4*stride].imag(), src[5*stride].real(), src[5*stride].imag(),
720 src[6*stride].real(), src[6*stride].imag(), src[7*stride].real(), src[7*stride].imag());
721}
722
724template<>
725inline typename pack<float, 16>::type mm_complex_load<float, 16>(std::complex<float> const *src) {
726 return mm_complex_load<float, 16>(src, 1);
727}
728
730/* Below are structs for pack<double, 8> */
732
734template<>
735inline typename pack<double, 8>::type mm_zero<double, 8>(){ return _mm512_setzero_pd(); }
736
738template<>
739inline typename pack<double, 8>::type mm_load<double, 8>(double const *src) { return _mm512_loadu_pd(src); }
740
742template<>
743inline void mm_store<double, 8>(double *dest, pack<double, 8>::type const &src) { _mm512_storeu_pd(dest, src); }
744
746template<>
747inline typename pack<double, 8>::type mm_pair_set<double, 8>(double x, double y) { return _mm512_setr_pd(x, y, x, y, x, y, x, y); }
748
750template<>
751inline typename pack<double, 8>::type mm_set1<double, 8>(double x) { return _mm512_set1_pd(x); }
752
754template<>
755inline typename pack<double, 8>::type mm_complex_load<double, 8>(std::complex<double> const *src, int stride) {
756 return _mm512_setr_pd(src[0*stride].real(), src[0*stride].imag(), src[1*stride].real(), src[1*stride].imag(),
757 src[2*stride].real(), src[2*stride].imag(), src[3*stride].real(), src[3*stride].imag());
758}
760template<>
761inline typename pack<double, 8>::type mm_complex_load<double, 8>(std::complex<double> const *src) {
762 return mm_complex_load<double, 8>(src, 1);
763}
764
766/* Elementary binary operations for vector packs */
768
769/* Addition */
770
772inline pack<float, 16>::type mm_add(pack<float, 16>::type const &x, pack<float, 16>::type const &y) {
773 return _mm512_add_ps(x, y);
774}
775
777inline pack<double, 8>::type mm_add(pack<double, 8>::type const &x, pack<double, 8>::type const &y) {
778 return _mm512_add_pd(x, y);
779}
780
781/* Subtraction */
782
784inline pack<float, 16>::type mm_sub(pack<float, 16>::type const &x,pack<float, 16>::type const &y) {
785 return _mm512_sub_ps(x, y);
786}
787
789inline pack<double, 8>::type mm_sub(pack<double, 8>::type const &x, pack<double, 8>::type const &y) {
790 return _mm512_sub_pd(x, y);
791}
792
793/* Multiplication */
794
796inline pack<float, 16>::type mm_mul(pack<float, 16>::type const &x, pack<float, 16>::type const &y) {
797 return _mm512_mul_ps(x, y);
798}
799
801inline pack<double, 8>::type mm_mul(pack<double, 8>::type const &x, pack<double, 8>::type const &y) {
802 return _mm512_mul_pd(x, y);
803}
804
805/* Division */
806
808inline pack<float, 16>::type mm_div(pack<float, 16>::type const &x,pack<float, 16>::type const &y) {
809 return _mm512_div_ps(x, y);
810}
811
813inline pack<double, 8>::type mm_div(pack<double, 8>::type const &x, pack<double, 8>::type const &y) {
814 return _mm512_div_pd(x, y);
815}
816
817/* Negation */
819inline pack<float, 16>::type mm_neg(pack<float, 16>::type const &x) {
820 return _mm512_xor_ps(x, (mm_set1<float, 16>(-0.f)));
821}
822
824inline pack<double, 8>::type mm_neg(pack<double, 8>::type const &x) {
825 return _mm512_xor_pd(x, (mm_set1<double, 8>(-0.f)));
826}
827
829/* Complex operations using AVX512 vector packs */
831
832// Complex Multiplication
833
835inline pack<float,16>::type mm_complex_mul(pack<float, 16>::type const &x, pack<float, 16>::type const &y) {
836 typename pack<float, 16>::type cc = _mm512_permute_ps(y, 0b10100000);
837 typename pack<float, 16>::type ba = _mm512_permute_ps(x, 0b10110001);
838 typename pack<float, 16>::type dd = _mm512_permute_ps(y, 0b11110101);
839 typename pack<float, 16>::type dba = _mm512_mul_ps(ba, dd);
840 typename pack<float, 16>::type mult = _mm512_fmaddsub_ps(x, cc, dba);
841 return mult;
842}
843
845inline pack<double, 8>::type mm_complex_mul(pack<double, 8>::type const &x, pack<double, 8>::type const &y) {
846 typename pack<double, 8>::type cc = _mm512_permute_pd(y, 0b00000000);
847 typename pack<double, 8>::type ba = _mm512_permute_pd(x, 0b01010101);
848 typename pack<double, 8>::type dd = _mm512_permute_pd(y, 0b11111111);
849 typename pack<double, 8>::type dba = _mm512_mul_pd(ba, dd);
850 typename pack<double, 8>::type mult = _mm512_fmaddsub_pd(x, cc, dba);
851 return mult;
852}
853
854// Complex fused-multiply add
855
857inline pack<float,16>::type mm_complex_fmadd(pack<float, 16>::type const &x, pack<float, 16>::type const &y, pack<float, 16>::type const &alpha) {
858 typename pack<float, 16>::type cc = _mm512_permute_ps(y, 0b10100000);
859 typename pack<float, 16>::type ba = _mm512_permute_ps(x, 0b10110001);
860 typename pack<float, 16>::type dd = _mm512_permute_ps(y, 0b11110101);
861 typename pack<float, 16>::type dba = _mm512_fmaddsub_ps(ba, dd, alpha);
862 typename pack<float, 16>::type mult = _mm512_fmaddsub_ps(x, cc, dba);
863 return mult;
864}
865
867inline pack<double, 8>::type mm_complex_fmadd(pack<double, 8>::type const &x, pack<double, 8>::type const &y, pack<double, 8>::type const &alpha) {
868 typename pack<double, 8>::type cc = _mm512_permute_pd(y, 0b00000000);
869 typename pack<double, 8>::type ba = _mm512_permute_pd(x, 0b01010101);
870 typename pack<double, 8>::type dd = _mm512_permute_pd(y, 0b11111111);
871 typename pack<double, 8>::type dba = _mm512_fmaddsub_pd(ba, dd, alpha);
872 typename pack<double, 8>::type mult = _mm512_fmaddsub_pd(x, cc, dba);
873 return mult;
874}
875
877inline pack<float,16>::type mm_complex_fmsub(pack<float, 16>::type const &x, pack<float, 16>::type const &y, pack<float, 16>::type const &alpha) {
878 typename pack<float, 16>::type cc = _mm512_permute_ps(y, 0b10100000);
879 typename pack<float, 16>::type ba = _mm512_permute_ps(x, 0b10110001);
880 typename pack<float, 16>::type dd = _mm512_permute_ps(y, 0b11110101);
881 typename pack<float, 16>::type dba = _mm512_fmsubadd_ps(ba, dd, alpha);
882 typename pack<float, 16>::type mult = _mm512_fmaddsub_ps(x, cc, dba);
883 return mult;
884}
885
887inline pack<double, 8>::type mm_complex_fmsub(pack<double, 8>::type const &x, pack<double, 8>::type const &y, pack<double, 8>::type const &alpha) {
888 typename pack<double, 8>::type cc = _mm512_permute_pd(y, 0b00000000);
889 typename pack<double, 8>::type ba = _mm512_permute_pd(x, 0b01010101);
890 typename pack<double, 8>::type dd = _mm512_permute_pd(y, 0b11111111);
891 typename pack<double, 8>::type dba = _mm512_fmsubadd_pd(ba, dd, alpha);
892 typename pack<double, 8>::type mult = _mm512_fmaddsub_pd(x, cc, dba);
893 return mult;
894}
895
896// Squared modulus of the complex numbers in a pack
897
899inline pack<float, 16>::type mm_complex_sq_mod(pack<float, 16>::type const &x) {
900 typename pack<float, 16>::type sq = mm_mul(x, x);
901 typename pack<float, 16>::type sq_perm = _mm512_permute_ps(sq, 0b10110001);
902 typename pack<float, 16>::type mod = mm_add(sq, sq_perm);
903 return mod;
904}
905
907inline pack<double, 8>::type mm_complex_sq_mod(pack<double, 8>::type const &x) {
908 typename pack<double, 8>::type sq = mm_mul(x, x);
909 typename pack<double, 8>::type sq_perm = _mm512_permute_pd(sq, 0b01010101);
910 typename pack<double, 8>::type mod = mm_add(sq, sq_perm);
911 return mod;
912}
913
914// Moduli (with square root) of complex numbers
915
917inline pack<float, 16>::type mm_complex_mod(pack<float, 16>::type const &x) {
918 return _mm512_sqrt_ps(mm_complex_sq_mod(x));
919}
920
922inline pack<double, 8>::type mm_complex_mod(pack<double, 8>::type const &x) {
923 return _mm512_sqrt_pd(mm_complex_sq_mod(x));
924}
925
926// Conjugate complex numbers
927
929inline pack<float, 16>::type mm_complex_conj(pack<float, 16>::type const &x) {
930 return _mm512_mask_blend_ps(0b1010101010101010, x, mm_neg(x));
931}
932
934inline pack<double, 8>::type mm_complex_conj(pack<double, 8>::type const &x) {
935 return _mm512_mask_blend_pd(0b10101010, x, mm_neg(x));
936}
937
938// Special operation when multiplying by i and -i
940inline pack<float, 16>::type mm_complex_mul_i(pack<float, 16>::type const &x) {
941 return _mm512_permute_ps( (mm_complex_conj(x)), 0b10110001);
942}
943
945inline pack<double, 8>::type mm_complex_mul_i(pack<double, 8>::type const &x) {
946 return _mm512_permute_pd( (mm_complex_conj(x)), 0b01010101);
947}
948
950inline pack<float, 16>::type mm_complex_mul_neg_i(pack<float, 16>::type const &x) {
951 return mm_complex_conj(_mm512_permute_ps(x, 0b10110001));
952}
953
955inline pack<double, 8>::type mm_complex_mul_neg_i(pack<double, 8>::type const &x) {
956 return mm_complex_conj(_mm512_permute_pd(x, 0b01010101));
957}
958
959// Complex division
960
962inline pack<float, 16>::type mm_complex_div(pack<float, 16>::type const &x, pack<float, 16>::type const &y) {
963 return _mm512_div_ps(mm_complex_mul(x, mm_complex_conj(y)), mm_complex_sq_mod(y));
964}
965
967inline pack<double, 8>::type mm_complex_div(pack<double, 8>::type const &x, pack<double, 8>::type const &y) {
968 return _mm512_div_pd(mm_complex_mul(x, mm_complex_conj(y)), mm_complex_sq_mod(y));
969}
970
971#endif // Heffte_ENABLE_AVX512
972#endif // Heffte_ENABLE_AVX
973
974}
975}
976
977#endif // HEFFTE_STOCK_VEC_TYPES_H
Namespace containing all HeFFTe methods and classes.
Definition heffte_backend_cuda.h:38
Wrapper around cufftHandle plans, set for float or double complex.
Definition heffte_backend_cuda.h:346
Struct determining whether a type is a complex number.
Definition heffte_stock_vec_types.h:38
Struct determining whether a type is a real number.
Definition heffte_stock_vec_types.h:33
Struct to retrieve the vector type associated with the number of elements stored "per unit".
Definition heffte_stock_vec_types.h:43