1#ifndef HILA_REDUCTION_H_
2#define HILA_REDUCTION_H_
21 bool comm_is_on =
false;
24 bool is_allreduce_ =
true;
25 bool is_nonblocking_ =
false;
26 bool is_delayed_ =
false;
28 bool delay_is_on =
false;
29 bool is_delayed_sum =
true;
35 void do_reduce_operation(MPI_Op operation) {
45 dtype = get_MPI_number_type<T>();
47 assert(dtype != MPI_BYTE &&
"Unknown number_type in reduction");
51 reduction_timer.start();
53 if (is_nonblocking()) {
54 MPI_Iallreduce(MPI_IN_PLACE, ptr,
sizeof(T) /
sizeof(hila::arithmetic_type<T>),
55 dtype, operation, lattice->mpi_comm_lat, &request);
57 MPI_Allreduce(MPI_IN_PLACE, ptr,
sizeof(T) /
sizeof(hila::arithmetic_type<T>),
58 dtype, operation, lattice->mpi_comm_lat);
62 if (is_nonblocking()) {
63 MPI_Ireduce(MPI_IN_PLACE, ptr,
sizeof(T) /
sizeof(hila::arithmetic_type<T>),
64 dtype, operation, 0, lattice->mpi_comm_lat, &request);
66 MPI_Reduce(MPI_IN_PLACE, ptr,
sizeof(T) /
sizeof(hila::arithmetic_type<T>),
67 dtype, operation, 0, lattice->mpi_comm_lat);
70 if (is_nonblocking()) {
71 MPI_Ireduce(ptr, ptr,
sizeof(T) /
sizeof(hila::arithmetic_type<T>), dtype,
72 operation, 0, lattice->mpi_comm_lat, &request);
74 MPI_Reduce(ptr, ptr,
sizeof(T) /
sizeof(hila::arithmetic_type<T>), dtype,
75 operation, 0, lattice->mpi_comm_lat);
79 reduction_timer.stop();
87 reduction_wait_timer.start();
89 MPI_Wait(&request, &status);
90 reduction_wait_timer.stop();
104 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value,
int> = 0>
128 bool is_allreduce() {
129 return is_allreduce_;
137 bool is_nonblocking() {
138 return is_nonblocking_;
158 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value,
int> = 0>
164 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value,
int> = 0>
183 template <
typename S,
184 std::enable_if_t<hila::is_assignable<T &, hila::type_plus<T, S>>
::value,
int> = 0>
205 void reduce_sum_node(
const T &v) {
216 if (delay_is_on && is_delayed_sum ==
false) {
217 assert(0 &&
"Cannot mix sum and product reductions!");
220 is_delayed_sum =
true;
222 do_reduce_operation(MPI_SUM);
258 do_reduce_operation(MPI_SUM);
260 do_reduce_operation(MPI_PROD);
291 onsites(par) result += (*this)[X];
292 return result.
value();
297 static_assert(std::is_arithmetic<T>::value,
298 ".product() reduction only for integer or floating point types");
302 onsites(par) result *= (*this)[X];
303 return result.
value();
313#if defined(CUDA) || defined(HIP)
315#include "backend_gpu/gpu_minmax.h"
321 static_assert(std::is_same<T, int>::value || std::is_same<T, long>::value ||
322 std::is_same<T, float>::value || std::is_same<T, double>::value ||
323 std::is_same<T, long double>::value,
324 "In Field .min() and .max() methods the Field element type must be one of "
325 "(int/long/float/double/long double)");
327#if defined(CUDA) || defined(HIP)
328 T val = gpu_minmax(is_min, par, loc);
330 int sgn = is_min ? 1 : -1;
332 T val = is_min ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
336#pragma omp parallel shared(val, loc, sgn, is_min)
339 T val_th = is_min ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
343#pragma hila novector omp_parallel_region direct_access(loc_th, val_th)
345 if (sgn * (*
this)[X] < sgn * val_th) {
347 loc_th = X.coordinates();
352 if (sgn * val_th < sgn * val) {
362 MPI_Datatype dtype = get_MPI_number_type<T>(size,
true);
369 static_assert(
sizeof(T) %
sizeof(
int) == 0,
370 "min/max reduction: datatype struct not packed!");
377 MPI_Allreduce(MPI_IN_PLACE, &rdata, 1, dtype, MPI_MINLOC, lattice->mpi_comm_lat);
379 MPI_Allreduce(MPI_IN_PLACE, &rdata, 1, dtype, MPI_MAXLOC, lattice->mpi_comm_lat);
384 MPI_Bcast(&loc,
sizeof(
CoordinateVector), MPI_BYTE, rdata.rank, lattice->mpi_comm_lat);
395 return minmax(
true, par, loc);
401 return minmax(
true,
ALL, loc);
407 return minmax(
true, par, loc);
415 return minmax(
false, par, loc);
421 return minmax(
false,
ALL, loc);
427 return minmax(
false, par, loc);
T max(Parity par=ALL) const
Find maximum value from Field.
T product(Parity par=Parity::all, bool allreduce=true) const
Product reduction of Field.
T minmax(bool is_min, Parity par, CoordinateVector &loc) const
Function to perform min or max operations.
T min(Parity par=ALL) const
Find minimum value from Field.
T sum(Parity par=Parity::all, bool allreduce=true) const
Sum reduction of Field.
Special reduction class: enables delayed and non-blocking reductions, which are not possible with the...
const T value()
Return value of the reduction variable. Wait for the comms if needed.
~Reduction()
Destructor cleans up communications if they are in progress.
T operator=(const S &rhs)
Assignment is used only outside site loops - drop comms if on, no need to wait.
void operator+=(const S &rhs)
Reduction(const Reduction< T > &r)=delete
Reduction & allreduce(bool b=true)
allreduce(bool) turns allreduce on or off. By default on.
Reduction & nonblocking(bool b=true)
nonblocking(bool) turns allreduce on or off. By default on.
Reduction & delayed(bool b=true)
deferred(bool) turns deferred on or off. By default turns on.
void set(const S &rhs)
Method set is the same as assignment, but without return value.
void reduce()
Complete the reduction - start if not done, and wait if ongoing.
void start_reduce()
For delayed reduction, start_reduce starts or completes the reduction operation.
Parity
Parity enum with values EVEN, ODD, ALL; refers to parity of the site. Parity of site (x,...
constexpr Parity ALL
bit pattern: 011
int myrank()
rank of this node
int number_of_nodes()
how many nodes there are