1#ifndef FFT_GPU_TRANSFORM_H
2#define FFT_GPU_TRANSFORM_H
10using gpufftComplex = cufftComplex;
11using gpufftDoubleComplex = cufftDoubleComplex;
12using gpufftHandle = cufftHandle;
13#define gpufftExecC2C cufftExecC2C
14#define gpufftExecZ2Z cufftExecZ2Z
15#define gpufftPlan1d cufftPlan1d
16#define gpufftDestroy cufftDestroy
18#define GPUFFT_FORWARD CUFFT_FORWARD
19#define GPUFFT_INVERSE CUFFT_INVERSE
21#define GPUFFT_C2C CUFFT_C2C
22#define GPUFFT_Z2Z CUFFT_Z2Z
26#include "hip/hip_runtime.h"
27#include <hipfft/hipfft.h>
29using gpufftComplex = hipfftComplex;
30using gpufftDoubleComplex = hipfftDoubleComplex;
31using gpufftHandle = hipfftHandle;
32#define gpufftExecC2C hipfftExecC2C
33#define gpufftExecZ2Z hipfftExecZ2Z
34#define gpufftPlan1d hipfftPlan1d
35#define gpufftDestroy hipfftDestroy
37#define GPUFFT_FORWARD HIPFFT_FORWARD
38#define GPUFFT_INVERSE HIPFFT_BACKWARD
40#define GPUFFT_C2C HIPFFT_C2C
41#define GPUFFT_Z2Z HIPFFT_Z2Z
47template <
typename cmplx_t>
48__global__
void hila_fft_gather_column(cmplx_t *
RESTRICT data, cmplx_t *
RESTRICT *d_ptr,
49 int *
RESTRICT d_size,
int n,
int colsize,
int columns) {
51 int ind = threadIdx.x + blockIdx.x * blockDim.x;
53 int s = colsize * ind;
56 for (
int i = 0; i < n; i++) {
57 int offset = ind * d_size[i];
58 for (
int j = 0; j < d_size[i]; j++, k++) {
59 data[k] = d_ptr[i][j + offset];
66template <
typename cmplx_t>
67__global__
void hila_fft_scatter_column(cmplx_t *
RESTRICT data, cmplx_t *
RESTRICT *d_ptr,
68 int *
RESTRICT d_size,
int n,
int colsize,
int columns) {
70 int ind = threadIdx.x + blockIdx.x * blockDim.x;
72 int s = colsize * ind;
75 for (
int i = 0; i < n; i++) {
76 int offset = ind * d_size[i];
77 for (
int j = 0; j < d_size[i]; j++, k++) {
78 d_ptr[i][j + offset] = data[k];
88class hila_saved_fftplan_t {
99 std::vector<plan_d> plans;
101 hila_saved_fftplan_t() {
102 plans.reserve(N_PLANS);
106 ~hila_saved_fftplan_t() {
110 void delete_plans() {
111 for (
auto &p : plans) {
112 gpufftDestroy(p.plan);
120 gpufftHandle get_plan(
int size,
int batch,
bool is_float) {
128 for (
auto &p : plans) {
129 if (p.size == size && p.batch == batch && p.is_float == is_float) {
138 fft_plan_timer.start();
141 if (plans.size() == N_PLANS) {
144 for (
int i = 1; i < plans.size(); i++) {
145 if (pp->seq > plans[i].seq)
148 gpufftDestroy(pp->plan);
151 plans.push_back(empty);
159 pp->is_float = is_float;
164 gpufftPlan1d(&(pp->plan), size, is_float ? GPUFFT_C2C : GPUFFT_Z2Z, batch);
165 check_device_error(
"FFT plan");
167 fft_plan_timer.stop();
176template <
typename cmplx_t>
177using fft_cmplx_t =
typename std::conditional<
sizeof(gpufftComplex) ==
sizeof(cmplx_t),
178 gpufftComplex, gpufftDoubleComplex>::type;
182template <
typename cmplx_t, std::enable_if_t<sizeof(cmplx_t) == sizeof(gpufftComplex),
int> = 0>
183inline void hila_gpufft_execute(gpufftHandle plan, cmplx_t *buf,
int direction) {
184 gpufftExecC2C(plan, (gpufftComplex *)buf, (gpufftComplex *)buf, direction);
187template <
typename cmplx_t,
188 std::enable_if_t<
sizeof(cmplx_t) ==
sizeof(gpufftDoubleComplex),
int> = 0>
189inline void hila_gpufft_execute(gpufftHandle plan, cmplx_t *buf,
int direction) {
190 gpufftExecZ2Z(plan, (gpufftDoubleComplex *)buf, (gpufftDoubleComplex *)buf, direction);
193template <
typename cmplx_t>
197 extern unsigned hila_fft_my_columns[NDIM];
198 extern hila::timer fft_execute_timer, fft_buffer_timer;
199 extern hila_saved_fftplan_t hila_saved_fftplan;
201 constexpr bool is_float = (
sizeof(cmplx_t) ==
sizeof(
Complex<float>));
203 int n_columns = hila_fft_my_columns[dir] * elements;
205 int direction = (fftdir == fft_direction::forward) ? GPUFFT_FORWARD : GPUFFT_INVERSE;
210 int batch = hila_fft_my_columns[dir];
211 int n_fft = elements;
214 bool is_divisible =
true;
215 while (batch > GPUFFT_BATCH_SIZE && is_divisible) {
216 is_divisible =
false;
217 for (
int div : {2, 3, 5, 7}) {
218 if (batch % div == 0) {
228 plan = hila_saved_fftplan.get_plan(lattice.size(dir), batch, is_float);
232 cmplx_t *fft_wrk = (cmplx_t *)d_malloc(buf_size *
sizeof(cmplx_t) * elements);
238 fft_buffer_timer.start();
240 cmplx_t **d_ptr = (cmplx_t **)d_malloc(
sizeof(cmplx_t *) * rec_p.size());
241 int *d_size = (
int *)d_malloc(
sizeof(
int) * rec_p.size());
243 gpuMemcpy(d_ptr, rec_p.data(), rec_p.size() *
sizeof(cmplx_t *), gpuMemcpyHostToDevice);
244 gpuMemcpy(d_size, rec_size.data(), rec_size.size() *
sizeof(
int), gpuMemcpyHostToDevice);
246 int N_blocks = (n_columns + N_threads - 1) / N_threads;
249 hila_fft_gather_column<cmplx_t><<<N_blocks, N_threads>>>(fft_wrk, d_ptr, d_size, rec_p.size(),
250 lattice.size(dir), n_columns);
252 hipLaunchKernelGGL(HIP_KERNEL_NAME(hila_fft_gather_column<cmplx_t>), dim3(N_blocks),
253 dim3(N_threads), 0, 0, fft_wrk, d_ptr, d_size, rec_p.size(),
254 lattice.size(dir), n_columns);
257 fft_buffer_timer.stop();
260 fft_execute_timer.start();
262 for (
int i = 0; i < n_fft; i++) {
264 cmplx_t *cp = fft_wrk + i * (batch * lattice.size(dir));
266 hila_gpufft_execute(plan, cp, direction);
267 check_device_error(
"FFT execute");
270 fft_execute_timer.stop();
272 fft_buffer_timer.start();
275 hila_fft_scatter_column<cmplx_t><<<N_blocks, N_threads>>>(fft_wrk, d_ptr, d_size, rec_p.size(),
276 lattice.size(dir), n_columns);
278 hipLaunchKernelGGL(HIP_KERNEL_NAME(hila_fft_scatter_column<cmplx_t>), dim3(N_blocks),
279 dim3(N_threads), 0, 0, fft_wrk, d_ptr, d_size, rec_p.size(),
280 lattice.size(dir), n_columns);
283 fft_buffer_timer.stop();
294template <
typename cmplx_t>
299 pencil_MPI_timer.start();
302 int n_comms = hila_pencil_comms[dir].size() - 1;
304 std::vector<MPI_Request> sendreq(n_comms), recreq(n_comms);
305 std::vector<MPI_Status> stat(n_comms);
308 std::vector<cmplx_t *> send_p(n_comms);
309 std::vector<cmplx_t *> receive_p(n_comms);
316 gpuStreamSynchronize(0);
318 size_t mpi_type_size;
319 MPI_Datatype mpi_type = get_MPI_complex_type<cmplx_t>(mpi_type_size);
321 for (
auto &fn : hila_pencil_comms[dir]) {
324 size_t siz = fn.recv_buf_size * elements *
sizeof(cmplx_t);
325 if (siz >= (1ULL << 31) * mpi_type_size) {
326 hila::out <<
"Too large MPI message in pencils! Size " << siz <<
" bytes ("
327 << siz / mpi_type_size <<
" elements)\n";
332 cmplx_t *p = receive_p[i] = (cmplx_t *)memalloc(siz);
334 cmplx_t *p = rec_p[j];
337 MPI_Irecv(p, (
int)(siz / mpi_type_size), mpi_type, fn.node, WRK_GATHER_TAG,
338 lattice.mpi_comm_lat, &recreq[i]);
346 for (
auto &fn : hila_pencil_comms[dir]) {
349 cmplx_t *p = send_buf + fn.column_offset * elements;
350 size_t n = fn.column_number * elements * lattice.mynode.size[dir] *
sizeof(cmplx_t);
354 send_p[i] = (cmplx_t *)memalloc(n);
355 gpuMemcpy(send_p[i], p, n, gpuMemcpyDeviceToHost);
359 MPI_Isend(p, (
int)(n / mpi_type_size), mpi_type, fn.node, WRK_GATHER_TAG,
360 lattice.mpi_comm_lat, &sendreq[i]);
367 MPI_Waitall(n_comms, recreq.data(), stat.data());
368 MPI_Waitall(n_comms, sendreq.data(), stat.data());
373 for (
auto &fn : hila_pencil_comms[dir]) {
376 size_t siz = fn.recv_buf_size * elements;
378 gpuMemcpy(rec_p[j], receive_p[i], siz *
sizeof(cmplx_t), gpuMemcpyHostToDevice);
384 for (i = 0; i < n_comms; i++) {
391 pencil_MPI_timer.stop();
397template <
typename cmplx_t>
402 pencil_MPI_timer.start();
404 int n_comms = hila_pencil_comms[dir].size() - 1;
406 std::vector<MPI_Request> sendreq(n_comms), recreq(n_comms);
407 std::vector<MPI_Status> stat(n_comms);
410 std::vector<cmplx_t *> send_p(n_comms);
411 std::vector<cmplx_t *> receive_p(n_comms);
416 size_t mpi_type_size;
417 MPI_Datatype mpi_type = get_MPI_complex_type<cmplx_t>(mpi_type_size);
419 gpuStreamSynchronize(0);
421 for (
auto &fn : hila_pencil_comms[dir]) {
424 size_t n = fn.column_number * elements * lattice.mynode.size[dir] *
sizeof(cmplx_t);
426 cmplx_t *p = send_buf + fn.column_offset * elements;
428 cmplx_t *p = receive_p[i] = (cmplx_t *)memalloc(n);
431 MPI_Irecv(p, (
int)(n / mpi_type_size), mpi_type, fn.node, WRK_SCATTER_TAG,
432 lattice.mpi_comm_lat, &recreq[i]);
440 for (
auto &fn : hila_pencil_comms[dir]) {
443 size_t n = fn.recv_buf_size * elements *
sizeof(cmplx_t);
445 cmplx_t *p = rec_p[j];
448 cmplx_t *p = send_p[i] = (cmplx_t *)memalloc(n);
449 gpuMemcpy(p, rec_p[j], n, gpuMemcpyDeviceToHost);
451 MPI_Isend(p, (
int)(n / mpi_type_size), mpi_type, fn.node, WRK_SCATTER_TAG,
452 lattice.mpi_comm_lat, &sendreq[i]);
461 MPI_Waitall(n_comms, recreq.data(), stat.data());
462 MPI_Waitall(n_comms, sendreq.data(), stat.data());
466 for (
auto &fn : hila_pencil_comms[dir]) {
469 size_t n = fn.column_number * elements * lattice.mynode.size[dir] *
sizeof(cmplx_t);
470 cmplx_t *p = send_buf + fn.column_offset * elements;
472 gpuMemcpy(p, receive_p[i], n, gpuMemcpyHostToDevice);
477 for (i = 0; i < n_comms; i++) {
484 pencil_MPI_timer.stop();
496template <
typename cmplx_t>
497__global__
void hila_reflect_dir_kernel(cmplx_t *
RESTRICT data,
const int colsize,
500 int ind = threadIdx.x + blockIdx.x * blockDim.x;
502 const int s = colsize * ind;
504 for (
int i = 1; i < colsize / 2; i++) {
506 int i2 = s + colsize - i;
507 cmplx_t tmp = data[i1];
515template <
typename cmplx_t>
519 extern unsigned hila_fft_my_columns[NDIM];
521 constexpr bool is_float = (
sizeof(cmplx_t) ==
sizeof(
Complex<float>));
523 int n_columns = hila_fft_my_columns[dir] * elements;
528 cmplx_t *fft_wrk = (cmplx_t *)d_malloc(buf_size *
sizeof(cmplx_t) * elements);
534 cmplx_t **d_ptr = (cmplx_t **)d_malloc(
sizeof(cmplx_t *) * rec_p.size());
535 int *d_size = (
int *)d_malloc(
sizeof(
int) * rec_p.size());
537 gpuMemcpy(d_ptr, rec_p.data(), rec_p.size() *
sizeof(cmplx_t *), gpuMemcpyHostToDevice);
538 gpuMemcpy(d_size, rec_size.data(), rec_size.size() *
sizeof(
int), gpuMemcpyHostToDevice);
540 int N_blocks = (n_columns + N_threads - 1) / N_threads;
543 hila_fft_gather_column<cmplx_t><<<N_blocks, N_threads>>>(fft_wrk, d_ptr, d_size, rec_p.size(),
544 lattice.size(dir), n_columns);
546 hipLaunchKernelGGL(HIP_KERNEL_NAME(hila_fft_gather_column<cmplx_t>), dim3(N_blocks),
547 dim3(N_threads), 0, 0, fft_wrk, d_ptr, d_size, rec_p.size(),
548 lattice.size(dir), n_columns);
552 hila_reflect_dir_kernel<cmplx_t>
553 <<<N_blocks, N_threads>>>(fft_wrk, lattice.size(dir), n_columns);
555 hipLaunchKernelGGL(HIP_KERNEL_NAME(hila_reflect_dir_kernel<cmplx_t>), dim3(N_blocks),
556 dim3(N_threads), 0, 0, fft_wrk, lattice.size(dir), n_columns);
561 hila_fft_scatter_column<cmplx_t><<<N_blocks, N_threads>>>(fft_wrk, d_ptr, d_size, rec_p.size(),
562 lattice.size(dir), n_columns);
564 hipLaunchKernelGGL(HIP_KERNEL_NAME(hila_fft_scatter_column<cmplx_t>), dim3(N_blocks),
565 dim3(N_threads), 0, 0, fft_wrk, d_ptr, d_size, rec_p.size(),
566 lattice.size(dir), n_columns);
void gather_data()
send column data to nodes
void scatter_data()
inverse of gather_data
void transform()
transform does the actual fft.
int myrank()
rank of this node
std::ostream out
this is our default output file stream
void terminate(int status)