21 bool comm_is_on =
false;
24 bool is_allreduce_ =
true;
25 bool is_nonblocking_ =
false;
26 bool is_delayed_ =
false;
28 bool delay_is_on =
false;
29 bool is_delayed_sum =
true;
35 void do_reduce_operation(MPI_Op operation) {
45 dtype = get_MPI_number_type<T>();
47 assert(dtype != MPI_BYTE &&
"Unknown number_type in reduction");
51 reduction_timer.start();
53 if (is_nonblocking()) {
54 MPI_Iallreduce(MPI_IN_PLACE, ptr,
55 sizeof(T) /
sizeof(hila::arithmetic_type<T>), dtype,
56 operation, lattice.mpi_comm_lat, &request);
58 MPI_Allreduce(MPI_IN_PLACE, ptr,
59 sizeof(T) /
sizeof(hila::arithmetic_type<T>), dtype,
60 operation, lattice.mpi_comm_lat);
64 if (is_nonblocking()) {
65 MPI_Ireduce(MPI_IN_PLACE, ptr,
66 sizeof(T) /
sizeof(hila::arithmetic_type<T>), dtype,
67 operation, 0, lattice.mpi_comm_lat, &request);
69 MPI_Reduce(MPI_IN_PLACE, ptr,
70 sizeof(T) /
sizeof(hila::arithmetic_type<T>), dtype,
71 operation, 0, lattice.mpi_comm_lat);
74 if (is_nonblocking()) {
75 MPI_Ireduce(ptr, ptr,
sizeof(T) /
sizeof(hila::arithmetic_type<T>),
76 dtype, operation, 0, lattice.mpi_comm_lat, &request);
78 MPI_Reduce(ptr, ptr,
sizeof(T) /
sizeof(hila::arithmetic_type<T>),
79 dtype, operation, 0, lattice.mpi_comm_lat);
83 reduction_timer.stop();
91 reduction_wait_timer.start();
93 MPI_Wait(&request, &status);
94 reduction_wait_timer.stop();
108 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value,
int> = 0>
125 MPI_Cancel(&request);
134 bool is_allreduce() {
135 return is_allreduce_;
143 bool is_nonblocking() {
144 return is_nonblocking_;
164 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value,
int> = 0>
170 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value,
int> = 0>
173 MPI_Cancel(&request);
190 template <
typename S,
191 std::enable_if_t<hila::is_assignable<T &, hila::type_plus<T, S>>
::value,
213 void reduce_sum_node(
const T &v) {
224 if (delay_is_on && is_delayed_sum ==
false) {
225 assert(0 &&
"Cannot mix sum and product reductions!");
228 is_delayed_sum =
true;
230 do_reduce_operation(MPI_SUM);
266 do_reduce_operation(MPI_SUM);
268 do_reduce_operation(MPI_PROD);
301 result += (*this)[X];
302 return result.
value();
307 static_assert(std::is_arithmetic<T>::value,
308 ".product() reduction only for integer or floating point types");
313 result *= (*this)[X];
314 return result.
value();
324#if defined(CUDA) || defined(HIP)
325#include "backend_gpu/gpu_reduction.h"
332 std::is_same<T, int>::value || std::is_same<T, long>::value ||
333 std::is_same<T, float>::value || std::is_same<T, double>::value ||
334 std::is_same<T, long double>::value,
335 "In Field .min() and .max() methods the Field element type must be one of "
336 "(int/long/float/double/long double)");
338#if defined(CUDA) || defined(HIP)
339 T val = gpu_minmax(is_min, par, loc);
341 int sgn = is_min ? 1 : -1;
343 T val = is_min ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
347#pragma omp parallel shared(val, loc, sgn, is_min)
351 is_min ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
355#pragma hila novector omp_parallel_region direct_access(loc_th, val_th)
357 if (sgn * (*
this)[X] < sgn * val_th) {
359 loc_th = X.coordinates();
364 if (sgn * val_th < sgn * val) {
373 MPI_Datatype dtype = get_MPI_number_type<T>(size,
true);
385 MPI_Allreduce(MPI_IN_PLACE, &rdata, 1, dtype, MPI_MINLOC,
386 lattice.mpi_comm_lat);
388 MPI_Allreduce(MPI_IN_PLACE, &rdata, 1, dtype, MPI_MAXLOC,
389 lattice.mpi_comm_lat);
395 lattice.mpi_comm_lat);
406 return minmax(
true, par, loc);
412 return minmax(
true,
ALL, loc);
418 return minmax(
true, par, loc);
426 return minmax(
false, par, loc);
432 return minmax(
false,
ALL, loc);
438 return minmax(
false, par, loc);
T max(Parity par=ALL) const
Find maximum value from Field.
T product(Parity par=Parity::all, bool allreduce=true) const
Product reduction of Field.
T minmax(bool is_min, Parity par, CoordinateVector &loc) const
Function to perform min or max operations.
T min(Parity par=ALL) const
Find minimum value from Field.
T sum(Parity par=Parity::all, bool allreduce=true) const
Sum reduction of Field.
Special reduction class: enables delayed and non-blocking reductions, which are not possible with the...
const T value()
Return value of the reduction variable. Wait for the comms if needed.
~Reduction()
Destructor cleans up communications if they are in progress.
T operator=(const S &rhs)
Assignment is used only outside site loops - drop comms if on, no need to wait.
void operator+=(const S &rhs)
Reduction(const Reduction< T > &r)=delete
Reduction & allreduce(bool b=true)
allreduce(bool) turns allreduce on or off. By default on.
Reduction & nonblocking(bool b=true)
nonblocking(bool) turns allreduce on or off. By default on.
Reduction & delayed(bool b=true)
deferred(bool) turns deferred on or off. By default turns on.
void set(const S &rhs)
Method set is the same as assignment, but without return value.
void reduce()
Complete the reduction - start if not done, and wait if ongoing.
void start_reduce()
For delayed reduction, start_reduce starts or completes the reduction operation.
Parity
Parity enum with values EVEN, ODD, ALL; refers to parity of the site. Parity of site (x,...
constexpr Parity ALL
bit pattern: 011
int myrank()
rank of this node
int number_of_nodes()
how many nodes there are