1#ifndef FFTW_TRANSFORM_H 
    2#define FFTW_TRANSFORM_H 
   11template <
typename cmplx_t>
 
   13    assert(0 && 
"Don't call this!");
 
   18    extern unsigned hila_fft_my_columns[NDIM];
 
   19    extern hila::timer fft_plan_timer, fft_buffer_timer, fft_execute_timer;
 
   21    size_t n_fft = hila_fft_my_columns[dir] * elements;
 
   24        (fftdir == fft_direction::forward) ? FFTW_FORWARD : FFTW_BACKWARD;
 
   26    fft_plan_timer.start();
 
   31    fftw_complex *fftwbuf =
 
   32        (fftw_complex *)fftw_malloc(
sizeof(fftw_complex) * lattice.
size(dir));
 
   33    fftw_plan fftwplan = fftw_plan_dft_1d(lattice.
size(dir), fftwbuf, fftwbuf,
 
   34                                          transform_dir, FFTW_ESTIMATE);
 
   36    fft_plan_timer.stop();
 
   38    for (
size_t i = 0; i < n_fft; i++) {
 
   41        fft_buffer_timer.start();
 
   43        fftw_complex *cp = fftwbuf;
 
   44        for (
int j = 0; j < rec_p.size(); j++) {
 
   45            memcpy(cp, rec_p[j] + i * rec_size[j], 
sizeof(fftw_complex) * rec_size[j]);
 
   49        fft_buffer_timer.stop();
 
   52        fft_execute_timer.start();
 
   54        fftw_execute(fftwplan);
 
   56        fft_execute_timer.stop();
 
   58        fft_buffer_timer.start();
 
   61        for (
int j = 0; j < rec_p.size(); j++) {
 
   62            memcpy(rec_p[j] + i * rec_size[j], cp, 
sizeof(fftw_complex) * rec_size[j]);
 
   66        fft_buffer_timer.stop();
 
   69    fftw_destroy_plan(fftwplan);
 
   76    extern hila::timer fft_plan_timer, fft_buffer_timer, fft_execute_timer;
 
   77    extern unsigned hila_fft_my_columns[NDIM];
 
   79    size_t n_fft = hila_fft_my_columns[dir] * elements;
 
   82        (fftdir == fft_direction::forward) ? FFTW_FORWARD : FFTW_BACKWARD;
 
   84    fft_plan_timer.start();
 
   89    fftwf_complex *fftwbuf =
 
   90        (fftwf_complex *)fftwf_malloc(
sizeof(fftwf_complex) * lattice.
size(dir));
 
   91    fftwf_plan fftwplan = fftwf_plan_dft_1d(lattice.
size(dir), fftwbuf, fftwbuf,
 
   92                                            transform_dir, FFTW_ESTIMATE);
 
   94    fft_plan_timer.stop();
 
   96    for (
size_t i = 0; i < n_fft; i++) {
 
   99        fft_buffer_timer.start();
 
  101        fftwf_complex *cp = fftwbuf;
 
  102        for (
int j = 0; j < rec_p.size(); j++) {
 
  103            memcpy(cp, rec_p[j] + i * rec_size[j], 
sizeof(fftwf_complex) * rec_size[j]);
 
  107        fft_buffer_timer.stop();
 
  110        fft_execute_timer.start();
 
  112        fftwf_execute(fftwplan);
 
  114        fft_execute_timer.stop();
 
  116        fft_buffer_timer.start();
 
  119        for (
int j = 0; j < rec_p.size(); j++) {
 
  120            memcpy(rec_p[j] + i * rec_size[j], cp, 
sizeof(fftwf_complex) * rec_size[j]);
 
  124        fft_buffer_timer.stop();
 
  127    fftwf_destroy_plan(fftwplan);
 
  134template <
typename cmplx_t>
 
  139    pencil_MPI_timer.start();
 
  142    int n_comms = hila_pencil_comms[dir].size() - 1;
 
  144    std::vector<MPI_Request> sendreq(n_comms), recreq(n_comms);
 
  145    std::vector<MPI_Status> stat(n_comms);
 
  149    for (
auto &fn : hila_pencil_comms[dir]) {
 
  152            size_t siz = fn.recv_buf_size * elements * 
sizeof(cmplx_t);
 
  153            if (siz >= (1ULL << 31)) {
 
  154                hila::out << 
"Too large MPI message in pencils! Size " << siz
 
  159            MPI_Irecv(rec_p[j], (
int)siz, MPI_BYTE, fn.node, WRK_GATHER_TAG,
 
  160                      lattice->mpi_comm_lat, &recreq[i]);
 
  168    for (
auto &fn : hila_pencil_comms[dir]) {
 
  171            cmplx_t *p = send_buf + fn.column_offset * elements;
 
  172            int n = fn.column_number * elements * lattice->mynode.size[dir] *
 
  175            MPI_Isend(p, n, MPI_BYTE, fn.node, WRK_GATHER_TAG, lattice->mpi_comm_lat,
 
  183        MPI_Waitall(n_comms, recreq.data(), stat.data());
 
  184        MPI_Waitall(n_comms, sendreq.data(), stat.data());
 
  187    pencil_MPI_timer.stop();
 
  194template <
typename cmplx_t>
 
  199    pencil_MPI_timer.start();
 
  201    int n_comms = hila_pencil_comms[dir].size() - 1;
 
  203    std::vector<MPI_Request> sendreq(n_comms), recreq(n_comms);
 
  204    std::vector<MPI_Status> stat(n_comms);
 
  208    for (
auto &fn : hila_pencil_comms[dir]) {
 
  210            cmplx_t *p = send_buf + fn.column_offset * elements;
 
  211            int n = fn.column_number * elements * lattice->mynode.size[dir] * 
sizeof(cmplx_t);
 
  213            MPI_Irecv(p, n, MPI_BYTE, fn.node, WRK_SCATTER_TAG,
 
  214                      lattice->mpi_comm_lat, &recreq[i]);
 
  222    for (
auto &fn : hila_pencil_comms[dir]) {
 
  225            MPI_Isend(rec_p[j], (
int)(fn.recv_buf_size * elements * 
sizeof(cmplx_t)), MPI_BYTE, fn.node,
 
  226                      WRK_SCATTER_TAG, lattice->mpi_comm_lat, &sendreq[i]);
 
  235        MPI_Waitall(n_comms, recreq.data(), stat.data());
 
  236        MPI_Waitall(n_comms, sendreq.data(), stat.data());
 
  239    pencil_MPI_timer.stop();
 
  251template <
typename cmplx_t>
 
  253    extern unsigned hila_fft_my_columns[NDIM];
 
  254    extern hila::timer fft_plan_timer, fft_buffer_timer, fft_execute_timer;
 
  256    const int ncols = hila_fft_my_columns[dir] * elements;
 
  258    const int length = lattice.
size(dir);
 
  260    cmplx_t *buf = (cmplx_t *)memalloc(
sizeof(cmplx_t) * length);
 
  262    for (
int i = 0; i < ncols; i++) {
 
  266        for (
int j = 0; j < rec_p.size(); j++) {
 
  267            memcpy(cp, rec_p[j] + i * rec_size[j], 
sizeof(cmplx_t) * rec_size[j]);
 
  272        for (
int j = 1; j < length / 2; j++) {
 
  273            std::swap(buf[j], buf[length - j]);
 
  277        for (
int j = 0; j < rec_p.size(); j++) {
 
  278            memcpy(rec_p[j] + i * rec_size[j], cp, 
sizeof(cmplx_t) * rec_size[j]);
 
int size(Direction d) const
lattice.size() -> CoordinateVector or lattice.size(d) -> int returns the dimensions of the lattice,...
void gather_data()
send column data to nodes
void scatter_data()
inverse of gather_data
void transform()
transform does the actual fft.
int myrank()
rank of this node
std::ostream out
this is our default output file stream
void terminate(int status)