HILA
Loading...
Searching...
No Matches
fft_gpu_transform.h
1#ifndef FFT_GPU_TRANSFORM_H
2#define FFT_GPU_TRANSFORM_H
3
4#ifndef HILAPP
5
6#if defined(CUDA)
7
8#include <cufft.h>
9
10using gpufftComplex = cufftComplex;
11using gpufftDoubleComplex = cufftDoubleComplex;
12using gpufftHandle = cufftHandle;
13#define gpufftExecC2C cufftExecC2C
14#define gpufftExecZ2Z cufftExecZ2Z
15#define gpufftPlan1d cufftPlan1d
16#define gpufftDestroy cufftDestroy
17
18#define GPUFFT_FORWARD CUFFT_FORWARD
19#define GPUFFT_INVERSE CUFFT_INVERSE
20
21#define GPUFFT_C2C CUFFT_C2C
22#define GPUFFT_Z2Z CUFFT_Z2Z
23
24#else
25
26#include "hip/hip_runtime.h"
27#include <hipfft/hipfft.h>
28
29using gpufftComplex = hipfftComplex;
30using gpufftDoubleComplex = hipfftDoubleComplex;
31using gpufftHandle = hipfftHandle;
32#define gpufftExecC2C hipfftExecC2C
33#define gpufftExecZ2Z hipfftExecZ2Z
34#define gpufftPlan1d hipfftPlan1d
35#define gpufftDestroy hipfftDestroy
36
37#define GPUFFT_FORWARD HIPFFT_FORWARD
38#define GPUFFT_INVERSE HIPFFT_BACKWARD
39
40#define GPUFFT_C2C HIPFFT_C2C
41#define GPUFFT_Z2Z HIPFFT_Z2Z
42
43#endif
44
45
46/// Gather one element column from the mpi buffer
47template <typename cmplx_t>
48__global__ void hila_fft_gather_column(cmplx_t *RESTRICT data, cmplx_t *RESTRICT *d_ptr,
49 int *RESTRICT d_size, int n, int colsize, int columns) {
50
51 int ind = threadIdx.x + blockIdx.x * blockDim.x;
52 if (ind < columns) {
53 int s = colsize * ind;
54
55 int k = s;
56 for (int i = 0; i < n; i++) {
57 int offset = ind * d_size[i];
58 for (int j = 0; j < d_size[i]; j++, k++) {
59 data[k] = d_ptr[i][j + offset];
60 }
61 }
62 }
63}
64
65/// Gather one element column from the mpi buffer
66template <typename cmplx_t>
67__global__ void hila_fft_scatter_column(cmplx_t *RESTRICT data, cmplx_t *RESTRICT *d_ptr,
68 int *RESTRICT d_size, int n, int colsize, int columns) {
69
70 int ind = threadIdx.x + blockIdx.x * blockDim.x;
71 if (ind < columns) {
72 int s = colsize * ind;
73
74 int k = s;
75 for (int i = 0; i < n; i++) {
76 int offset = ind * d_size[i];
77 for (int j = 0; j < d_size[i]; j++, k++) {
78 d_ptr[i][j + offset] = data[k];
79 }
80 }
81 }
82}
83
84// Define datatype for saved plans
85
86#define N_PLANS NDIM // just for concreteness...
87
88class hila_saved_fftplan_t {
89 public:
90 struct plan_d {
91 gpufftHandle plan;
92 unsigned long seq; // plan use sequence - for clearing up
93 int size;
94 int batch;
95 bool is_float;
96 };
97
98 unsigned long seq;
99 std::vector<plan_d> plans;
100
101 hila_saved_fftplan_t() {
102 plans.reserve(N_PLANS);
103 seq = 0;
104 }
105
106 ~hila_saved_fftplan_t() {
107 delete_plans();
108 }
109
110 void delete_plans() {
111 for (auto &p : plans) {
112 gpufftDestroy(p.plan);
113 }
114 plans.clear();
115 seq = 0;
116 }
117
118 // get cached plan or create new. If the saved plan is incompatible with
119 // the one required, destroy plans
120 gpufftHandle get_plan(int size, int batch, bool is_float) {
121
122 extern hila::timer fft_plan_timer;
123
124 // do we have saved plan of the same type?
125
126 seq++;
127
128 for (auto &p : plans) {
129 if (p.size == size && p.batch == batch && p.is_float == is_float) {
130 // Now we got it!
131 p.seq = seq;
132 return p.plan;
133 }
134 }
135
136 // not cached, make new if there's room
137
138 fft_plan_timer.start();
139
140 plan_d *pp;
141 if (plans.size() == N_PLANS) {
142 // find and destroy oldest used plan
143 pp = &plans[0];
144 for (int i = 1; i < plans.size(); i++) {
145 if (pp->seq > plans[i].seq)
146 pp = &plans[i];
147 }
148 gpufftDestroy(pp->plan);
149 } else {
150 plan_d empty;
151 plans.push_back(empty);
152 pp = &plans.back();
153 }
154
155 // If we got here we need to make a plan
156
157 pp->size = size;
158 pp->batch = batch;
159 pp->is_float = is_float;
160 pp->seq = seq;
161
162 // HIPFFT_C2C for float transform, Z2Z for double
163
164 gpufftPlan1d(&(pp->plan), size, is_float ? GPUFFT_C2C : GPUFFT_Z2Z, batch);
165 check_device_error("FFT plan");
166
167 fft_plan_timer.stop();
168
169 return pp->plan;
170 }
171};
172
173/// Define appropriate cufft-type depending on the cmplx_t -type
174/// these types are 1-1 compatible anyway
175/// use as cufft_cmplx_t<T>::type
176template <typename cmplx_t>
177using fft_cmplx_t = typename std::conditional<sizeof(gpufftComplex) == sizeof(cmplx_t),
178 gpufftComplex, gpufftDoubleComplex>::type;
179
180/// Templates for cufftExec float and double complex
181
182template <typename cmplx_t, std::enable_if_t<sizeof(cmplx_t) == sizeof(gpufftComplex), int> = 0>
183inline void hila_gpufft_execute(gpufftHandle plan, cmplx_t *buf, int direction) {
184 gpufftExecC2C(plan, (gpufftComplex *)buf, (gpufftComplex *)buf, direction);
185}
186
187template <typename cmplx_t,
188 std::enable_if_t<sizeof(cmplx_t) == sizeof(gpufftDoubleComplex), int> = 0>
189inline void hila_gpufft_execute(gpufftHandle plan, cmplx_t *buf, int direction) {
190 gpufftExecZ2Z(plan, (gpufftDoubleComplex *)buf, (gpufftDoubleComplex *)buf, direction);
191}
192
193template <typename cmplx_t>
195
196 // these externs defined in fft.cpp
197 extern unsigned hila_fft_my_columns[NDIM];
198 extern hila::timer fft_execute_timer, fft_buffer_timer;
199 extern hila_saved_fftplan_t hila_saved_fftplan;
200
201 constexpr bool is_float = (sizeof(cmplx_t) == sizeof(Complex<float>));
202
203 int n_columns = hila_fft_my_columns[dir] * elements;
204
205 int direction = (fftdir == fft_direction::forward) ? GPUFFT_FORWARD : GPUFFT_INVERSE;
206
207 // allocate here fftw plans. TODO: perhaps store, if plans take appreciable time?
208 // Timer will tell the proportional timing
209
210 int batch = hila_fft_my_columns[dir];
211 int n_fft = elements;
212 // reduce very large batch to smaller, avoid large buffer space
213
214 bool is_divisible = true;
215 while (batch > GPUFFT_BATCH_SIZE && is_divisible) {
216 is_divisible = false;
217 for (int div : {2, 3, 5, 7}) {
218 if (batch % div == 0) {
219 batch /= div;
220 n_fft *= div;
221 is_divisible = true;
222 break;
223 }
224 }
225 }
226
227 gpufftHandle plan;
228 plan = hila_saved_fftplan.get_plan(lattice.size(dir), batch, is_float);
229 // hila::out0 << " Batch " << batch << " nfft " << n_fft << '\n';
230
231 // alloc work array
232 cmplx_t *fft_wrk = (cmplx_t *)d_malloc(buf_size * sizeof(cmplx_t) * elements);
233
234 // Reorganize the data to form columns of a single element
235 // move from receive_buf to fft_wrk
236 // first need to copy index arrays to device
237
238 fft_buffer_timer.start();
239
240 cmplx_t **d_ptr = (cmplx_t **)d_malloc(sizeof(cmplx_t *) * rec_p.size());
241 int *d_size = (int *)d_malloc(sizeof(int) * rec_p.size());
242
243 gpuMemcpy(d_ptr, rec_p.data(), rec_p.size() * sizeof(cmplx_t *), gpuMemcpyHostToDevice);
244 gpuMemcpy(d_size, rec_size.data(), rec_size.size() * sizeof(int), gpuMemcpyHostToDevice);
245
246 int N_blocks = (n_columns + N_threads - 1) / N_threads;
247
248#if defined(CUDA)
249 hila_fft_gather_column<cmplx_t><<<N_blocks, N_threads>>>(fft_wrk, d_ptr, d_size, rec_p.size(),
250 lattice.size(dir), n_columns);
251#else
252 hipLaunchKernelGGL(HIP_KERNEL_NAME(hila_fft_gather_column<cmplx_t>), dim3(N_blocks),
253 dim3(N_threads), 0, 0, fft_wrk, d_ptr, d_size, rec_p.size(),
254 lattice.size(dir), n_columns);
255#endif
256
257 fft_buffer_timer.stop();
258
259 // do the fft
260 fft_execute_timer.start();
261
262 for (int i = 0; i < n_fft; i++) {
263
264 cmplx_t *cp = fft_wrk + i * (batch * lattice.size(dir));
265
266 hila_gpufft_execute(plan, cp, direction);
267 check_device_error("FFT execute");
268 }
269
270 fft_execute_timer.stop();
271
272 fft_buffer_timer.start();
273
274#if defined(CUDA)
275 hila_fft_scatter_column<cmplx_t><<<N_blocks, N_threads>>>(fft_wrk, d_ptr, d_size, rec_p.size(),
276 lattice.size(dir), n_columns);
277#else
278 hipLaunchKernelGGL(HIP_KERNEL_NAME(hila_fft_scatter_column<cmplx_t>), dim3(N_blocks),
279 dim3(N_threads), 0, 0, fft_wrk, d_ptr, d_size, rec_p.size(),
280 lattice.size(dir), n_columns);
281#endif
282
283 fft_buffer_timer.stop();
284
285 d_free(d_size);
286 d_free(d_ptr);
287 d_free(fft_wrk);
288}
289
290
291////////////////////////////////////////////////////////////////////
292/// send column data to nodes
293
294template <typename cmplx_t>
296
297
298 extern hila::timer pencil_MPI_timer;
299 pencil_MPI_timer.start();
300
301 // post receive and send
302 int n_comms = hila_pencil_comms[dir].size() - 1;
303
304 MPI_Request sendreq[n_comms], recreq[n_comms];
305 MPI_Status stat[n_comms];
306
307#ifndef GPU_AWARE_MPI
308 cmplx_t *send_p[n_comms];
309 cmplx_t *receive_p[n_comms];
310#endif
311
312 int i = 0;
313 int j = 0;
314
315 // this synchronization should be enough for all MPI's in the
316 gpuStreamSynchronize(0);
317
318 size_t mpi_type_size;
319 MPI_Datatype mpi_type = get_MPI_complex_type<cmplx_t>(mpi_type_size);
320
321 for (auto &fn : hila_pencil_comms[dir]) {
322 if (fn.node != hila::myrank()) {
323
324 size_t siz = fn.recv_buf_size * elements * sizeof(cmplx_t);
325 if (siz >= (1ULL << 31) * mpi_type_size) {
326 hila::out << "Too large MPI message in pencils! Size " << siz << " bytes ("
327 << siz / mpi_type_size << " elements)\n";
329 }
330
331#ifndef GPU_AWARE_MPI
332 cmplx_t *p = receive_p[i] = (cmplx_t *)memalloc(siz);
333#else
334 cmplx_t *p = rec_p[j];
335#endif
336
337 MPI_Irecv(p, (int)(siz / mpi_type_size), mpi_type, fn.node, WRK_GATHER_TAG,
338 lattice.mpi_comm_lat, &recreq[i]);
339
340 i++;
341 }
342 j++;
343 }
344
345 i = 0;
346 for (auto &fn : hila_pencil_comms[dir]) {
347 if (fn.node != hila::myrank()) {
348
349 cmplx_t *p = send_buf + fn.column_offset * elements;
350 size_t n = fn.column_number * elements * lattice.mynode.size[dir] * sizeof(cmplx_t);
351
352#ifndef GPU_AWARE_MPI
353 // now not GPU_AWARE_MPI
354 send_p[i] = (cmplx_t *)memalloc(n);
355 gpuMemcpy(send_p[i], p, n, gpuMemcpyDeviceToHost);
356 p = send_p[i];
357#endif
358
359 MPI_Isend(p, (int)(n / mpi_type_size), mpi_type, fn.node, WRK_GATHER_TAG,
360 lattice.mpi_comm_lat, &sendreq[i]);
361 i++;
362 }
363 }
364
365 // and wait for the send and receive to complete
366 if (n_comms > 0) {
367 MPI_Waitall(n_comms, recreq, stat);
368 MPI_Waitall(n_comms, sendreq, stat);
369
370#ifndef GPU_AWARE_MPI
371 i = j = 0;
372
373 for (auto &fn : hila_pencil_comms[dir]) {
374 if (fn.node != hila::myrank()) {
375
376 size_t siz = fn.recv_buf_size * elements;
377
378 gpuMemcpy(rec_p[j], receive_p[i], siz * sizeof(cmplx_t), gpuMemcpyHostToDevice);
379 i++;
380 }
381 j++;
382 }
383
384 for (i = 0; i < n_comms; i++) {
385 free(receive_p[i]);
386 free(send_p[i]);
387 }
388#endif
389 }
390
391 pencil_MPI_timer.stop();
392}
393
394//////////////////////////////////////////////////////////////////////////////////////
395/// inverse of gather_data
396
397template <typename cmplx_t>
399
400
401 extern hila::timer pencil_MPI_timer;
402 pencil_MPI_timer.start();
403
404 int n_comms = hila_pencil_comms[dir].size() - 1;
405
406 MPI_Request sendreq[n_comms], recreq[n_comms];
407 MPI_Status stat[n_comms];
408
409#ifndef GPU_AWARE_MPI
410 cmplx_t *send_p[n_comms];
411 cmplx_t *receive_p[n_comms];
412#endif
413
414 int i = 0;
415
416 size_t mpi_type_size;
417 MPI_Datatype mpi_type = get_MPI_complex_type<cmplx_t>(mpi_type_size);
418
419 gpuStreamSynchronize(0);
420
421 for (auto &fn : hila_pencil_comms[dir]) {
422 if (fn.node != hila::myrank()) {
423
424 size_t n = fn.column_number * elements * lattice.mynode.size[dir] * sizeof(cmplx_t);
425#ifdef GPU_AWARE_MPI
426 cmplx_t *p = send_buf + fn.column_offset * elements;
427#else
428 cmplx_t *p = receive_p[i] = (cmplx_t *)memalloc(n);
429#endif
430
431 MPI_Irecv(p, (int)(n / mpi_type_size), mpi_type, fn.node, WRK_SCATTER_TAG,
432 lattice.mpi_comm_lat, &recreq[i]);
433
434 i++;
435 }
436 }
437
438 i = 0;
439 int j = 0;
440 for (auto &fn : hila_pencil_comms[dir]) {
441 if (fn.node != hila::myrank()) {
442
443 size_t n = fn.recv_buf_size * elements * sizeof(cmplx_t);
444#ifdef GPU_AWARE_MPI
445 cmplx_t *p = rec_p[j];
446// gpuStreamSynchronize(0);
447#else
448 cmplx_t *p = send_p[i] = (cmplx_t *)memalloc(n);
449 gpuMemcpy(p, rec_p[j], n, gpuMemcpyDeviceToHost);
450#endif
451 MPI_Isend(p, (int)(n / mpi_type_size), mpi_type, fn.node, WRK_SCATTER_TAG,
452 lattice.mpi_comm_lat, &sendreq[i]);
453
454 i++;
455 }
456 j++;
457 }
458
459 // and wait for the send and receive to complete
460 if (n_comms > 0) {
461 MPI_Waitall(n_comms, recreq, stat);
462 MPI_Waitall(n_comms, sendreq, stat);
463
464#ifndef GPU_AWARE_MPI
465 i = 0;
466 for (auto &fn : hila_pencil_comms[dir]) {
467 if (fn.node != hila::myrank()) {
468
469 size_t n = fn.column_number * elements * lattice.mynode.size[dir] * sizeof(cmplx_t);
470 cmplx_t *p = send_buf + fn.column_offset * elements;
471
472 gpuMemcpy(p, receive_p[i], n, gpuMemcpyHostToDevice);
473 i++;
474 }
475 }
476
477 for (i = 0; i < n_comms; i++) {
478 free(receive_p[i]);
479 free(send_p[i]);
480 }
481#endif
482 }
483
484 pencil_MPI_timer.stop();
485}
486
487///////////////////////////////////////////////////////////////////////////////////
488/// Separate reflect operation
489/// Reflect flips the coordinates so that negative direction becomes positive,
490/// and x=0 plane remains,
491/// r(x) <- f(L - x) - note that x == 0 layer is as before
492/// r(0) = f(0), r(1) = f(L-1), r(2) = f(L-2) ...
493///////////////////////////////////////////////////////////////////////////////////
494
495/// Reflect data in the array
496template <typename cmplx_t>
497__global__ void hila_reflect_dir_kernel(cmplx_t *RESTRICT data, const int colsize,
498 const int columns) {
499
500 int ind = threadIdx.x + blockIdx.x * blockDim.x;
501 if (ind < columns) {
502 const int s = colsize * ind;
503
504 for (int i = 1; i < colsize / 2; i++) {
505 int i1 = s + i;
506 int i2 = s + colsize - i;
507 cmplx_t tmp = data[i1];
508 data[i1] = data[i2];
509 data[i2] = tmp;
510 }
511 }
512}
513
514
515template <typename cmplx_t>
517
518 // these externs defined in fft.cpp
519 extern unsigned hila_fft_my_columns[NDIM];
520
521 constexpr bool is_float = (sizeof(cmplx_t) == sizeof(Complex<float>));
522
523 int n_columns = hila_fft_my_columns[dir] * elements;
524
525 // reduce very large batch to smaller, avoid large buffer space
526
527 // alloc work array
528 cmplx_t *fft_wrk = (cmplx_t *)d_malloc(buf_size * sizeof(cmplx_t) * elements);
529
530 // Reorganize the data to form columns of a single element
531 // move from receive_buf to fft_wrk
532 // first need to copy index arrays to device
533
534 cmplx_t **d_ptr = (cmplx_t **)d_malloc(sizeof(cmplx_t *) * rec_p.size());
535 int *d_size = (int *)d_malloc(sizeof(int) * rec_p.size());
536
537 gpuMemcpy(d_ptr, rec_p.data(), rec_p.size() * sizeof(cmplx_t *), gpuMemcpyHostToDevice);
538 gpuMemcpy(d_size, rec_size.data(), rec_size.size() * sizeof(int), gpuMemcpyHostToDevice);
539
540 int N_blocks = (n_columns + N_threads - 1) / N_threads;
541
542#if defined(CUDA)
543 hila_fft_gather_column<cmplx_t><<<N_blocks, N_threads>>>(fft_wrk, d_ptr, d_size, rec_p.size(),
544 lattice.size(dir), n_columns);
545#else
546 hipLaunchKernelGGL(HIP_KERNEL_NAME(hila_fft_gather_column<cmplx_t>), dim3(N_blocks),
547 dim3(N_threads), 0, 0, fft_wrk, d_ptr, d_size, rec_p.size(),
548 lattice.size(dir), n_columns);
549#endif
550
551#if defined(CUDA)
552 hila_reflect_dir_kernel<cmplx_t>
553 <<<N_blocks, N_threads>>>(fft_wrk, lattice.size(dir), n_columns);
554#else
555 hipLaunchKernelGGL(HIP_KERNEL_NAME(hila_reflect_dir_kernel<cmplx_t>), dim3(N_blocks),
556 dim3(N_threads), 0, 0, fft_wrk, lattice.size(dir), n_columns);
557#endif
558
559
560#if defined(CUDA)
561 hila_fft_scatter_column<cmplx_t><<<N_blocks, N_threads>>>(fft_wrk, d_ptr, d_size, rec_p.size(),
562 lattice.size(dir), n_columns);
563#else
564 hipLaunchKernelGGL(HIP_KERNEL_NAME(hila_fft_scatter_column<cmplx_t>), dim3(N_blocks),
565 dim3(N_threads), 0, 0, fft_wrk, d_ptr, d_size, rec_p.size(),
566 lattice.size(dir), n_columns);
567#endif
568
569
570 d_free(d_size);
571 d_free(d_ptr);
572 d_free(fft_wrk);
573}
574
575
576#endif // HILAPP
577
578#endif
Complex definition.
Definition cmplx.h:50
void gather_data()
send column data to nodes
void scatter_data()
inverse of gather_data
void transform()
transform does the actual fft.
#define RESTRICT
Definition defs.h:50
int myrank()
rank of this node
Definition com_mpi.cpp:235
std::ostream out
this is our default output file stream
void terminate(int status)