6#if !defined(CUDA) && !defined(HIP)
14#include "plumbing/timing.h"
21#define WRK_GATHER_TAG 42
22#define WRK_SCATTER_TAG 43
39 int n =
pmod((*this).e(d), lattice.
size(d));
40 if (n > lattice.
size(d) / 2)
43 k[d] = n * 2.0 * M_PI / lattice.
size(d);
56 unsigned column_offset;
57 unsigned column_number;
63extern std::vector<pencil_struct> hila_pencil_comms[NDIM];
78template <
typename T,
typename cmplx_t>
81 cmplx_t c[
sizeof(T) /
sizeof(cmplx_t)];
87template <
typename cmplx_t>
102 std::vector<cmplx_t *> rec_p;
103 std::vector<int> rec_size;
107 extern size_t pencil_recv_buf_size[NDIM];
109 elements = _elements;
111 only_reflect = _reflect;
113 local_volume = lattice->mynode.volume;
120 if (pencil_recv_buf_size[d] > buf_size)
121 buf_size = pencil_recv_buf_size[d];
123 if (buf_size < local_volume)
124 buf_size = local_volume;
127 send_buf = (cmplx_t *)d_malloc(buf_size *
sizeof(cmplx_t) * elements);
128 receive_buf = (cmplx_t *)d_malloc(buf_size *
sizeof(cmplx_t) * elements);
159 rec_p.resize(hila_pencil_comms[dir].size());
160 rec_size.resize(hila_pencil_comms[dir].size());
162 cmplx_t *p = receive_buf;
164 for (pencil_struct &fn : hila_pencil_comms[dir]) {
170 rec_size[i] = fn.size_to_dir;
171 p += fn.recv_buf_size * elements;
177 rec_p[i] = send_buf + fn.column_offset * elements;
178 rec_size[i] = fn.size_to_dir;
188 template <
typename T>
192 pencil_collect_timer.start();
194 constexpr int elements =
sizeof(T) /
sizeof(cmplx_t);
200 const size_t elem_offset =
203 cmplx_t *sb = send_buf;
206#pragma hila novector direct_access(sb)
209 T_union<T, cmplx_t> v;
211 int off = offset.dot(X.coordinates() - nmin);
212 for (
int i = 0; i < elements; i++) {
213 sb[off + i * elem_offset] = v.c[i];
217 pencil_collect_timer.stop();
222 template <
typename T>
225 constexpr int elements =
sizeof(T) /
sizeof(cmplx_t);
228 pencil_save_timer.start();
235 cmplx_t *rb = receive_buf;
238#pragma hila novector direct_access(rb)
241 T_union<T, cmplx_t> v;
243 size_t off = offset.dot(X.coordinates() - nmin);
244 for (
int i = 0; i < elements; i++) {
245 v.c[i] = rb[off + i * elem_offset];
250 pencil_save_timer.stop();
261 pencil_reshuffle_timer.start();
270 cmplx_t *sb = send_buf;
271 cmplx_t *rb = receive_buf;
273#pragma hila novector direct_access(sb, rb)
276 size_t off_in = offset_in.dot(v);
277 size_t off_out = offset_out.dot(v);
278 for (
int e = 0; e < elem; e++) {
279 sb[off_out + e * e_offset_out] = rb[off_in + e * e_offset_in];
283 pencil_reshuffle_timer.stop();
290 void swap_buffers() {
291 std::swap(send_buf, receive_buf);
301 template <
typename T>
308 bool first_dir =
true;
312 if (directions[dir]) {
346 result.mark_changed(
ALL);
351void FFT_delete_plans();
358#include "plumbing/fft_fftw_transform.h"
360#elif defined(CUDA) || defined(HIP)
362#include "plumbing/backend_gpu/fft_gpu_transform.h"
384 static_assert(hila::contains_complex<T>::value,
385 "FFT_field argument fields must contain complex type");
389 constexpr size_t elements =
sizeof(T) /
sizeof(cmplx_t);
477 "FFT_real_to_complex can be applied only to Field<real-type> variable");
481 return cf.FFT(fftdir);
516 if (cv[d] > 0 && cv[d] < lattice.
size(d) / 2)
518 if (cv[d] > lattice.
size(d) / 2)
530 static_assert(hila::is_complex<T>::value,
531 "FFT_complex_to_real can be applied only to Field<Complex<>> type variable");
534 assert(lattice.
size(d) % 2 == 0 &&
535 "FFT_complex_to_real works only with even lattice size to all directions");
539 auto rf = this->reflect();
542 int type = hila::FFT_complex_to_real_site(X.coordinates());
545 }
else if (type == -1) {
546 rf[X] = rf[X].conj();
548 rf[X].real() = (*this)[X].real();
563 onsites(
ALL) res[X] = rf[X].
real();
596 constexpr int elements = 1;
600 hila_fft<T> refl(elements, fft_direction::forward,
true);
Array< n, m, hila::arithmetic_type< T > > imag(const Array< n, m, T > &arg)
Return imaginary part of Array.
hila::arithmetic_type< T > squarenorm(const Array< n, m, T > &rhs)
Return square norm of Array.
Array< n, m, hila::arithmetic_type< T > > real(const Array< n, m, T > &arg)
Return real part of Array.
Vector< 4, double > convert_to_k() const
Convert momentum space CoordinateVector to wave number k, where -pi/2 < k_i <= pi_2 Utility function ...
The field class implements the standard methods for accessing Fields. Hilapp replaces the parity acce...
Field< A > real() const
Returns real part of Field.
Field< Complex< hila::arithmetic_type< T > > > FFT_real_to_complex(fft_direction fdir=fft_direction::forward) const
void check_alloc()
Allocate Field if it is not already allocated.
int size(Direction d) const
lattice.size() -> CoordinateVector or lattice.size(d) -> int returns the dimensions of the lattice,...
const auto & fill(const S rhs)
Matrix fill.
Matrix class which defines matrix operations.
void gather_data()
send column data to nodes
void setup_direction(Direction _dir)
Initialize fft to Direction dir.
void collect_data(const Field< T > &f)
void scatter_data()
inverse of gather_data
void reshuffle_data(Direction prev_dir)
void transform()
transform does the actual fft.
void save_result(Field< T > &f)
Inverse of the fft_collect_data: write fft'd data from receive_buf to field.
void full_transform(const Field< T > &input, Field< T > &result, const CoordinateVector &directions)
Do the transform itself (fft or reflect only)
Definition of Complex types.
This header file defines:
int pmod(const int a, const int b)
Positive mod(): mods the result so that 0 <= a % b < b.
#define foralldir(d)
Macro to loop over (all) Direction(s)
Direction
Enumerator for direction that assigns integer to direction to be interpreted as unit vector.
constexpr Parity ALL
bit pattern: 011
This file defines all includes for HILA.
fft_direction
define a class for FFT direction
void init_pencil_direction(Direction d)
Initialize fft direction - defined in fft.cpp.
size_t pencil_get_buffer_offsets(const Direction dir, const size_t elements, CoordinateVector &offset, CoordinateVector &nmin)
void FFT_field(const Field< T > &input, Field< T > &result, const CoordinateVector &directions, fft_direction fftdir=fft_direction::forward)
This files containts definitions for the Field class and the classes required to define it such as fi...
Implement hila::swap for gauge fields.
int myrank()
rank of this node