1#ifndef FFTW_TRANSFORM_H
2#define FFTW_TRANSFORM_H
11template <
typename cmplx_t>
13 assert(0 &&
"Don't call this!");
18 extern unsigned hila_fft_my_columns[NDIM];
19 extern hila::timer fft_plan_timer, fft_buffer_timer, fft_execute_timer;
21 size_t n_fft = hila_fft_my_columns[dir] * elements;
24 (fftdir == fft_direction::forward) ? FFTW_FORWARD : FFTW_BACKWARD;
26 fft_plan_timer.start();
31 fftw_complex *fftwbuf =
32 (fftw_complex *)fftw_malloc(
sizeof(fftw_complex) * lattice.size(dir));
33 fftw_plan fftwplan = fftw_plan_dft_1d(lattice.size(dir), fftwbuf, fftwbuf,
34 transform_dir, FFTW_ESTIMATE);
36 fft_plan_timer.stop();
38 for (
size_t i = 0; i < n_fft; i++) {
41 fft_buffer_timer.start();
43 fftw_complex *cp = fftwbuf;
44 for (
int j = 0; j < rec_p.size(); j++) {
45 memcpy(cp, rec_p[j] + i * rec_size[j],
sizeof(fftw_complex) * rec_size[j]);
49 fft_buffer_timer.stop();
52 fft_execute_timer.start();
54 fftw_execute(fftwplan);
56 fft_execute_timer.stop();
58 fft_buffer_timer.start();
61 for (
int j = 0; j < rec_p.size(); j++) {
62 memcpy(rec_p[j] + i * rec_size[j], cp,
sizeof(fftw_complex) * rec_size[j]);
66 fft_buffer_timer.stop();
69 fftw_destroy_plan(fftwplan);
76 extern hila::timer fft_plan_timer, fft_buffer_timer, fft_execute_timer;
77 extern unsigned hila_fft_my_columns[NDIM];
79 size_t n_fft = hila_fft_my_columns[dir] * elements;
82 (fftdir == fft_direction::forward) ? FFTW_FORWARD : FFTW_BACKWARD;
84 fft_plan_timer.start();
89 fftwf_complex *fftwbuf =
90 (fftwf_complex *)fftwf_malloc(
sizeof(fftwf_complex) * lattice.size(dir));
91 fftwf_plan fftwplan = fftwf_plan_dft_1d(lattice.size(dir), fftwbuf, fftwbuf,
92 transform_dir, FFTW_ESTIMATE);
94 fft_plan_timer.stop();
96 for (
size_t i = 0; i < n_fft; i++) {
99 fft_buffer_timer.start();
101 fftwf_complex *cp = fftwbuf;
102 for (
int j = 0; j < rec_p.size(); j++) {
103 memcpy(cp, rec_p[j] + i * rec_size[j],
sizeof(fftwf_complex) * rec_size[j]);
107 fft_buffer_timer.stop();
110 fft_execute_timer.start();
112 fftwf_execute(fftwplan);
114 fft_execute_timer.stop();
116 fft_buffer_timer.start();
119 for (
int j = 0; j < rec_p.size(); j++) {
120 memcpy(rec_p[j] + i * rec_size[j], cp,
sizeof(fftwf_complex) * rec_size[j]);
124 fft_buffer_timer.stop();
127 fftwf_destroy_plan(fftwplan);
134template <
typename cmplx_t>
139 pencil_MPI_timer.start();
142 int n_comms = hila_pencil_comms[dir].size() - 1;
144 std::vector<MPI_Request> sendreq(n_comms), recreq(n_comms);
145 std::vector<MPI_Status> stat(n_comms);
149 for (
auto &fn : hila_pencil_comms[dir]) {
152 size_t siz = fn.recv_buf_size * elements *
sizeof(cmplx_t);
153 if (siz >= (1ULL << 31)) {
154 hila::out <<
"Too large MPI message in pencils! Size " << siz
159 MPI_Irecv(rec_p[j], (
int)siz, MPI_BYTE, fn.node, WRK_GATHER_TAG,
160 lattice.mpi_comm_lat, &recreq[i]);
168 for (
auto &fn : hila_pencil_comms[dir]) {
171 cmplx_t *p = send_buf + fn.column_offset * elements;
172 int n = fn.column_number * elements * lattice.mynode.size[dir] *
175 MPI_Isend(p, n, MPI_BYTE, fn.node, WRK_GATHER_TAG, lattice.mpi_comm_lat,
183 MPI_Waitall(n_comms, recreq.data(), stat.data());
184 MPI_Waitall(n_comms, sendreq.data(), stat.data());
187 pencil_MPI_timer.stop();
194template <
typename cmplx_t>
199 pencil_MPI_timer.start();
201 int n_comms = hila_pencil_comms[dir].size() - 1;
203 std::vector<MPI_Request> sendreq(n_comms), recreq(n_comms);
204 std::vector<MPI_Status> stat(n_comms);
208 for (
auto &fn : hila_pencil_comms[dir]) {
210 cmplx_t *p = send_buf + fn.column_offset * elements;
211 int n = fn.column_number * elements * lattice.mynode.size[dir] *
sizeof(cmplx_t);
213 MPI_Irecv(p, n, MPI_BYTE, fn.node, WRK_SCATTER_TAG,
214 lattice.mpi_comm_lat, &recreq[i]);
222 for (
auto &fn : hila_pencil_comms[dir]) {
225 MPI_Isend(rec_p[j], (
int)(fn.recv_buf_size * elements *
sizeof(cmplx_t)), MPI_BYTE, fn.node,
226 WRK_SCATTER_TAG, lattice.mpi_comm_lat, &sendreq[i]);
235 MPI_Waitall(n_comms, recreq.data(), stat.data());
236 MPI_Waitall(n_comms, sendreq.data(), stat.data());
239 pencil_MPI_timer.stop();
251template <
typename cmplx_t>
253 extern unsigned hila_fft_my_columns[NDIM];
254 extern hila::timer fft_plan_timer, fft_buffer_timer, fft_execute_timer;
256 const int ncols = hila_fft_my_columns[dir] * elements;
258 const int length = lattice.size(dir);
260 cmplx_t *buf = (cmplx_t *)memalloc(
sizeof(cmplx_t) * length);
262 for (
int i = 0; i < ncols; i++) {
266 for (
int j = 0; j < rec_p.size(); j++) {
267 memcpy(cp, rec_p[j] + i * rec_size[j],
sizeof(cmplx_t) * rec_size[j]);
272 for (
int j = 1; j < length / 2; j++) {
273 std::swap(buf[j], buf[length - j]);
277 for (
int j = 0; j < rec_p.size(); j++) {
278 memcpy(rec_p[j] + i * rec_size[j], cp,
sizeof(cmplx_t) * rec_size[j]);
void gather_data()
send column data to nodes
void scatter_data()
inverse of gather_data
void transform()
transform does the actual fft.
int myrank()
rank of this node
std::ostream out
this is our default output file stream
void terminate(int status)