HILA
Loading...
Searching...
No Matches
reduction.h
1#ifndef HILA_REDUCTION_H_
2#define HILA_REDUCTION_H_
3
4#include "hila.h"
5
6
7//////////////////////////////////////////////////////////////////////////////////
8/// @brief Special reduction class: enables delayed and non-blocking reductions, which
9/// are not possible with the standard reduction. See [user guide](@ref reduction_type_guide)
10/// for details
11///
12
13template <typename T>
14class Reduction {
15
16 private:
17 // value holder
18 T val;
19
20 // comm_is_on is true if non-blocking MPI communications are under way.
21 bool comm_is_on = false;
22
23 // Reduction status : is this allreduce, nonblocking, or delayed
24 bool is_allreduce_ = true;
25 bool is_nonblocking_ = false;
26 bool is_delayed_ = false;
27
28 bool delay_is_on = false; // status of the delayed reduction
29 bool is_delayed_sum = true; // sum/product
30
31 MPI_Request request;
32
33 // start the actual reduction
34
35 void do_reduce_operation(MPI_Op operation) {
36
37 // if for some reason reduction is going on unfinished, wait.
38 wait();
39
40 if (is_nonblocking())
41 comm_is_on = true;
42
43 MPI_Datatype dtype;
44
45 dtype = get_MPI_number_type<T>();
46
47 assert(dtype != MPI_BYTE && "Unknown number_type in reduction");
48
49 void *ptr = &val;
50
51 reduction_timer.start();
52 if (is_allreduce()) {
53 if (is_nonblocking()) {
54 MPI_Iallreduce(MPI_IN_PLACE, ptr, sizeof(T) / sizeof(hila::arithmetic_type<T>),
55 dtype, operation, lattice->mpi_comm_lat, &request);
56 } else {
57 MPI_Allreduce(MPI_IN_PLACE, ptr, sizeof(T) / sizeof(hila::arithmetic_type<T>),
58 dtype, operation, lattice->mpi_comm_lat);
59 }
60 } else {
61 if (hila::myrank() == 0) {
62 if (is_nonblocking()) {
63 MPI_Ireduce(MPI_IN_PLACE, ptr, sizeof(T) / sizeof(hila::arithmetic_type<T>),
64 dtype, operation, 0, lattice->mpi_comm_lat, &request);
65 } else {
66 MPI_Reduce(MPI_IN_PLACE, ptr, sizeof(T) / sizeof(hila::arithmetic_type<T>),
67 dtype, operation, 0, lattice->mpi_comm_lat);
68 }
69 } else {
70 if (is_nonblocking()) {
71 MPI_Ireduce(ptr, ptr, sizeof(T) / sizeof(hila::arithmetic_type<T>), dtype,
72 operation, 0, lattice->mpi_comm_lat, &request);
73 } else {
74 MPI_Reduce(ptr, ptr, sizeof(T) / sizeof(hila::arithmetic_type<T>), dtype,
75 operation, 0, lattice->mpi_comm_lat);
76 }
77 }
78 }
79 reduction_timer.stop();
80 }
81
82 /// Wait for MPI to complete, if it is currently going on
83 /// This must be called for non-blocking reduce before use!
84
85 void wait() {
86 if (comm_is_on) {
87 reduction_wait_timer.start();
88 MPI_Status status;
89 MPI_Wait(&request, &status);
90 reduction_wait_timer.stop();
91 comm_is_on = false;
92 }
93 }
94
95 public:
96 /// Initialize to zero by default (? exception to other variables)
97 /// allreduce = true by default
99 val = 0;
100 comm_is_on = false;
101 }
102 /// Initialize to value only on rank 0 - ensures expected result for delayed reduction
103 ///
104 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
105 Reduction(const S &v) {
106 if (hila::myrank() == 0) {
107 val = v;
108 } else {
109 val = 0;
110 }
111 comm_is_on = false;
112 }
113
114 /// Delete copy construct from another reduction just in case
115 /// don't make Reduction temporaries
116 Reduction(const Reduction<T> &r) = delete;
117
118 /// Destructor cleans up communications if they are in progress
120 wait();
121 }
122
123 /// allreduce(bool) turns allreduce on or off. By default on.
124 Reduction &allreduce(bool b = true) {
125 is_allreduce_ = b;
126 return *this;
127 }
128 bool is_allreduce() {
129 return is_allreduce_;
130 }
131
132 /// nonblocking(bool) turns allreduce on or off. By default on.
133 Reduction &nonblocking(bool b = true) {
134 is_nonblocking_ = b;
135 return *this;
136 }
137 bool is_nonblocking() {
138 return is_nonblocking_;
139 }
140
141 /// deferred(bool) turns deferred on or off. By default turns on.
142 Reduction &delayed(bool b = true) {
143 is_delayed_ = b;
144 return *this;
145 }
146 bool is_delayed() {
147 return is_delayed_;
148 }
149
150 /// Return value of the reduction variable. Wait for the comms if needed.
151 const T value() {
152 reduce();
153 return val;
154 }
155
156
157 /// Method set is the same as assignment, but without return value
158 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
159 void set(const S &rhs) {
160 *this = rhs;
161 }
162
163 /// Assignment is used only outside site loops - drop comms if on, no need to wait
164 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
165 T operator=(const S &rhs) {
166 wait();
167
168 comm_is_on = false;
169 T ret = rhs;
170 if (hila::myrank() == 0) {
171 val = ret;
172 } else {
173 val = 0;
174 }
175 // all ranks return the same value
176 return ret;
177 }
178
179 /// Compound operator += is used in reduction but can be used outside onsites too.
180 /// returns void, unconventionally:
181 /// these are used within site loops but their value is not well-defined.
182 /// Thus a = (r += b); is disallowed
183 template <typename S,
184 std::enable_if_t<hila::is_assignable<T &, hila::type_plus<T, S>>::value, int> = 0>
185 void operator+=(const S &rhs) {
186 val += rhs;
187 }
188
189 // template <typename S,
190 // std::enable_if_t<hila::is_assignable<T &, hila::type_mul<T, S>>::value,
191 // int> = 0>
192 // void operator*=(const S &rhs) {
193 // val *= rhs;
194 // }
195
196 // template <typename S,
197 // std::enable_if_t<hila::is_assignable<T &, hila::type_div<T, S>>::value,
198 // int> = 0>
199 // void operator/=(const S &rhs) {
200 // val /= rhs;
201 // }
202
203 // Start sum reduction -- works only if the type T addition == element-wise
204 // addition. This is true for all hila predefined data types
205 void reduce_sum_node(const T &v) {
206
207 wait(); // wait for possible ongoing
208
209 // add the node values to reduction var
210 if (hila::myrank() == 0 || is_delayed_)
211 val += v;
212 else
213 val = v;
214
215 if (is_delayed_) {
216 if (delay_is_on && is_delayed_sum == false) {
217 assert(0 && "Cannot mix sum and product reductions!");
218 }
219 delay_is_on = true;
220 is_delayed_sum = true;
221 } else {
222 do_reduce_operation(MPI_SUM);
223 }
224 }
225
226 /// Product reduction -- currently works only for scalar data types.
227 /// For Complex, Matrix and Vector data product is not element-wise.
228 /// TODO: Array or std::array ?
229 /// TODO: implement using custom MPI ops (if needed)
230 // void reduce_product_node(const T &v) {
231
232 // reduce();
233
234 // if (hila::myrank() == 0 || is_delayed_)
235 // val *= v;
236 // else
237 // val = v;
238
239 // if (is_delayed_) {
240 // if (delay_is_on && is_delayed_sum == true) {
241 // assert(0 && "Cannot mix sum and product reductions!");
242 // }
243 // delay_is_on = true;
244 // is_delayed_sum = false;
245 // } else {
246 // do_reduce_operation(MPI_PROD);
247 // }
248 // }
249
250
251 /// For delayed reduction, start_reduce starts or completes the reduction operation
253 if (!comm_is_on) {
254 if (delay_is_on) {
255 delay_is_on = false;
256
257 if (is_delayed_sum)
258 do_reduce_operation(MPI_SUM);
259 else
260 do_reduce_operation(MPI_PROD);
261 }
262 }
263 }
264
265 /// Complete the reduction - start if not done, and wait if ongoing
266 void reduce() {
267 start_reduce();
268 wait();
269 }
270};
271
272
273////////////////////////////////////////////////////////////////////////////////////
274
275// #if defined(CUDA) || defined(HIP)
276// #include "backend_gpu/gpu_reduction.h"
277
278// template <typename T>
279// T Field<T>::sum(Parity par, bool allreduce) const {
280// return gpu_reduce_sum(allreduce, par, false);
281// }
282
283// #else
284// // This for not-gpu branch
285
286template <typename T>
287T Field<T>::sum(Parity par, bool allreduce) const {
288
289 Reduction<T> result;
290 result.allreduce(allreduce);
291 onsites(par) result += (*this)[X];
292 return result.value();
293}
294
295template <typename T>
296T Field<T>::product(Parity par, bool allreduce) const {
297 static_assert(std::is_arithmetic<T>::value,
298 ".product() reduction only for integer or floating point types");
299 Reduction<T> result;
300 result = 1;
301 result.allreduce(allreduce);
302 onsites(par) result *= (*this)[X];
303 return result.value();
304}
305
306
307// get global minimum/maximums - meant to be used through .min() and .max()
308
309#ifdef OPENMP
310#include <omp.h>
311#endif
312
313#if defined(CUDA) || defined(HIP)
314// #include "backend_gpu/gpu_reduction.h"
315#include "backend_gpu/gpu_minmax.h"
316#endif
317
318template <typename T>
319T Field<T>::minmax(bool is_min, Parity par, CoordinateVector &loc) const {
320
321 static_assert(std::is_same<T, int>::value || std::is_same<T, long>::value ||
322 std::is_same<T, float>::value || std::is_same<T, double>::value ||
323 std::is_same<T, long double>::value,
324 "In Field .min() and .max() methods the Field element type must be one of "
325 "(int/long/float/double/long double)");
326
327#if defined(CUDA) || defined(HIP)
328 T val = gpu_minmax(is_min, par, loc);
329#else
330 int sgn = is_min ? 1 : -1;
331 // get suitable initial value
332 T val = is_min ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
333
334// write the loop with explicit OpenMP parallel region. It has negligible effect
335// on non-OpenMP code, and the pragmas are ignored.
336#pragma omp parallel shared(val, loc, sgn, is_min)
337 {
338 CoordinateVector loc_th(0);
339 T val_th = is_min ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
340
341// Pragma "hila omp_parallel_region" is necessary here, because this is within
342// omp parallel
343#pragma hila novector omp_parallel_region direct_access(loc_th, val_th)
344 onsites(par) {
345 if (sgn * (*this)[X] < sgn * val_th) {
346 val_th = (*this)[X];
347 loc_th = X.coordinates();
348 }
349 }
350
351#pragma omp critical
352 if (sgn * val_th < sgn * val) {
353 val = val_th;
354 loc = loc_th;
355 }
356 }
357#endif
358
359 if (hila::number_of_nodes() > 1) {
360 size_t size;
361 // this returns MPI+int -type, for example MPI_DOUBLE_INT
362 MPI_Datatype dtype = get_MPI_number_type<T>(size, true);
363
364 struct {
365 T v;
366 int rank;
367 } rdata;
368
369 static_assert(sizeof(T) % sizeof(int) == 0,
370 "min/max reduction: datatype struct not packed!");
371
372 rdata.v = val;
373 rdata.rank = hila::myrank();
374
375 // after allreduce rdata contains the min value and rank where it is
376 if (is_min) {
377 MPI_Allreduce(MPI_IN_PLACE, &rdata, 1, dtype, MPI_MINLOC, lattice->mpi_comm_lat);
378 } else {
379 MPI_Allreduce(MPI_IN_PLACE, &rdata, 1, dtype, MPI_MAXLOC, lattice->mpi_comm_lat);
380 }
381 val = rdata.v;
382
383 // send the coordinatevector of the minloc to all nodes
384 MPI_Bcast(&loc, sizeof(CoordinateVector), MPI_BYTE, rdata.rank, lattice->mpi_comm_lat);
385 }
386
387 return val;
388}
389
390
391/// Find minimum value from Field
392template <typename T>
393T Field<T>::min(Parity par) const {
395 return minmax(true, par, loc);
396}
397
398/// Find minimum value and location from Field
399template <typename T>
401 return minmax(true, ALL, loc);
402}
403
404/// Find minimum value and location from Field
405template <typename T>
407 return minmax(true, par, loc);
408}
409
410
411/// Find maximum value from Field
412template <typename T>
413T Field<T>::max(Parity par) const {
415 return minmax(false, par, loc);
416}
417
418/// Find maximum value and location from Field
419template <typename T>
421 return minmax(false, ALL, loc);
422}
423
424/// Find maximum value and location from Field
425template <typename T>
427 return minmax(false, par, loc);
428}
429
430
431#endif
T max(Parity par=ALL) const
Find maximum value from Field.
Definition reduction.h:413
T product(Parity par=Parity::all, bool allreduce=true) const
Product reduction of Field.
Definition reduction.h:296
T minmax(bool is_min, Parity par, CoordinateVector &loc) const
Function to perform min or max operations.
Definition reduction.h:319
T min(Parity par=ALL) const
Find minimum value from Field.
Definition reduction.h:393
T sum(Parity par=Parity::all, bool allreduce=true) const
Sum reduction of Field.
Definition reduction.h:287
Special reduction class: enables delayed and non-blocking reductions, which are not possible with the...
Definition reduction.h:14
const T value()
Return value of the reduction variable. Wait for the comms if needed.
Definition reduction.h:151
~Reduction()
Destructor cleans up communications if they are in progress.
Definition reduction.h:119
T operator=(const S &rhs)
Assignment is used only outside site loops - drop comms if on, no need to wait.
Definition reduction.h:165
void operator+=(const S &rhs)
Definition reduction.h:185
Reduction(const Reduction< T > &r)=delete
Reduction & allreduce(bool b=true)
allreduce(bool) turns allreduce on or off. By default on.
Definition reduction.h:124
Reduction & nonblocking(bool b=true)
nonblocking(bool) turns allreduce on or off. By default on.
Definition reduction.h:133
Reduction & delayed(bool b=true)
deferred(bool) turns deferred on or off. By default turns on.
Definition reduction.h:142
Reduction(const S &v)
Definition reduction.h:105
void set(const S &rhs)
Method set is the same as assignment, but without return value.
Definition reduction.h:159
void reduce()
Complete the reduction - start if not done, and wait if ongoing.
Definition reduction.h:266
void start_reduce()
For delayed reduction, start_reduce starts or completes the reduction operation.
Definition reduction.h:252
Parity
Parity enum with values EVEN, ODD, ALL; refers to parity of the site. Parity of site (x,...
constexpr Parity ALL
bit pattern: 011
int myrank()
rank of this node
Definition com_mpi.cpp:237
int number_of_nodes()
how many nodes there are
Definition com_mpi.cpp:248