HILA
Loading...
Searching...
No Matches
reduction.h
1#ifndef REDUCTION_H_
2#define REDUCTION_H_
3
4#include "hila.h"
5
6
7//////////////////////////////////////////////////////////////////////////////////
8/// @brief Special reduction class: enables delayed and non-blocking reductions, which
9/// are not possible with the standard reduction. See [user guide](@ref reduction_type_guide)
10/// for details
11///
12
13template <typename T>
14class Reduction {
15
16 private:
17 // value holder
18 T val;
19
20 // comm_is_on is true if non-blocking MPI communications are under way.
21 bool comm_is_on = false;
22
23 // Reduction status : is this allreduce, nonblocking, or delayed
24 bool is_allreduce_ = true;
25 bool is_nonblocking_ = false;
26 bool is_delayed_ = false;
27
28 bool delay_is_on = false; // status of the delayed reduction
29 bool is_delayed_sum = true; // sum/product
30
31 MPI_Request request;
32
33 // start the actual reduction
34
35 void do_reduce_operation(MPI_Op operation) {
36
37 // if for some reason reduction is going on unfinished, wait.
38 wait();
39
40 if (is_nonblocking())
41 comm_is_on = true;
42
43 MPI_Datatype dtype;
44
45 dtype = get_MPI_number_type<T>();
46
47 assert(dtype != MPI_BYTE && "Unknown number_type in reduction");
48
49 void *ptr = &val;
50
51 reduction_timer.start();
52 if (is_allreduce()) {
53 if (is_nonblocking()) {
54 MPI_Iallreduce(MPI_IN_PLACE, ptr,
55 sizeof(T) / sizeof(hila::arithmetic_type<T>), dtype,
56 operation, lattice.mpi_comm_lat, &request);
57 } else {
58 MPI_Allreduce(MPI_IN_PLACE, ptr,
59 sizeof(T) / sizeof(hila::arithmetic_type<T>), dtype,
60 operation, lattice.mpi_comm_lat);
61 }
62 } else {
63 if (hila::myrank() == 0) {
64 if (is_nonblocking()) {
65 MPI_Ireduce(MPI_IN_PLACE, ptr,
66 sizeof(T) / sizeof(hila::arithmetic_type<T>), dtype,
67 operation, 0, lattice.mpi_comm_lat, &request);
68 } else {
69 MPI_Reduce(MPI_IN_PLACE, ptr,
70 sizeof(T) / sizeof(hila::arithmetic_type<T>), dtype,
71 operation, 0, lattice.mpi_comm_lat);
72 }
73 } else {
74 if (is_nonblocking()) {
75 MPI_Ireduce(ptr, ptr, sizeof(T) / sizeof(hila::arithmetic_type<T>),
76 dtype, operation, 0, lattice.mpi_comm_lat, &request);
77 } else {
78 MPI_Reduce(ptr, ptr, sizeof(T) / sizeof(hila::arithmetic_type<T>),
79 dtype, operation, 0, lattice.mpi_comm_lat);
80 }
81 }
82 }
83 reduction_timer.stop();
84 }
85
86 /// Wait for MPI to complete, if it is currently going on
87 /// This must be called for non-blocking reduce before use!
88
89 void wait() {
90 if (comm_is_on) {
91 reduction_wait_timer.start();
92 MPI_Status status;
93 MPI_Wait(&request, &status);
94 reduction_wait_timer.stop();
95 comm_is_on = false;
96 }
97 }
98
99 public:
100 /// Initialize to zero by default (? exception to other variables)
101 /// allreduce = true by default
103 val = 0;
104 comm_is_on = false;
105 }
106 /// Initialize to value only on rank 0 - ensures expected result for delayed reduction
107 ///
108 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
109 Reduction(const S &v) {
110 if (hila::myrank() == 0) {
111 val = v;
112 } else {
113 val = 0;
114 }
115 comm_is_on = false;
116 }
117
118 /// Delete copy construct from another reduction just in case
119 /// don't make Reduction temporaries
120 Reduction(const Reduction<T> &r) = delete;
121
122 /// Destructor cleans up communications if they are in progress
124 if (comm_is_on) {
125 MPI_Cancel(&request);
126 }
127 }
128
129 /// allreduce(bool) turns allreduce on or off. By default on.
130 Reduction &allreduce(bool b = true) {
131 is_allreduce_ = b;
132 return *this;
133 }
134 bool is_allreduce() {
135 return is_allreduce_;
136 }
137
138 /// nonblocking(bool) turns allreduce on or off. By default on.
139 Reduction &nonblocking(bool b = true) {
140 is_nonblocking_ = b;
141 return *this;
142 }
143 bool is_nonblocking() {
144 return is_nonblocking_;
145 }
146
147 /// deferred(bool) turns deferred on or off. By default turns on.
148 Reduction &delayed(bool b = true) {
149 is_delayed_ = b;
150 return *this;
151 }
152 bool is_delayed() {
153 return is_delayed_;
154 }
155
156 /// Return value of the reduction variable. Wait for the comms if needed.
157 const T value() {
158 reduce();
159 return val;
160 }
161
162
163 /// Method set is the same as assignment, but without return value
164 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
165 void set(const S &rhs) {
166 *this = rhs;
167 }
168
169 /// Assignment is used only outside site loops - drop comms if on, no need to wait
170 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
171 T operator=(const S &rhs) {
172 if (comm_is_on)
173 MPI_Cancel(&request);
174
175 comm_is_on = false;
176 T ret = rhs;
177 if (hila::myrank() == 0) {
178 val = ret;
179 } else {
180 val = 0;
181 }
182 // all ranks return the same value
183 return ret;
184 }
185
186 /// Compound operator += is used in reduction but can be used outside onsites too.
187 /// returns void, unconventionally:
188 /// these are used within site loops but their value is not well-defined.
189 /// Thus a = (r += b); is disallowed
190 template <typename S,
191 std::enable_if_t<hila::is_assignable<T &, hila::type_plus<T, S>>::value,
192 int> = 0>
193 void operator+=(const S &rhs) {
194 val += rhs;
195 }
196
197 // template <typename S,
198 // std::enable_if_t<hila::is_assignable<T &, hila::type_mul<T, S>>::value,
199 // int> = 0>
200 // void operator*=(const S &rhs) {
201 // val *= rhs;
202 // }
203
204 // template <typename S,
205 // std::enable_if_t<hila::is_assignable<T &, hila::type_div<T, S>>::value,
206 // int> = 0>
207 // void operator/=(const S &rhs) {
208 // val /= rhs;
209 // }
210
211 // Start sum reduction -- works only if the type T addition == element-wise
212 // addition. This is true for all hila predefined data types
213 void reduce_sum_node(const T &v) {
214
215 wait(); // wait for possible ongoing
216
217 // add the node values to reduction var
218 if (hila::myrank() == 0 || is_delayed_)
219 val += v;
220 else
221 val = v;
222
223 if (is_delayed_) {
224 if (delay_is_on && is_delayed_sum == false) {
225 assert(0 && "Cannot mix sum and product reductions!");
226 }
227 delay_is_on = true;
228 is_delayed_sum = true;
229 } else {
230 do_reduce_operation(MPI_SUM);
231 }
232 }
233
234 /// Product reduction -- currently works only for scalar data types.
235 /// For Complex, Matrix and Vector data product is not element-wise.
236 /// TODO: Array or std::array ?
237 /// TODO: implement using custom MPI ops (if needed)
238 // void reduce_product_node(const T &v) {
239
240 // reduce();
241
242 // if (hila::myrank() == 0 || is_delayed_)
243 // val *= v;
244 // else
245 // val = v;
246
247 // if (is_delayed_) {
248 // if (delay_is_on && is_delayed_sum == true) {
249 // assert(0 && "Cannot mix sum and product reductions!");
250 // }
251 // delay_is_on = true;
252 // is_delayed_sum = false;
253 // } else {
254 // do_reduce_operation(MPI_PROD);
255 // }
256 // }
257
258
259 /// For delayed reduction, start_reduce starts or completes the reduction operation
261 if (!comm_is_on) {
262 if (delay_is_on) {
263 delay_is_on = false;
264
265 if (is_delayed_sum)
266 do_reduce_operation(MPI_SUM);
267 else
268 do_reduce_operation(MPI_PROD);
269 }
270 }
271 }
272
273 /// Complete the reduction - start if not done, and wait if ongoing
274 void reduce() {
275 start_reduce();
276 wait();
277 }
278
279};
280
281
282////////////////////////////////////////////////////////////////////////////////////
283
284// #if defined(CUDA) || defined(HIP)
285// #include "backend_gpu/gpu_reduction.h"
286
287// template <typename T>
288// T Field<T>::sum(Parity par, bool allreduce) const {
289// return gpu_reduce_sum(allreduce, par, false);
290// }
291
292// #else
293// // This for not-gpu branch
294
295template <typename T>
296T Field<T>::sum(Parity par, bool allreduce) const {
297
298 Reduction<T> result;
299 result.allreduce(allreduce);
300 onsites (par)
301 result += (*this)[X];
302 return result.value();
303}
304
305template <typename T>
306T Field<T>::product(Parity par, bool allreduce) const {
307 static_assert(std::is_arithmetic<T>::value,
308 ".product() reduction only for integer or floating point types");
309 Reduction<T> result;
310 result = 1;
311 result.allreduce(allreduce);
312 onsites (par)
313 result *= (*this)[X];
314 return result.value();
315}
316
317
318// get global minimum/maximums - meant to be used through .min() and .max()
319
320#ifdef OPENMP
321#include <omp.h>
322#endif
323
324#if defined(CUDA) || defined(HIP)
325#include "backend_gpu/gpu_reduction.h"
326#endif
327
328template <typename T>
329T Field<T>::minmax(bool is_min, Parity par, CoordinateVector &loc) const {
330
331 static_assert(
332 std::is_same<T, int>::value || std::is_same<T, long>::value ||
333 std::is_same<T, float>::value || std::is_same<T, double>::value ||
334 std::is_same<T, long double>::value,
335 "In Field .min() and .max() methods the Field element type must be one of "
336 "(int/long/float/double/long double)");
337
338#if defined(CUDA) || defined(HIP)
339 T val = gpu_minmax(is_min, par, loc);
340#else
341 int sgn = is_min ? 1 : -1;
342 // get suitable initial value
343 T val = is_min ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
344
345// write the loop with explicit OpenMP parallel region. It has negligible effect
346// on non-OpenMP code, and the pragmas are ignored.
347#pragma omp parallel shared(val, loc, sgn, is_min)
348 {
349 CoordinateVector loc_th(0);
350 T val_th =
351 is_min ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
352
353// Pragma "hila omp_parallel_region" is necessary here, because this is within
354// omp parallel
355#pragma hila novector omp_parallel_region direct_access(loc_th, val_th)
356 onsites (par) {
357 if (sgn * (*this)[X] < sgn * val_th) {
358 val_th = (*this)[X];
359 loc_th = X.coordinates();
360 }
361 }
362
363#pragma omp critical
364 if (sgn * val_th < sgn * val) {
365 val = val_th;
366 loc = loc_th;
367 }
368 }
369#endif
370
371 if (hila::number_of_nodes() > 1) {
372 size_t size;
373 MPI_Datatype dtype = get_MPI_number_type<T>(size, true);
374
375 struct {
376 T v;
377 int rank;
378 } rdata;
379
380 rdata.v = val;
381 rdata.rank = hila::myrank();
382
383 // after allreduce rdata contains the min value and rank where it is
384 if (is_min) {
385 MPI_Allreduce(MPI_IN_PLACE, &rdata, 1, dtype, MPI_MINLOC,
386 lattice.mpi_comm_lat);
387 } else {
388 MPI_Allreduce(MPI_IN_PLACE, &rdata, 1, dtype, MPI_MAXLOC,
389 lattice.mpi_comm_lat);
390 }
391 val = rdata.v;
392
393 // send the coordinatevector of the minloc to all nodes
394 MPI_Bcast(&loc, sizeof(CoordinateVector), MPI_BYTE, rdata.rank,
395 lattice.mpi_comm_lat);
396 }
397
398 return val;
399}
400
401
402/// Find minimum value from Field
403template <typename T>
404T Field<T>::min(Parity par) const {
406 return minmax(true, par, loc);
407}
408
409/// Find minimum value and location from Field
410template <typename T>
412 return minmax(true, ALL, loc);
413}
414
415/// Find minimum value and location from Field
416template <typename T>
418 return minmax(true, par, loc);
419}
420
421
422/// Find maximum value from Field
423template <typename T>
424T Field<T>::max(Parity par) const {
426 return minmax(false, par, loc);
427}
428
429/// Find maximum value and location from Field
430template <typename T>
432 return minmax(false, ALL, loc);
433}
434
435/// Find maximum value and location from Field
436template <typename T>
438 return minmax(false, par, loc);
439}
440
441
442
443#endif
T max(Parity par=ALL) const
Find maximum value from Field.
Definition reduction.h:424
T product(Parity par=Parity::all, bool allreduce=true) const
Product reduction of Field.
Definition reduction.h:306
T minmax(bool is_min, Parity par, CoordinateVector &loc) const
Function to perform min or max operations.
Definition reduction.h:329
T min(Parity par=ALL) const
Find minimum value from Field.
Definition reduction.h:404
T sum(Parity par=Parity::all, bool allreduce=true) const
Sum reduction of Field.
Definition reduction.h:296
Special reduction class: enables delayed and non-blocking reductions, which are not possible with the...
Definition reduction.h:14
const T value()
Return value of the reduction variable. Wait for the comms if needed.
Definition reduction.h:157
~Reduction()
Destructor cleans up communications if they are in progress.
Definition reduction.h:123
T operator=(const S &rhs)
Assignment is used only outside site loops - drop comms if on, no need to wait.
Definition reduction.h:171
void operator+=(const S &rhs)
Definition reduction.h:193
Reduction(const Reduction< T > &r)=delete
Reduction & allreduce(bool b=true)
allreduce(bool) turns allreduce on or off. By default on.
Definition reduction.h:130
Reduction & nonblocking(bool b=true)
nonblocking(bool) turns allreduce on or off. By default on.
Definition reduction.h:139
Reduction & delayed(bool b=true)
deferred(bool) turns deferred on or off. By default turns on.
Definition reduction.h:148
Reduction(const S &v)
Definition reduction.h:109
void set(const S &rhs)
Method set is the same as assignment, but without return value.
Definition reduction.h:165
void reduce()
Complete the reduction - start if not done, and wait if ongoing.
Definition reduction.h:274
void start_reduce()
For delayed reduction, start_reduce starts or completes the reduction operation.
Definition reduction.h:260
Parity
Parity enum with values EVEN, ODD, ALL; refers to parity of the site. Parity of site (x,...
constexpr Parity ALL
bit pattern: 011
int myrank()
rank of this node
Definition com_mpi.cpp:234
int number_of_nodes()
how many nodes there are
Definition com_mpi.cpp:245