HILA
Loading...
Searching...
No Matches
reduction.h
1#ifndef REDUCTION_H_
2#define REDUCTION_H_
3
4#include "hila.h"
5
6
7//////////////////////////////////////////////////////////////////////////////////
8/// @brief Special reduction class: enables delayed and non-blocking reductions, which
9/// are not possible with the standard reduction. See [user guide](@ref reduction_type_guide)
10/// for details
11///
12
13template <typename T>
14class Reduction {
15
16 private:
17 // value holder
18 T val;
19
20 // comm_is_on is true if non-blocking MPI communications are under way.
21 bool comm_is_on = false;
22
23 // Reduction status : is this allreduce, nonblocking, or delayed
24 bool is_allreduce_ = true;
25 bool is_nonblocking_ = false;
26 bool is_delayed_ = false;
27
28 bool delay_is_on = false; // status of the delayed reduction
29 bool is_delayed_sum = true; // sum/product
30
31 MPI_Request request;
32
33 // start the actual reduction
34
35 void do_reduce_operation(MPI_Op operation) {
36
37 // if for some reason reduction is going on unfinished, wait.
38 wait();
39
40 if (is_nonblocking())
41 comm_is_on = true;
42
43 MPI_Datatype dtype;
44
45 dtype = get_MPI_number_type<T>();
46
47 assert(dtype != MPI_BYTE && "Unknown number_type in reduction");
48
49 void *ptr = &val;
50
51 reduction_timer.start();
52 if (is_allreduce()) {
53 if (is_nonblocking()) {
54 MPI_Iallreduce(MPI_IN_PLACE, ptr,
55 sizeof(T) / sizeof(hila::arithmetic_type<T>), dtype,
56 operation, lattice.mpi_comm_lat, &request);
57 } else {
58 MPI_Allreduce(MPI_IN_PLACE, ptr,
59 sizeof(T) / sizeof(hila::arithmetic_type<T>), dtype,
60 operation, lattice.mpi_comm_lat);
61 }
62 } else {
63 if (hila::myrank() == 0) {
64 if (is_nonblocking()) {
65 MPI_Ireduce(MPI_IN_PLACE, ptr,
66 sizeof(T) / sizeof(hila::arithmetic_type<T>), dtype,
67 operation, 0, lattice.mpi_comm_lat, &request);
68 } else {
69 MPI_Reduce(MPI_IN_PLACE, ptr,
70 sizeof(T) / sizeof(hila::arithmetic_type<T>), dtype,
71 operation, 0, lattice.mpi_comm_lat);
72 }
73 } else {
74 if (is_nonblocking()) {
75 MPI_Ireduce(ptr, ptr, sizeof(T) / sizeof(hila::arithmetic_type<T>),
76 dtype, operation, 0, lattice.mpi_comm_lat, &request);
77 } else {
78 MPI_Reduce(ptr, ptr, sizeof(T) / sizeof(hila::arithmetic_type<T>),
79 dtype, operation, 0, lattice.mpi_comm_lat);
80 }
81 }
82 }
83 reduction_timer.stop();
84 }
85
86 /// Wait for MPI to complete, if it is currently going on
87 /// This must be called for non-blocking reduce before use!
88
89 void wait() {
90 if (comm_is_on) {
91 reduction_wait_timer.start();
92 MPI_Status status;
93 MPI_Wait(&request, &status);
94 reduction_wait_timer.stop();
95 comm_is_on = false;
96 }
97 }
98
99 public:
100 /// Initialize to zero by default (? exception to other variables)
101 /// allreduce = true by default
103 val = 0;
104 comm_is_on = false;
105 }
106 /// Initialize to value only on rank 0 - ensures expected result for delayed reduction
107 ///
108 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
109 Reduction(const S &v) {
110 if (hila::myrank() == 0) {
111 val = v;
112 } else {
113 val = 0;
114 }
115 comm_is_on = false;
116 }
117
118 /// Delete copy construct from another reduction just in case
119 /// don't make Reduction temporaries
120 Reduction(const Reduction<T> &r) = delete;
121
122 /// Destructor cleans up communications if they are in progress
124 wait();
125 }
126
127 /// allreduce(bool) turns allreduce on or off. By default on.
128 Reduction &allreduce(bool b = true) {
129 is_allreduce_ = b;
130 return *this;
131 }
132 bool is_allreduce() {
133 return is_allreduce_;
134 }
135
136 /// nonblocking(bool) turns allreduce on or off. By default on.
137 Reduction &nonblocking(bool b = true) {
138 is_nonblocking_ = b;
139 return *this;
140 }
141 bool is_nonblocking() {
142 return is_nonblocking_;
143 }
144
145 /// deferred(bool) turns deferred on or off. By default turns on.
146 Reduction &delayed(bool b = true) {
147 is_delayed_ = b;
148 return *this;
149 }
150 bool is_delayed() {
151 return is_delayed_;
152 }
153
154 /// Return value of the reduction variable. Wait for the comms if needed.
155 const T value() {
156 reduce();
157 return val;
158 }
159
160
161 /// Method set is the same as assignment, but without return value
162 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
163 void set(const S &rhs) {
164 *this = rhs;
165 }
166
167 /// Assignment is used only outside site loops - drop comms if on, no need to wait
168 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
169 T operator=(const S &rhs) {
170 wait();
171
172 comm_is_on = false;
173 T ret = rhs;
174 if (hila::myrank() == 0) {
175 val = ret;
176 } else {
177 val = 0;
178 }
179 // all ranks return the same value
180 return ret;
181 }
182
183 /// Compound operator += is used in reduction but can be used outside onsites too.
184 /// returns void, unconventionally:
185 /// these are used within site loops but their value is not well-defined.
186 /// Thus a = (r += b); is disallowed
187 template <typename S,
188 std::enable_if_t<hila::is_assignable<T &, hila::type_plus<T, S>>::value,
189 int> = 0>
190 void operator+=(const S &rhs) {
191 val += rhs;
192 }
193
194 // template <typename S,
195 // std::enable_if_t<hila::is_assignable<T &, hila::type_mul<T, S>>::value,
196 // int> = 0>
197 // void operator*=(const S &rhs) {
198 // val *= rhs;
199 // }
200
201 // template <typename S,
202 // std::enable_if_t<hila::is_assignable<T &, hila::type_div<T, S>>::value,
203 // int> = 0>
204 // void operator/=(const S &rhs) {
205 // val /= rhs;
206 // }
207
208 // Start sum reduction -- works only if the type T addition == element-wise
209 // addition. This is true for all hila predefined data types
210 void reduce_sum_node(const T &v) {
211
212 wait(); // wait for possible ongoing
213
214 // add the node values to reduction var
215 if (hila::myrank() == 0 || is_delayed_)
216 val += v;
217 else
218 val = v;
219
220 if (is_delayed_) {
221 if (delay_is_on && is_delayed_sum == false) {
222 assert(0 && "Cannot mix sum and product reductions!");
223 }
224 delay_is_on = true;
225 is_delayed_sum = true;
226 } else {
227 do_reduce_operation(MPI_SUM);
228 }
229 }
230
231 /// Product reduction -- currently works only for scalar data types.
232 /// For Complex, Matrix and Vector data product is not element-wise.
233 /// TODO: Array or std::array ?
234 /// TODO: implement using custom MPI ops (if needed)
235 // void reduce_product_node(const T &v) {
236
237 // reduce();
238
239 // if (hila::myrank() == 0 || is_delayed_)
240 // val *= v;
241 // else
242 // val = v;
243
244 // if (is_delayed_) {
245 // if (delay_is_on && is_delayed_sum == true) {
246 // assert(0 && "Cannot mix sum and product reductions!");
247 // }
248 // delay_is_on = true;
249 // is_delayed_sum = false;
250 // } else {
251 // do_reduce_operation(MPI_PROD);
252 // }
253 // }
254
255
256 /// For delayed reduction, start_reduce starts or completes the reduction operation
258 if (!comm_is_on) {
259 if (delay_is_on) {
260 delay_is_on = false;
261
262 if (is_delayed_sum)
263 do_reduce_operation(MPI_SUM);
264 else
265 do_reduce_operation(MPI_PROD);
266 }
267 }
268 }
269
270 /// Complete the reduction - start if not done, and wait if ongoing
271 void reduce() {
272 start_reduce();
273 wait();
274 }
275
276};
277
278
279////////////////////////////////////////////////////////////////////////////////////
280
281// #if defined(CUDA) || defined(HIP)
282// #include "backend_gpu/gpu_reduction.h"
283
284// template <typename T>
285// T Field<T>::sum(Parity par, bool allreduce) const {
286// return gpu_reduce_sum(allreduce, par, false);
287// }
288
289// #else
290// // This for not-gpu branch
291
292template <typename T>
293T Field<T>::sum(Parity par, bool allreduce) const {
294
295 Reduction<T> result;
296 result.allreduce(allreduce);
297 onsites (par)
298 result += (*this)[X];
299 return result.value();
300}
301
302template <typename T>
303T Field<T>::product(Parity par, bool allreduce) const {
304 static_assert(std::is_arithmetic<T>::value,
305 ".product() reduction only for integer or floating point types");
306 Reduction<T> result;
307 result = 1;
308 result.allreduce(allreduce);
309 onsites (par)
310 result *= (*this)[X];
311 return result.value();
312}
313
314
315// get global minimum/maximums - meant to be used through .min() and .max()
316
317#ifdef OPENMP
318#include <omp.h>
319#endif
320
321#if defined(CUDA) || defined(HIP)
322#include "backend_gpu/gpu_reduction.h"
323#endif
324
325template <typename T>
326T Field<T>::minmax(bool is_min, Parity par, CoordinateVector &loc) const {
327
328 static_assert(
329 std::is_same<T, int>::value || std::is_same<T, long>::value ||
330 std::is_same<T, float>::value || std::is_same<T, double>::value ||
331 std::is_same<T, long double>::value,
332 "In Field .min() and .max() methods the Field element type must be one of "
333 "(int/long/float/double/long double)");
334
335#if defined(CUDA) || defined(HIP)
336 T val = gpu_minmax(is_min, par, loc);
337#else
338 int sgn = is_min ? 1 : -1;
339 // get suitable initial value
340 T val = is_min ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
341
342// write the loop with explicit OpenMP parallel region. It has negligible effect
343// on non-OpenMP code, and the pragmas are ignored.
344#pragma omp parallel shared(val, loc, sgn, is_min)
345 {
346 CoordinateVector loc_th(0);
347 T val_th =
348 is_min ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
349
350// Pragma "hila omp_parallel_region" is necessary here, because this is within
351// omp parallel
352#pragma hila novector omp_parallel_region direct_access(loc_th, val_th)
353 onsites (par) {
354 if (sgn * (*this)[X] < sgn * val_th) {
355 val_th = (*this)[X];
356 loc_th = X.coordinates();
357 }
358 }
359
360#pragma omp critical
361 if (sgn * val_th < sgn * val) {
362 val = val_th;
363 loc = loc_th;
364 }
365 }
366#endif
367
368 if (hila::number_of_nodes() > 1) {
369 size_t size;
370 MPI_Datatype dtype = get_MPI_number_type<T>(size, true);
371
372 struct {
373 T v;
374 int rank;
375 } rdata;
376
377 rdata.v = val;
378 rdata.rank = hila::myrank();
379
380 // after allreduce rdata contains the min value and rank where it is
381 if (is_min) {
382 MPI_Allreduce(MPI_IN_PLACE, &rdata, 1, dtype, MPI_MINLOC,
383 lattice.mpi_comm_lat);
384 } else {
385 MPI_Allreduce(MPI_IN_PLACE, &rdata, 1, dtype, MPI_MAXLOC,
386 lattice.mpi_comm_lat);
387 }
388 val = rdata.v;
389
390 // send the coordinatevector of the minloc to all nodes
391 MPI_Bcast(&loc, sizeof(CoordinateVector), MPI_BYTE, rdata.rank,
392 lattice.mpi_comm_lat);
393 }
394
395 return val;
396}
397
398
399/// Find minimum value from Field
400template <typename T>
401T Field<T>::min(Parity par) const {
403 return minmax(true, par, loc);
404}
405
406/// Find minimum value and location from Field
407template <typename T>
409 return minmax(true, ALL, loc);
410}
411
412/// Find minimum value and location from Field
413template <typename T>
415 return minmax(true, par, loc);
416}
417
418
419/// Find maximum value from Field
420template <typename T>
421T Field<T>::max(Parity par) const {
423 return minmax(false, par, loc);
424}
425
426/// Find maximum value and location from Field
427template <typename T>
429 return minmax(false, ALL, loc);
430}
431
432/// Find maximum value and location from Field
433template <typename T>
435 return minmax(false, par, loc);
436}
437
438
439
440#endif
T max(Parity par=ALL) const
Find maximum value from Field.
Definition reduction.h:421
T product(Parity par=Parity::all, bool allreduce=true) const
Product reduction of Field.
Definition reduction.h:303
T minmax(bool is_min, Parity par, CoordinateVector &loc) const
Function to perform min or max operations.
Definition reduction.h:326
T min(Parity par=ALL) const
Find minimum value from Field.
Definition reduction.h:401
T sum(Parity par=Parity::all, bool allreduce=true) const
Sum reduction of Field.
Definition reduction.h:293
Special reduction class: enables delayed and non-blocking reductions, which are not possible with the...
Definition reduction.h:14
const T value()
Return value of the reduction variable. Wait for the comms if needed.
Definition reduction.h:155
~Reduction()
Destructor cleans up communications if they are in progress.
Definition reduction.h:123
T operator=(const S &rhs)
Assignment is used only outside site loops - drop comms if on, no need to wait.
Definition reduction.h:169
void operator+=(const S &rhs)
Definition reduction.h:190
Reduction(const Reduction< T > &r)=delete
Reduction & allreduce(bool b=true)
allreduce(bool) turns allreduce on or off. By default on.
Definition reduction.h:128
Reduction & nonblocking(bool b=true)
nonblocking(bool) turns allreduce on or off. By default on.
Definition reduction.h:137
Reduction & delayed(bool b=true)
deferred(bool) turns deferred on or off. By default turns on.
Definition reduction.h:146
Reduction(const S &v)
Definition reduction.h:109
void set(const S &rhs)
Method set is the same as assignment, but without return value.
Definition reduction.h:163
void reduce()
Complete the reduction - start if not done, and wait if ongoing.
Definition reduction.h:271
void start_reduce()
For delayed reduction, start_reduce starts or completes the reduction operation.
Definition reduction.h:257
Parity
Parity enum with values EVEN, ODD, ALL; refers to parity of the site. Parity of site (x,...
constexpr Parity ALL
bit pattern: 011
int myrank()
rank of this node
Definition com_mpi.cpp:234
int number_of_nodes()
how many nodes there are
Definition com_mpi.cpp:245