HILA
Loading...
Searching...
No Matches
fft.h
Go to the documentation of this file.
1/** @file fft.h */
2#ifndef HILA_FFT_H
3#define HILA_FFT_H
4
5#if !defined(CUDA) && !defined(HIP)
6#define USE_FFTW
7#endif
8
9#include "plumbing/defs.h"
10#include "datatypes/cmplx.h"
12#include "plumbing/field.h"
13#include "plumbing/timing.h"
14
15#ifdef USE_FFTW
16#include <fftw3.h>
17#endif
18
19// just some values here, make less than 100 just in case
20#define WRK_GATHER_TAG 42
21#define WRK_SCATTER_TAG 43
22
23/// convert lattice vector to to k-vector (needs lattice.size() so defined in fft.h)
24/// Note: mods the CoordinateVector to lattice
25
26
27#pragma hila novector
28template<typename T>
31 foralldir(d) {
32 int n = pmod((*this).e(d), lattice.size(d));
33 if (n > lattice.size(d) / 2)
34 n -= lattice.size(d);
35
36 k[d] = n * 2.0 * M_PI / lattice.size(d);
37 }
38 return k;
39}
40
41inline Vector<NDIM, double> convert_to_k(const CoordinateVector &cv) {
42 return cv.convert_to_k();
43}
44
45// hold static fft node data structures
46struct pencil_struct {
47 int node; // node rank to send stuff for fft:ing
48 unsigned size_to_dir; // size of "node" to fft-dir
49 unsigned column_offset; // first perp-plane column to be handled by "node"
50 unsigned column_number; // and number of columns to be sent
51 size_t recv_buf_size; // size of my fft collect buffer (in units of sizeof(T)
52 // for stuff received from / returned to "node"
53};
54
55//
56extern std::vector<pencil_struct> hila_pencil_comms[NDIM];
57
58/// Build offsets to buffer arrays:
59/// Fastest Direction = dir, offset 1
60/// next fastest index of complex_t elements in T, offset elem_offset
61/// and then other directions, in order
62/// Returns element_offset and sets offset and nmin vectors
63
64size_t pencil_get_buffer_offsets(const Direction dir, const size_t elements,
65 CoordinateVector &offset, CoordinateVector &nmin);
66
67/// Initialize fft direction - defined in fft.cpp
69
70// Helper class to transform data
71template <typename T, typename cmplx_t>
72union T_union {
73 T val;
74 cmplx_t c[sizeof(T) / sizeof(cmplx_t)];
75};
76
77/// Class to hold fft relevant variables - note: fft.cpp holds static info, which is not
78/// here
79
80template <typename cmplx_t>
81class hila_fft {
82 public:
83 Direction dir;
84 int elements;
85 fft_direction fftdir;
86 size_t buf_size;
87 size_t local_volume;
88
89 bool only_reflect;
90
91 cmplx_t *send_buf;
92 cmplx_t *receive_buf;
93
94 // data structures which point to to-be-copied buffers
95 std::vector<cmplx_t *> rec_p;
96 std::vector<int> rec_size;
97
98 // initialize fft, allocate buffers
99 hila_fft(int _elements, fft_direction _fftdir, bool _reflect = false) {
100 extern size_t pencil_recv_buf_size[NDIM];
101
102 elements = _elements;
103 fftdir = _fftdir;
104 only_reflect = _reflect;
105
106 local_volume = lattice.mynode.volume();
107
108 // init dirs here at one go
110
111 buf_size = 1;
112 foralldir(d) {
113 if (pencil_recv_buf_size[d] > buf_size)
114 buf_size = pencil_recv_buf_size[d];
115 }
116 if (buf_size < local_volume)
117 buf_size = local_volume;
118
119 // get fully aligned buffer space
120 send_buf = (cmplx_t *)d_malloc(buf_size * sizeof(cmplx_t) * elements);
121 receive_buf = (cmplx_t *)d_malloc(buf_size * sizeof(cmplx_t) * elements);
122 // if (buf_size > 0)
123 // fft_wrk_buf = (cmplx_t *)d_malloc(buf_size * sizeof(cmplx_t) *
124 // elements);
125 }
126
127 ~hila_fft() {
128 d_free(send_buf);
129 d_free(receive_buf);
130 // if (buf_size > 0)
131 // d_free(fft_wrk_buf);
132 }
133
134 // make_plan does the fft plan, as appropriate
135 void make_plan();
136 // the actual transform is done here. Custom for fftw and others
137 void transform();
138
139 // reflection using special call
140 void reflect();
141
142 /////////////////////////////////////////////////////////////////////////
143 /// Initialize fft to Direction dir.
144
146
147 dir = _dir;
148
149 // now in transform itself
150 // make_fft_plan();
151
152 rec_p.resize(hila_pencil_comms[dir].size());
153 rec_size.resize(hila_pencil_comms[dir].size());
154
155 cmplx_t *p = receive_buf;
156 int i = 0;
157 for (pencil_struct &fn : hila_pencil_comms[dir]) {
158
159 if (fn.node != hila::myrank()) {
160
161 // usually, out/in buffer is the same
162 rec_p[i] = p;
163 rec_size[i] = fn.size_to_dir;
164 p += fn.recv_buf_size * elements;
165
166 } else {
167
168 // for local node, point directly to send_buf arrays
169
170 rec_p[i] = send_buf + fn.column_offset * elements;
171 rec_size[i] = fn.size_to_dir;
172 }
173 i++;
174 }
175 }
176
177 /// Collect the data from field to send_buf for sending or fft'ing.
178 /// Order: Direction dir goes fastest, then the index to complex data in T,
179 /// and other directions are slowest.
180
181 template <typename T>
182 void collect_data(const Field<T> &f) {
183
184 extern hila::timer pencil_collect_timer;
185 pencil_collect_timer.start();
186
187 constexpr int elements = sizeof(T) / sizeof(cmplx_t);
188
189 // Build vector offset, which encodes where the data should be written
190 // elem_offset is the same for the offset of the elements of T
191 CoordinateVector offset, nmin;
192
193 const size_t elem_offset =
194 pencil_get_buffer_offsets(dir, sizeof(T) / sizeof(cmplx_t), offset, nmin);
195
196 cmplx_t *sb = send_buf;
197
198 // and collect the data
199#pragma hila novector direct_access(sb)
200 onsites(ALL) {
201
202 T_union<T, cmplx_t> v;
203 v.val = f[X];
204 int off = offset.dot(X.coordinates() - nmin);
205 for (int i = 0; i < elements; i++) {
206 sb[off + i * elem_offset] = v.c[i];
207 }
208 }
209
210 pencil_collect_timer.stop();
211 }
212
213 /// Inverse of the fft_collect_data: write fft'd data from receive_buf to field.
214
215 template <typename T>
217
218 constexpr int elements = sizeof(T) / sizeof(cmplx_t);
219
220 extern hila::timer pencil_save_timer;
221 pencil_save_timer.start();
222
223 // Build vector offset, which encodes where the data should be written
224 CoordinateVector offset, nmin;
225
226 const size_t elem_offset = pencil_get_buffer_offsets(dir, elements, offset, nmin);
227
228 cmplx_t *rb = receive_buf;
229
230// and collect the data from buffers
231#pragma hila novector direct_access(rb)
232 onsites(ALL) {
233
234 T_union<T, cmplx_t> v;
235
236 size_t off = offset.dot(X.coordinates() - nmin);
237 for (int i = 0; i < elements; i++) {
238 v.c[i] = rb[off + i * elem_offset];
239 }
240 f[X] = v.val;
241 }
242
243 pencil_save_timer.stop();
244 }
245
246 /////////////////////////////////////////////////////////////////////////////
247 /// Reshuffle data, given that previous fft dir was to prev_dir and now to dir
248 /// Assuming here that the data is in receive_buf after fft and copy to send_buf
249 /// This requires swapping send_buf and receive_buf ptrs after 1 fft
250
251 void reshuffle_data(Direction prev_dir) {
252
253 extern hila::timer pencil_reshuffle_timer;
254 pencil_reshuffle_timer.start();
255
256 int elem = elements;
257
258 CoordinateVector offset_in, offset_out, nmin;
259
260 const size_t e_offset_in = pencil_get_buffer_offsets(prev_dir, elements, offset_in, nmin);
261 const size_t e_offset_out = pencil_get_buffer_offsets(dir, elements, offset_out, nmin);
262
263 cmplx_t *sb = send_buf;
264 cmplx_t *rb = receive_buf;
265
266#pragma hila novector direct_access(sb, rb)
267 onsites(ALL) {
268 CoordinateVector v = X.coordinates() - nmin;
269 size_t off_in = offset_in.dot(v);
270 size_t off_out = offset_out.dot(v);
271 for (int e = 0; e < elem; e++) {
272 sb[off_out + e * e_offset_out] = rb[off_in + e * e_offset_in];
273 }
274 }
275
276 pencil_reshuffle_timer.stop();
277 }
278
279 // free the work buffers
280 void cleanup() {}
281
282 // just swap the buf pointers
283 void swap_buffers() {
284 std::swap(send_buf, receive_buf);
285 }
286
287 // communication functions for slicing the lattice
288 void scatter_data();
289 void gather_data();
290
291 ////////////////////////////////////////////////////////////////////////
292 /// Do the transform itself (fft or reflect only)
293
294 template <typename T>
295 void full_transform(const Field<T> &input, Field<T> &result,
296 const CoordinateVector &directions) {
297
298 assert(lattice.id() == input.fs->lattice_id && "Default lattice mismatch in fft");
299
300 // Make sure the result is allocated and mark it changed
301 result.check_alloc();
302
303 bool first_dir = true;
304 Direction prev_dir;
305
306 foralldir(dir) {
307 if (directions[dir]) {
308
309 setup_direction(dir);
310
311 if (first_dir) {
312 collect_data(input);
313 // in_p = &result;
314 } else {
315 reshuffle_data(prev_dir);
316 }
317
318 gather_data();
319
320 if (!only_reflect)
321 transform();
322 else
323 reflect();
324
325 scatter_data();
326
327 cleanup();
328
329 // fft_save_result( result, dir, receive_buf );
330
331 prev_dir = dir;
332 first_dir = false;
333
334 // swap the pointers
335 swap_buffers();
336 }
337 }
338
339 save_result(result);
340
341 result.mark_changed(ALL);
342 }
343};
344
345// prototype for plan deletion
346void FFT_delete_plans();
347
348
349// Implementation dependent core fft collect and transforms are defined here
350
351#if defined(USE_FFTW)
352
353#include "plumbing/fft_fftw_transform.h"
354
355#elif defined(CUDA) || defined(HIP)
356
357#include "plumbing/backend_gpu/fft_gpu_transform.h"
358
359#endif
360
361/////////////////////////////////////////////////////////////////////////////////////////
362/// Complex-to-complex FFT transform of a field input, result in result.
363/// input and result can be same, "in-place".
364/// Both input and output are of type Field<T>, where T must contain complex type,
365/// Complex<float> or Complex<double>.
366/// directions: if directions[dir] == false (or 0), transform is not done to
367/// direction dir. fftdir: direction of the transform itself:
368/// fft_direction::forward (default) x -> k
369/// fft_direction::inverse k-> x
370/// FFT is unnormalized: transform + inverse transform yields source multiplied
371/// by the product of the size of the lattice to active directions
372/// If all directions are active, result = source * lattice.volume():
373/////////////////////////////////////////////////////////////////////////////////////////
374
375template <typename T>
376inline void FFT_field(const Field<T> &input, Field<T> &result, const CoordinateVector &directions,
377 fft_direction fftdir = fft_direction::forward) {
378
379 static_assert(hila::contains_complex<T>::value,
380 "FFT_field argument fields must contain complex type");
381
382 // get the type of the complex number here
383 using cmplx_t = Complex<hila::arithmetic_type<T>>;
384 constexpr size_t elements = sizeof(T) / sizeof(cmplx_t);
385
386 extern hila::timer fft_timer;
387 fft_timer.start();
388
389 hila_fft<cmplx_t> fft(elements, fftdir);
390
391 fft.full_transform(input, result, directions);
392
393 fft_timer.stop();
394}
395
396//////////////////////////////////////////////////////////////////////////////////
397///
398/// Complex-to-complex FFT transform of a field input, result in result.
399/// Same as FFT_field(input,result,directions,fftdir)
400/// with all directions active.
401///
402//////////////////////////////////////////////////////////////////////////////////
403
404template <typename T>
405inline void FFT_field(const Field<T> &input, Field<T> &result,
406 fft_direction fftdir = fft_direction::forward) {
407
408 CoordinateVector dirs;
409 dirs.fill(true); // set all directions OK
410
411 FFT_field(input, result, dirs, fftdir);
412}
413
414
415/**
416 * @brief Field method for performing FFT
417 * @details
418 * By default calling without arguments will execute FFT in all directions.
419 * @code{.cpp}
420 * .
421 * . // Field f is defined
422 * .
423 * auto res = f.FFT() //Forward transform
424 * auto res_2 = res.FFT(fft_direction::back) // res_2 is same as f
425 * @endcode
426 *
427 * One can also specify the direction of the FFT with a coordinate vector:
428 * @code{.cpp}
429 * .
430 * . // Field f is defined
431 * .
432 * auto res = f.FFT(e_x) //Forward transform in x-direction
433 * auto res_2 = res.FFT(e_X,fft_direction::back) // res_2 is same as f
434 * @endcode
435 *
436 * With this in mind `f.FFT(e_x+e_y+e_z) = f.FFT()`
437 *
438 * @tparam T
439 * @param dirs Direction to perform FFT in, default is all directions
440 * @param fftdir fft_direction::forward (default) or fft_direction::back
441 * @return Field<T> Transformed field
442 */
443template <typename T>
445 Field<T> res;
446 FFT_field(*this, res, dirs, fftdir);
447 return res;
448}
449
450template <typename T>
453 cv.fill(true);
454 return FFT(cv, fftdir);
455}
456
457
458//////////////////////////////////////////////////////////////////////////////////
459/// FFT_real_to_complex:
460/// Field must be a real-valued field, result is a complex-valued field of the same type
461/// Implemented just by doing a FFT with a complex field with im=0;
462/// fft_direction::back gives a complex conjugate of the forward transform
463/// Result is f(-x) = f(L - x) = f(x)^*
464//////////////////////////////////////////////////////////////////////////////////
465
466template <typename T>
468
469 static_assert(hila::is_arithmetic<T>::value,
470 "FFT_real_to_complex can be applied only to Field<real-type> variable");
471
473 cf[ALL] = Complex<T>((*this)[X], 0.0);
474 return cf.FFT(fftdir);
475}
477//////////////////////////////////////////////////////////////////////////////////
478/// FFT_complex_to_real;
479/// Field must be a complex-valued field, result is a real field of the same number type
480/// Not optimized, should not be used on a hot path
481///
482/// Because the complex field must have the property f(-x) = f(L-x) = f(x)^*, only
483/// half of the values in input field are significant, the routine does the appropriate
484/// symmetrization.
485///
486/// Routine FFT_complex_to_real_loc(CoordinateVector cv) gives the significant values at
487/// location cv:
488/// = +1 significant complex value,
489/// = 0 significant real part, imag ignored
490/// = -1 value ignored here
491/// Example: in 2d 8x8 lattice the sites are: (* = (0,0), value 0)
492///
493/// - + + + - - - - - - - 0 + + + 0
494/// - + + + - - - - - - - + + + + +
495/// - + + + - - - - after centering - - - + + + + +
496/// 0 + + + 0 - - - (0,0) to center - - - + + + + +
497/// + + + + + - - - -----------------> - - - * + + + 0
498/// + + + + + - - - - - - - + + + -
499/// + + + + + - - - - - - - + + + -
500/// * + + + 0 - - - - - - - + + + -
501///
502//////////////////////////////////////////////////////////////////////////////////
503
505
506 // foralldir continues only if cv[d] == 0 or cv[d] == size(d)/2
507 foralldir(d) {
508 if (cv[d] > 0 && cv[d] < lattice.size(d) / 2)
509 return 1;
510 if (cv[d] > lattice.size(d) / 2)
511 return -1;
512 }
513 // we get here only if all coords are 0 or size(d)/2
514 return 0;
515}
516
517
518template <typename T>
520
521 static_assert(hila::is_complex<T>::value,
522 "FFT_complex_to_real can be applied only to Field<Complex<>> type variable");
523
524 // first, do a full reflection of the field, giving rf(x) = f(L-x) = "f(-x)"
525 auto rf = this->reflect();
526 // And symmetrize the field appropriately - can use rf
527 onsites(ALL) {
528 int type = FFT_complex_to_real_loc(X.coordinates());
529 if (type == 1) {
530 rf[X] = (*this)[X];
531 } else if (type == -1) {
532 rf[X] = rf[X].conj();
533 } else {
534 rf[X].real() = (*this)[X].real();
535 rf[X].imag() = 0;
536 }
537 }
538
539 FFT_field(rf, rf, fftdir);
540
541 double ims = 0;
542 double rss = 0;
543 onsites(ALL) {
544 ims += ::squarenorm(rf[X].imag());
545 rss += ::squarenorm(rf[X].real());
546 }
547
549 onsites(ALL) res[X] = rf[X].real();
550 return res;
551}
552
553
554//////////////////////////////////////////////////////////////////////////////////
555/// Field<T>::reflect() reflects the field around the desired axis
556/// This is here because it uses similar communications as fft
557/// TODO: refactorise so that there is separate "make columns" class!
558
559template <typename T>
561
562 constexpr int elements = 1;
563
564 Field<T> result;
565
566 hila_fft<T> refl(elements, fft_direction::forward, true);
567 refl.full_transform(*this, result, dirs);
568
569 return result;
570}
571
572template <typename T>
574
576 c.fill(true);
577 return reflect(c);
578}
579
580template <typename T>
582
584 c.fill(false);
585 c[dir] = true;
586 return reflect(c);
587}
588
589
590#endif
Array< n, m, hila::arithmetic_type< T > > imag(const Array< n, m, T > &arg)
Return imaginary part of Array.
Definition array.h:702
hila::arithmetic_type< T > squarenorm(const Array< n, m, T > &rhs)
Return square norm of Array.
Definition array.h:1018
Array< n, m, hila::arithmetic_type< T > > real(const Array< n, m, T > &arg)
Return real part of Array.
Definition array.h:688
Complex definition.
Definition cmplx.h:56
Vector< 4, double > convert_to_k() const
Convert momentum space CoordinateVector to wave number k, where -pi/2 < k_i <= pi_2 Utility function ...
Definition fft.h:29
The field class implements the standard methods for accessing Fields. Hilapp replaces the parity acce...
Definition field.h:61
Field< A > real() const
Returns real part of Field.
Definition field.h:1137
field_struct *__restrict__ fs
Field::fs holds all of the field content in Field::field_struct.
Definition field.h:243
Field< Complex< hila::arithmetic_type< T > > > FFT_real_to_complex(fft_direction fdir=fft_direction::forward) const
Definition fft.h:467
void check_alloc()
Allocate Field if it is not already allocated.
Definition field.h:459
const auto & fill(const S rhs)
Matrix fill.
Definition matrix.h:937
Matrix class which defines matrix operations.
Definition matrix.h:1620
Definition fft.h:81
void gather_data()
send column data to nodes
void setup_direction(Direction _dir)
Initialize fft to Direction dir.
Definition fft.h:145
void collect_data(const Field< T > &f)
Definition fft.h:182
void scatter_data()
inverse of gather_data
void reshuffle_data(Direction prev_dir)
Definition fft.h:251
void transform()
transform does the actual fft.
void save_result(Field< T > &f)
Inverse of the fft_collect_data: write fft'd data from receive_buf to field.
Definition fft.h:216
void full_transform(const Field< T > &input, Field< T > &result, const CoordinateVector &directions)
Do the transform itself (fft or reflect only)
Definition fft.h:295
Definition of Complex types.
This header file defines:
int pmod(const int a, const int b)
#define foralldir(d)
Macro to loop over (all) Direction(s)
Definition coordinates.h:78
Direction
Enumerator for direction that assigns integer to direction to be interpreted as unit vector.
Definition coordinates.h:34
constexpr Parity ALL
bit pattern: 011
This file defines all includes for HILA.
fft_direction
define a class for FFT direction
Definition defs.h:150
void init_pencil_direction(Direction d)
Initialize fft direction - defined in fft.cpp.
Definition fft.cpp:61
size_t pencil_get_buffer_offsets(const Direction dir, const size_t elements, CoordinateVector &offset, CoordinateVector &nmin)
Definition fft.cpp:39
void FFT_field(const Field< T > &input, Field< T > &result, const CoordinateVector &directions, fft_direction fftdir=fft_direction::forward)
Definition fft.h:376
int FFT_complex_to_real_loc(const CoordinateVector &cv)
Definition fft.h:504
This files containts definitions for the Field class and the classes required to define it such as fi...
int myrank()
rank of this node
Definition com_mpi.cpp:234