HILA
Loading...
Searching...
No Matches
fft_fftw_transform.h
1#ifndef FFTW_TRANSFORM_H
2#define FFTW_TRANSFORM_H
3
4/// Define hila_fft<>::transform() -functions and
5/// scatter() / gather() -functions for fftw
6///
7/// This is not a standalone header, it is meant to be #include'd from
8/// fft.h .
9
10/// transform does the actual fft.
11template <typename cmplx_t>
13 assert(0 && "Don't call this!");
14}
15
16template <>
17inline void hila_fft<Complex<double>>::transform() {
18 extern unsigned hila_fft_my_columns[NDIM];
19 extern hila::timer fft_plan_timer, fft_buffer_timer, fft_execute_timer;
20
21 size_t n_fft = hila_fft_my_columns[dir] * elements;
22
23 int transform_dir =
24 (fftdir == fft_direction::forward) ? FFTW_FORWARD : FFTW_BACKWARD;
25
26 fft_plan_timer.start();
27
28 // allocate here fftw plans. TODO: perhaps store, if plans take appreciable time?
29 // Timer will tell the proportional timing
30
31 fftw_complex *fftwbuf =
32 (fftw_complex *)fftw_malloc(sizeof(fftw_complex) * lattice.size(dir));
33 fftw_plan fftwplan = fftw_plan_dft_1d(lattice.size(dir), fftwbuf, fftwbuf,
34 transform_dir, FFTW_ESTIMATE);
35
36 fft_plan_timer.stop();
37
38 for (size_t i = 0; i < n_fft; i++) {
39 // collect stuff from buffers
40
41 fft_buffer_timer.start();
42
43 fftw_complex *cp = fftwbuf;
44 for (int j = 0; j < rec_p.size(); j++) {
45 memcpy(cp, rec_p[j] + i * rec_size[j], sizeof(fftw_complex) * rec_size[j]);
46 cp += rec_size[j];
47 }
48
49 fft_buffer_timer.stop();
50
51 // do the fft
52 fft_execute_timer.start();
53
54 fftw_execute(fftwplan);
55
56 fft_execute_timer.stop();
57
58 fft_buffer_timer.start();
59
60 cp = fftwbuf;
61 for (int j = 0; j < rec_p.size(); j++) {
62 memcpy(rec_p[j] + i * rec_size[j], cp, sizeof(fftw_complex) * rec_size[j]);
63 cp += rec_size[j];
64 }
65
66 fft_buffer_timer.stop();
67 }
68
69 fftw_destroy_plan(fftwplan);
70 fftw_free(fftwbuf);
71}
72
73template <>
74inline void hila_fft<Complex<float>>::transform() {
75
76 extern hila::timer fft_plan_timer, fft_buffer_timer, fft_execute_timer;
77 extern unsigned hila_fft_my_columns[NDIM];
78
79 size_t n_fft = hila_fft_my_columns[dir] * elements;
80
81 int transform_dir =
82 (fftdir == fft_direction::forward) ? FFTW_FORWARD : FFTW_BACKWARD;
83
84 fft_plan_timer.start();
85
86 // allocate here fftw plans. TODO: perhaps store, if plans take appreciable time?
87 // Timer will tell the proportional timing
88
89 fftwf_complex *fftwbuf =
90 (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * lattice.size(dir));
91 fftwf_plan fftwplan = fftwf_plan_dft_1d(lattice.size(dir), fftwbuf, fftwbuf,
92 transform_dir, FFTW_ESTIMATE);
93
94 fft_plan_timer.stop();
95
96 for (size_t i = 0; i < n_fft; i++) {
97 // collect stuff from buffers
98
99 fft_buffer_timer.start();
100
101 fftwf_complex *cp = fftwbuf;
102 for (int j = 0; j < rec_p.size(); j++) {
103 memcpy(cp, rec_p[j] + i * rec_size[j], sizeof(fftwf_complex) * rec_size[j]);
104 cp += rec_size[j];
105 }
106
107 fft_buffer_timer.stop();
108
109 // do the fft
110 fft_execute_timer.start();
111
112 fftwf_execute(fftwplan);
113
114 fft_execute_timer.stop();
115
116 fft_buffer_timer.start();
117
118 cp = fftwbuf;
119 for (int j = 0; j < rec_p.size(); j++) {
120 memcpy(rec_p[j] + i * rec_size[j], cp, sizeof(fftwf_complex) * rec_size[j]);
121 cp += rec_size[j];
122 }
123
124 fft_buffer_timer.stop();
125 }
126
127 fftwf_destroy_plan(fftwplan);
128 fftwf_free(fftwbuf);
129}
130
131////////////////////////////////////////////////////////////////////
132/// send column data to nodes
133
134template <typename cmplx_t>
136
137
138 extern hila::timer pencil_MPI_timer;
139 pencil_MPI_timer.start();
140
141 // post receive and send
142 int n_comms = hila_pencil_comms[dir].size() - 1;
143
144 std::vector<MPI_Request> sendreq(n_comms), recreq(n_comms);
145 std::vector<MPI_Status> stat(n_comms);
146
147 int i = 0;
148 int j = 0;
149 for (auto &fn : hila_pencil_comms[dir]) {
150 if (fn.node != hila::myrank()) {
151
152 size_t siz = fn.recv_buf_size * elements * sizeof(cmplx_t);
153 if (siz >= (1ULL << 31)) {
154 hila::out << "Too large MPI message in pencils! Size " << siz
155 << " bytes\n";
157 }
158
159 MPI_Irecv(rec_p[j], (int)siz, MPI_BYTE, fn.node, WRK_GATHER_TAG,
160 lattice.mpi_comm_lat, &recreq[i]);
161
162 i++;
163 }
164 j++;
165 }
166
167 i = 0;
168 for (auto &fn : hila_pencil_comms[dir]) {
169 if (fn.node != hila::myrank()) {
170
171 cmplx_t *p = send_buf + fn.column_offset * elements;
172 int n = fn.column_number * elements * lattice.mynode.size[dir] *
173 sizeof(cmplx_t);
174
175 MPI_Isend(p, n, MPI_BYTE, fn.node, WRK_GATHER_TAG, lattice.mpi_comm_lat,
176 &sendreq[i]);
177 i++;
178 }
179 }
180
181 // and wait for the send and receive to complete
182 if (n_comms > 0) {
183 MPI_Waitall(n_comms, recreq.data(), stat.data());
184 MPI_Waitall(n_comms, sendreq.data(), stat.data());
185 }
186
187 pencil_MPI_timer.stop();
188
189}
190
191//////////////////////////////////////////////////////////////////////////////////////
192/// inverse of gather_data
193
194template <typename cmplx_t>
196
197
198 extern hila::timer pencil_MPI_timer;
199 pencil_MPI_timer.start();
200
201 int n_comms = hila_pencil_comms[dir].size() - 1;
202
203 std::vector<MPI_Request> sendreq(n_comms), recreq(n_comms);
204 std::vector<MPI_Status> stat(n_comms);
205
206 int i = 0;
207
208 for (auto &fn : hila_pencil_comms[dir]) {
209 if (fn.node != hila::myrank()) {
210 cmplx_t *p = send_buf + fn.column_offset * elements;
211 int n = fn.column_number * elements * lattice.mynode.size[dir] * sizeof(cmplx_t);
212
213 MPI_Irecv(p, n, MPI_BYTE, fn.node, WRK_SCATTER_TAG,
214 lattice.mpi_comm_lat, &recreq[i]);
215
216 i++;
217 }
218 }
219
220 i = 0;
221 int j = 0;
222 for (auto &fn : hila_pencil_comms[dir]) {
223 if (fn.node != hila::myrank()) {
224
225 MPI_Isend(rec_p[j], (int)(fn.recv_buf_size * elements * sizeof(cmplx_t)), MPI_BYTE, fn.node,
226 WRK_SCATTER_TAG, lattice.mpi_comm_lat, &sendreq[i]);
227
228 i++;
229 }
230 j++;
231 }
232
233 // and wait for the send and receive to complete
234 if (n_comms > 0) {
235 MPI_Waitall(n_comms, recreq.data(), stat.data());
236 MPI_Waitall(n_comms, sendreq.data(), stat.data());
237 }
238
239 pencil_MPI_timer.stop();
240
241}
242
243///////////////////////////////////////////////////////////////////////////////////
244/// Separate reflect operation
245/// Reflect flips the coordinates so that negative direction becomes positive,
246/// and x=0 plane remains,
247/// r(x) <- f(L - x) - note that x == 0 layer is as before
248/// r(0) = f(0), r(1) = f(L-1), r(2) = f(L-2) ...
249///////////////////////////////////////////////////////////////////////////////////
250
251template <typename cmplx_t>
252inline void hila_fft<cmplx_t>::reflect() {
253 extern unsigned hila_fft_my_columns[NDIM];
254 extern hila::timer fft_plan_timer, fft_buffer_timer, fft_execute_timer;
255
256 const int ncols = hila_fft_my_columns[dir] * elements;
257
258 const int length = lattice.size(dir);
259
260 cmplx_t *buf = (cmplx_t *)memalloc(sizeof(cmplx_t) * length);
261
262 for (int i = 0; i < ncols; i++) {
263 // collect stuff from buffers
264
265 cmplx_t *cp = buf;
266 for (int j = 0; j < rec_p.size(); j++) {
267 memcpy(cp, rec_p[j] + i * rec_size[j], sizeof(cmplx_t) * rec_size[j]);
268 cp += rec_size[j];
269 }
270
271 // reflect
272 for (int j = 1; j < length / 2; j++) {
273 std::swap(buf[j], buf[length - j]);
274 }
275
276 cp = buf;
277 for (int j = 0; j < rec_p.size(); j++) {
278 memcpy(rec_p[j] + i * rec_size[j], cp, sizeof(cmplx_t) * rec_size[j]);
279 cp += rec_size[j];
280 }
281 }
282
283 free(buf);
284}
285
286
287#endif
Definition fft.h:81
void gather_data()
send column data to nodes
void scatter_data()
inverse of gather_data
void transform()
transform does the actual fft.
int myrank()
rank of this node
Definition com_mpi.cpp:235
std::ostream out
this is our default output file stream
void terminate(int status)