HILA
Loading...
Searching...
No Matches
fft.h
Go to the documentation of this file.
1#ifndef HILA_FFT_H
2#define HILA_FFT_H
3
4/** @file fft.h */
5
6#if !defined(CUDA) && !defined(HIP)
7#define USE_FFTW
8#endif
9
10#include "plumbing/defs.h"
11#include "datatypes/cmplx.h"
13#include "plumbing/field.h"
14#include "plumbing/timing.h"
15
16#ifdef USE_FFTW
17#include <fftw3.h>
18#endif
19
20// just some values here, make less than 100 just in case
21#define WRK_GATHER_TAG 42
22#define WRK_SCATTER_TAG 43
23
24
25/**
26 * @brief Convert momentum space CoordinateVector to wave number k, where -pi/2 < k_i <= pi/2
27 *
28 * CoordinateVector is (periodically) modded to valid lattice coordinate,
29 * and folded so that if n_i > lattice.size(i), n_i = n_i - lattice.size(i)
30 * now k_i = 2 pi n_i / lattice.size(i)
31 */
32
33
34#pragma hila novector
35template <typename T>
38 foralldir(d) {
39 int n = pmod((*this).e(d), lattice.size(d));
40 if (n > lattice.size(d) / 2)
41 n -= lattice.size(d);
42
43 k[d] = n * 2.0 * M_PI / lattice.size(d);
44 }
45 return k;
46}
47
48inline Vector<NDIM, double> convert_to_k(const CoordinateVector &cv) {
49 return cv.convert_to_k();
50}
51
52// hold static fft node data structures
53struct pencil_struct {
54 int node; // node rank to send stuff for fft:ing
55 unsigned size_to_dir; // size of "node" to fft-dir
56 unsigned column_offset; // first perp-plane column to be handled by "node"
57 unsigned column_number; // and number of columns to be sent
58 size_t recv_buf_size; // size of my fft collect buffer (in units of sizeof(T)
59 // for stuff received from / returned to "node"
60};
61
62//
63extern std::vector<pencil_struct> hila_pencil_comms[NDIM];
64
65/// Build offsets to buffer arrays:
66/// Fastest Direction = dir, offset 1
67/// next fastest index of complex_t elements in T, offset elem_offset
68/// and then other directions, in order
69/// Returns element_offset and sets offset and nmin vectors
70
71size_t pencil_get_buffer_offsets(const Direction dir, const size_t elements,
72 CoordinateVector &offset, CoordinateVector &nmin);
73
74/// Initialize fft direction - defined in fft.cpp
76
77// Helper class to transform data
78template <typename T, typename cmplx_t>
79union T_union {
80 T val;
81 cmplx_t c[sizeof(T) / sizeof(cmplx_t)];
82};
83
84/// Class to hold fft relevant variables - note: fft.cpp holds static info, which is not
85/// here
86
87template <typename cmplx_t>
88class hila_fft {
89 public:
90 Direction dir;
91 int elements;
92 fft_direction fftdir;
93 size_t buf_size;
94 size_t local_volume;
95
96 bool only_reflect;
97
98 cmplx_t *send_buf;
99 cmplx_t *receive_buf;
100
101 // data structures which point to to-be-copied buffers
102 std::vector<cmplx_t *> rec_p;
103 std::vector<int> rec_size;
104
105 // initialize fft, allocate buffers
106 hila_fft(int _elements, fft_direction _fftdir, bool _reflect = false) {
107 extern size_t pencil_recv_buf_size[NDIM];
108
109 elements = _elements;
110 fftdir = _fftdir;
111 only_reflect = _reflect;
112
113 local_volume = lattice->mynode.volume;
114
115 // init dirs here at one go
117
118 buf_size = 1;
119 foralldir(d) {
120 if (pencil_recv_buf_size[d] > buf_size)
121 buf_size = pencil_recv_buf_size[d];
122 }
123 if (buf_size < local_volume)
124 buf_size = local_volume;
125
126 // get fully aligned buffer space
127 send_buf = (cmplx_t *)d_malloc(buf_size * sizeof(cmplx_t) * elements);
128 receive_buf = (cmplx_t *)d_malloc(buf_size * sizeof(cmplx_t) * elements);
129 // if (buf_size > 0)
130 // fft_wrk_buf = (cmplx_t *)d_malloc(buf_size * sizeof(cmplx_t) *
131 // elements);
132 }
133
134 ~hila_fft() {
135 d_free(send_buf);
136 d_free(receive_buf);
137 // if (buf_size > 0)
138 // d_free(fft_wrk_buf);
139 }
140
141 // make_plan does the fft plan, as appropriate
142 void make_plan();
143 // the actual transform is done here. Custom for fftw and others
144 void transform();
145
146 // reflection using special call
147 void reflect();
148
149 /////////////////////////////////////////////////////////////////////////
150 /// Initialize fft to Direction dir.
151
153
154 dir = _dir;
155
156 // now in transform itself
157 // make_fft_plan();
158
159 rec_p.resize(hila_pencil_comms[dir].size());
160 rec_size.resize(hila_pencil_comms[dir].size());
161
162 cmplx_t *p = receive_buf;
163 int i = 0;
164 for (pencil_struct &fn : hila_pencil_comms[dir]) {
165
166 if (fn.node != hila::myrank()) {
167
168 // usually, out/in buffer is the same
169 rec_p[i] = p;
170 rec_size[i] = fn.size_to_dir;
171 p += fn.recv_buf_size * elements;
172
173 } else {
174
175 // for local node, point directly to send_buf arrays
176
177 rec_p[i] = send_buf + fn.column_offset * elements;
178 rec_size[i] = fn.size_to_dir;
179 }
180 i++;
181 }
182 }
183
184 /// Collect the data from field to send_buf for sending or fft'ing.
185 /// Order: Direction dir goes fastest, then the index to complex data in T,
186 /// and other directions are slowest.
187
188 template <typename T>
189 void collect_data(const Field<T> &f) {
190
191 extern hila::timer pencil_collect_timer;
192 pencil_collect_timer.start();
193
194 constexpr int elements = sizeof(T) / sizeof(cmplx_t);
195
196 // Build vector offset, which encodes where the data should be written
197 // elem_offset is the same for the offset of the elements of T
198 CoordinateVector offset, nmin;
199
200 const size_t elem_offset =
201 pencil_get_buffer_offsets(dir, sizeof(T) / sizeof(cmplx_t), offset, nmin);
202
203 cmplx_t *sb = send_buf;
204
205 // and collect the data
206#pragma hila novector direct_access(sb)
207 onsites(ALL) {
208
209 T_union<T, cmplx_t> v;
210 v.val = f[X];
211 int off = offset.dot(X.coordinates() - nmin);
212 for (int i = 0; i < elements; i++) {
213 sb[off + i * elem_offset] = v.c[i];
214 }
215 }
216
217 pencil_collect_timer.stop();
218 }
219
220 /// Inverse of the fft_collect_data: write fft'd data from receive_buf to field.
221
222 template <typename T>
224
225 constexpr int elements = sizeof(T) / sizeof(cmplx_t);
226
227 extern hila::timer pencil_save_timer;
228 pencil_save_timer.start();
229
230 // Build vector offset, which encodes where the data should be written
231 CoordinateVector offset, nmin;
232
233 const size_t elem_offset = pencil_get_buffer_offsets(dir, elements, offset, nmin);
234
235 cmplx_t *rb = receive_buf;
236
237// and collect the data from buffers
238#pragma hila novector direct_access(rb)
239 onsites(ALL) {
240
241 T_union<T, cmplx_t> v;
242
243 size_t off = offset.dot(X.coordinates() - nmin);
244 for (int i = 0; i < elements; i++) {
245 v.c[i] = rb[off + i * elem_offset];
246 }
247 f[X] = v.val;
248 }
249
250 pencil_save_timer.stop();
251 }
252
253 /////////////////////////////////////////////////////////////////////////////
254 /// Reshuffle data, given that previous fft dir was to prev_dir and now to dir
255 /// Assuming here that the data is in receive_buf after fft and copy to send_buf
256 /// This requires swapping send_buf and receive_buf ptrs after 1 fft
257
258 void reshuffle_data(Direction prev_dir) {
259
260 extern hila::timer pencil_reshuffle_timer;
261 pencil_reshuffle_timer.start();
262
263 int elem = elements;
264
265 CoordinateVector offset_in, offset_out, nmin;
266
267 const size_t e_offset_in = pencil_get_buffer_offsets(prev_dir, elements, offset_in, nmin);
268 const size_t e_offset_out = pencil_get_buffer_offsets(dir, elements, offset_out, nmin);
269
270 cmplx_t *sb = send_buf;
271 cmplx_t *rb = receive_buf;
272
273#pragma hila novector direct_access(sb, rb)
274 onsites(ALL) {
275 CoordinateVector v = X.coordinates() - nmin;
276 size_t off_in = offset_in.dot(v);
277 size_t off_out = offset_out.dot(v);
278 for (int e = 0; e < elem; e++) {
279 sb[off_out + e * e_offset_out] = rb[off_in + e * e_offset_in];
280 }
281 }
282
283 pencil_reshuffle_timer.stop();
284 }
285
286 // free the work buffers
287 void cleanup() {}
288
289 // just swap the buf pointers
290 void swap_buffers() {
291 std::swap(send_buf, receive_buf);
292 }
293
294 // communication functions for slicing the lattice
295 void scatter_data();
296 void gather_data();
297
298 ////////////////////////////////////////////////////////////////////////
299 /// Do the transform itself (fft or reflect only)
300
301 template <typename T>
302 void full_transform(const Field<T> &input, Field<T> &result,
303 const CoordinateVector &directions) {
304
305 // Make sure the result is allocated and mark it changed
306 result.check_alloc();
307
308 bool first_dir = true;
309 Direction prev_dir;
310
311 foralldir(dir) {
312 if (directions[dir]) {
313
314 setup_direction(dir);
315
316 if (first_dir) {
317 collect_data(input);
318 // in_p = &result;
319 } else {
320 reshuffle_data(prev_dir);
321 }
322
323 gather_data();
324
325 if (!only_reflect)
326 transform();
327 else
328 reflect();
329
330 scatter_data();
331
332 cleanup();
333
334 // fft_save_result( result, dir, receive_buf );
335
336 prev_dir = dir;
337 first_dir = false;
338
339 // swap the pointers
340 swap_buffers();
341 }
342 }
343
344 save_result(result);
345
346 result.mark_changed(ALL);
347 }
348};
349
350// prototype for plan deletion
351void FFT_delete_plans();
352
353
354// Implementation dependent core fft collect and transforms are defined here
355
356#if defined(USE_FFTW)
357
358#include "plumbing/fft_fftw_transform.h"
359
360#elif defined(CUDA) || defined(HIP)
361
362#include "plumbing/backend_gpu/fft_gpu_transform.h"
363
364#endif
365
366/////////////////////////////////////////////////////////////////////////////////////////
367/// Complex-to-complex FFT transform of a field input, result in result.
368/// input and result can be same, "in-place".
369/// Both input and output are of type Field<T>, where T must contain complex type,
370/// Complex<float> or Complex<double>.
371/// directions: if directions[dir] == false (or 0), transform is not done to
372/// direction dir. fftdir: direction of the transform itself:
373/// fft_direction::forward (default) x -> k
374/// fft_direction::inverse k-> x
375/// FFT is unnormalized: transform + inverse transform yields source multiplied
376/// by the product of the size of the lattice to active directions
377/// If all directions are active, result = source * lattice.volume():
378/////////////////////////////////////////////////////////////////////////////////////////
379
380template <typename T>
381inline void FFT_field(const Field<T> &input, Field<T> &result, const CoordinateVector &directions,
382 fft_direction fftdir = fft_direction::forward) {
383
384 static_assert(hila::contains_complex<T>::value,
385 "FFT_field argument fields must contain complex type");
386
387 // get the type of the complex number here
388 using cmplx_t = Complex<hila::arithmetic_type<T>>;
389 constexpr size_t elements = sizeof(T) / sizeof(cmplx_t);
390
391 extern hila::timer fft_timer;
392 fft_timer.start();
393
394 hila_fft<cmplx_t> fft(elements, fftdir);
395
396 fft.full_transform(input, result, directions);
397
398 fft_timer.stop();
399}
400
401//////////////////////////////////////////////////////////////////////////////////
402///
403/// Complex-to-complex FFT transform of a field input, result in result.
404/// Same as FFT_field(input,result,directions,fftdir)
405/// with all directions active.
406///
407//////////////////////////////////////////////////////////////////////////////////
408
409template <typename T>
410inline void FFT_field(const Field<T> &input, Field<T> &result,
411 fft_direction fftdir = fft_direction::forward) {
412
413 CoordinateVector dirs;
414 dirs.fill(true); // set all directions OK
415
416 FFT_field(input, result, dirs, fftdir);
417}
418
419
420/**
421 * @brief Field method for performing FFT
422 * @details
423 * By default calling without arguments will execute FFT in all directions.
424 * @code{.cpp}
425 * .
426 * . // Field f is defined
427 * .
428 * auto res = f.FFT() //Forward transform
429 * auto res_2 = res.FFT(fft_direction::back) // res_2 is same as f
430 * @endcode
431 *
432 * One can also specify the direction of the FFT with a coordinate vector:
433 * @code{.cpp}
434 * .
435 * . // Field f is defined
436 * .
437 * auto res = f.FFT(e_x) //Forward transform in x-direction
438 * auto res_2 = res.FFT(e_X,fft_direction::back) // res_2 is same as f
439 * @endcode
440 *
441 * With this in mind `f.FFT(e_x+e_y+e_z) = f.FFT()`
442 *
443 * @tparam T
444 * @param dirs Direction to perform FFT in, default is all directions
445 * @param fftdir fft_direction::forward (default) or fft_direction::back
446 * @return Field<T> Transformed field
447 */
448template <typename T>
450 Field<T> res;
451 FFT_field(*this, res, dirs, fftdir);
452 return res;
453}
454
455template <typename T>
458 cv.fill(true);
459 Field<T> res;
460 FFT_field(*this, res, cv, fftdir);
461 return res;
462}
463
464
465//////////////////////////////////////////////////////////////////////////////////
466/// FFT_real_to_complex:
467/// Field must be a real-valued field, result is a complex-valued field of the same type
468/// Implemented just by doing a FFT with a complex field with im=0;
469/// fft_direction::back gives a complex conjugate of the forward transform
470/// Result is f(-x) = f(L - x) = f(x)^*
471//////////////////////////////////////////////////////////////////////////////////
472
473template <typename T>
475
476 static_assert(hila::is_arithmetic<T>::value,
477 "FFT_real_to_complex can be applied only to Field<real-type> variable");
478
480 cf[ALL] = Complex<T>((*this)[X], 0.0);
481 return cf.FFT(fftdir);
482}
483
484//////////////////////////////////////////////////////////////////////////////////
485/// FFT_complex_to_real;
486/// Field must be a complex-valued field, result is a real field of the same number type
487/// Not optimized, should not be used on a hot path
488///
489/// Because the complex field must have the property f(-x) = f(L-x) = f(x)^*, only
490/// half of the values in input field are significant, the routine does the appropriate
491/// symmetrization.
492///
493/// Routine hila::FFT_complex_to_real_site(CoordinateVector cv) gives the significant values at
494/// location cv:
495/// = +1 significant complex value,
496/// = 0 significant real part, imag ignored
497/// = -1 value ignored here
498/// Example: in 2d 8x8 lattice the sites are: (* = (0,0), value 0)
499///
500/// - + + + - - - - - - - 0 + + + 0
501/// - + + + - - - - - - - + + + + +
502/// - + + + - - - - after centering - - - + + + + +
503/// 0 + + + 0 - - - (0,0) to center - - - + + + + +
504/// + + + + + - - - -----------------> - - - * + + + 0
505/// + + + + + - - - - - - - + + + -
506/// + + + + + - - - - - - - + + + -
507/// * + + + 0 - - - - - - - + + + -
508///
509//////////////////////////////////////////////////////////////////////////////////
510
511namespace hila {
512inline int FFT_complex_to_real_site(const CoordinateVector &cv) {
513
514 // foralldir continues only if cv[d] == 0 or cv[d] == size(d)/2
515 foralldir(d) {
516 if (cv[d] > 0 && cv[d] < lattice.size(d) / 2)
517 return 1;
518 if (cv[d] > lattice.size(d) / 2)
519 return -1;
520 }
521 // we get here only if all coords are 0 or size(d)/2
522 return 0;
523}
524
525} // namespace hila
526
527template <typename T>
529
530 static_assert(hila::is_complex<T>::value,
531 "FFT_complex_to_real can be applied only to Field<Complex<>> type variable");
532
533 foralldir(d) {
534 assert(lattice.size(d) % 2 == 0 &&
535 "FFT_complex_to_real works only with even lattice size to all directions");
536 }
537
538 // first, do a full reflection of the field, giving rf(x) = f(L-x) = "f(-x)"
539 auto rf = this->reflect();
540 // And symmetrize the field appropriately - can use rf
541 onsites(ALL) {
542 int type = hila::FFT_complex_to_real_site(X.coordinates());
543 if (type == 1) {
544 rf[X] = (*this)[X];
545 } else if (type == -1) {
546 rf[X] = rf[X].conj();
547 } else {
548 rf[X].real() = (*this)[X].real();
549 rf[X].imag() = 0;
550 }
551 }
552
553 FFT_field(rf, rf, fftdir);
554
555 double ims = 0;
556 double rss = 0;
557 onsites(ALL) {
558 ims += ::squarenorm(rf[X].imag());
559 rss += ::squarenorm(rf[X].real());
560 }
561
563 onsites(ALL) res[X] = rf[X].real();
564 return res;
565}
566
567
568//////////////////////////////////////////////////////////////////////////////////
569/// Field<T>::reflect() reflects the field around the desired axis
570/// This is here because it uses similar communications as fft
571/// TODO: refactorise so that there is separate "make columns" class!
572
573/**
574 * @brief Reflect the Field around the desired axis
575 * @details Can be called in the following ways:
576 *
577 * __Reflect on all axes:__
578 *
579 * \code {.cpp}
580 * Field<MyType> f;
581 * .
582 * .
583 * .
584 * f.reflect()
585 * \endcode
586 *
587 * @todo refactorise so that there is separate "make columns" class!
588 * @tparam T
589 * @param dirs
590 * @return Field<T>
591 */
592
593template <typename T>
595
596 constexpr int elements = 1;
597
598 Field<T> result;
599
600 hila_fft<T> refl(elements, fft_direction::forward, true);
601 refl.full_transform(*this, result, dirs);
602
603 return result;
604}
605
606template <typename T>
608
610 c.fill(true);
611 return reflect(c);
612}
613
614template <typename T>
616
618 c.fill(false);
619 c[dir] = true;
620 return reflect(c);
621}
622
623
624#endif
Array< n, m, hila::arithmetic_type< T > > imag(const Array< n, m, T > &arg)
Return imaginary part of Array.
Definition array.h:703
hila::arithmetic_type< T > squarenorm(const Array< n, m, T > &rhs)
Return square norm of Array.
Definition array.h:1019
Array< n, m, hila::arithmetic_type< T > > real(const Array< n, m, T > &arg)
Return real part of Array.
Definition array.h:689
Complex definition.
Definition cmplx.h:58
Vector< 4, double > convert_to_k() const
Convert momentum space CoordinateVector to wave number k, where -pi/2 < k_i <= pi_2 Utility function ...
Definition fft.h:36
The field class implements the standard methods for accessing Fields. Hilapp replaces the parity acce...
Definition field.h:62
Field< A > real() const
Returns real part of Field.
Definition field.h:1202
Field< Complex< hila::arithmetic_type< T > > > FFT_real_to_complex(fft_direction fdir=fft_direction::forward) const
Definition fft.h:474
void check_alloc()
Allocate Field if it is not already allocated.
Definition field.h:458
int size(Direction d) const
lattice.size() -> CoordinateVector or lattice.size(d) -> int returns the dimensions of the lattice,...
Definition lattice.h:433
const auto & fill(const S rhs)
Matrix fill.
Definition matrix.h:1022
Matrix class which defines matrix operations.
Definition matrix.h:1742
Definition fft.h:88
void gather_data()
send column data to nodes
void setup_direction(Direction _dir)
Initialize fft to Direction dir.
Definition fft.h:152
void collect_data(const Field< T > &f)
Definition fft.h:189
void scatter_data()
inverse of gather_data
void reshuffle_data(Direction prev_dir)
Definition fft.h:258
void transform()
transform does the actual fft.
void save_result(Field< T > &f)
Inverse of the fft_collect_data: write fft'd data from receive_buf to field.
Definition fft.h:223
void full_transform(const Field< T > &input, Field< T > &result, const CoordinateVector &directions)
Do the transform itself (fft or reflect only)
Definition fft.h:302
Definition of Complex types.
This header file defines:
int pmod(const int a, const int b)
Positive mod(): mods the result so that 0 <= a % b < b.
#define foralldir(d)
Macro to loop over (all) Direction(s)
Definition coordinates.h:80
Direction
Enumerator for direction that assigns integer to direction to be interpreted as unit vector.
Definition coordinates.h:34
constexpr Parity ALL
bit pattern: 011
This file defines all includes for HILA.
fft_direction
define a class for FFT direction
Definition defs.h:159
void init_pencil_direction(Direction d)
Initialize fft direction - defined in fft.cpp.
Definition fft.cpp:60
size_t pencil_get_buffer_offsets(const Direction dir, const size_t elements, CoordinateVector &offset, CoordinateVector &nmin)
Definition fft.cpp:38
void FFT_field(const Field< T > &input, Field< T > &result, const CoordinateVector &directions, fft_direction fftdir=fft_direction::forward)
Definition fft.h:381
This files containts definitions for the Field class and the classes required to define it such as fi...
Implement hila::swap for gauge fields.
Definition array.h:982
int myrank()
rank of this node
Definition com_mpi.cpp:237