HILA
Loading...
Searching...
No Matches
reductionvector.h
1#ifndef HILA_REDUCTIONVECTOR_H_
2#define HILA_REDUCTIONVECTOR_H_
3
4#include "hila.h"
5
6
7#if defined(HILAPP)
8
9// This is a dummy function, forcing in GPU code the generation of __device__ functions for +, +=
10// -operators and constructors. Without this cub::blockReduce gives wrong answers! This does not
11// generate any code, only hilapp sees this. The function is formally called from ReductionVector<T>
12// destructor
13
14template <typename T>
15inline void _hila_init_gpu_ops_vectorreduction() {
16 Field<T> a;
17 onsites(ALL) {
18 a[X] = 0;
19 T v = 0;
20 v += v + a[X];
21 a[X] = v + v;
22 }
23}
24
25#endif
26
27
28//////////////////////////////////////////////////////////////////////////////////
29/// Special reduction class for arrays: declare a reduction array as
30/// ReductionVector<T> a(size);
31///
32/// This can be used within site loops as
33/// onsites(ALL ) {
34/// int i = ...
35/// a[i] += ...
36/// }
37///
38/// or outside site loops as usual.
39/// Reductions: += *= sum or product reduction
40/// NOTE: size of the reduction vector must be same on all nodes!
41/// By default reduction is "allreduce", i.e. all nodes get the same result.
42///
43/// Reduction can be customized by using methods:
44/// a.size() : return the size of the array
45/// a.resize(new_size) : change size of array (all nodes must use same size!)
46/// a[i] : get/set element i
47///
48/// a.allreduce(bool) : turn allreduce on/off (default=true)
49/// a.nonblocking(bool) : turn non-blocking reduction on/off
50/// a.delayed(bool) : turn delayed reduction on/off
51///
52/// a.wait() : for non-blocking reduction, wait() has to be called
53/// after the loop to complete the reduction
54/// a.reduce() : for delayed reduction starts/completes the reduction
55///
56/// a.is_allreduce() : queries the reduction status
57/// is_nonblocking()
58/// is_delayed()
59///
60/// a.push_back(element) : add one element to
61/// The same reduction variable can be used again
62///
63
64template <typename T>
66
67 private:
68 std::vector<T> val;
69
70 /// comm_is_on is true if MPI communications are under way.
71 bool comm_is_on = false;
72
73 /// status variables of reduction
74 bool is_allreduce_ = true;
75 bool is_nonblocking_ = false;
76 bool is_delayed_ = false;
77
78 bool delay_is_on = false; // status of the delayed reduction
79 bool is_delayed_sum = true; // sum/product
80
81 MPI_Request request;
82
83 void reduce_operation(MPI_Op operation) {
84
85 // if for some reason reduction is going on unfinished, wait.
86 wait();
87
88 if (is_nonblocking_)
89 comm_is_on = true;
90
91 MPI_Datatype dtype;
92 dtype = get_MPI_number_type<T>();
93
94 if (dtype == MPI_BYTE) {
95 assert(sizeof(T) < 0 && "Unknown number_type in vector reduction");
96 }
97
98 reduction_timer.start();
99 if (is_allreduce_) {
100 if (is_nonblocking_) {
101 MPI_Iallreduce(MPI_IN_PLACE, (void *)val.data(),
102 sizeof(T) * val.size() / sizeof(hila::arithmetic_type<T>), dtype,
103 operation, lattice->mpi_comm_lat, &request);
104 } else {
105 MPI_Allreduce(MPI_IN_PLACE, (void *)val.data(),
106 sizeof(T) * val.size() / sizeof(hila::arithmetic_type<T>), dtype,
107 operation, lattice->mpi_comm_lat);
108 }
109 } else {
110 if (hila::myrank() == 0) {
111 if (is_nonblocking_) {
112 MPI_Ireduce(MPI_IN_PLACE, (void *)val.data(),
113 sizeof(T) * val.size() / sizeof(hila::arithmetic_type<T>), dtype,
114 operation, 0, lattice->mpi_comm_lat, &request);
115 } else {
116 MPI_Reduce(MPI_IN_PLACE, (void *)val.data(),
117 sizeof(T) * val.size() / sizeof(hila::arithmetic_type<T>), dtype,
118 operation, 0, lattice->mpi_comm_lat);
119 }
120 } else {
121 if (is_nonblocking_) {
122 MPI_Ireduce((void *)val.data(), (void *)val.data(),
123 sizeof(T) * val.size() / sizeof(hila::arithmetic_type<T>), dtype,
124 operation, 0, lattice->mpi_comm_lat, &request);
125 } else {
126 MPI_Reduce((void *)val.data(), (void *)val.data(),
127 sizeof(T) * val.size() / sizeof(hila::arithmetic_type<T>), dtype,
128 operation, 0, lattice->mpi_comm_lat);
129 }
130 }
131 }
132 reduction_timer.stop();
133 }
134
135 public:
136 // Define iterators using std::vector iterators
137 using iterator = typename std::vector<T>::iterator;
138 using const_iterator = typename std::vector<T>::const_iterator;
139
140 iterator begin() {
141 return val.begin();
142 }
143 iterator end() {
144 return val.end();
145 }
146 const_iterator begin() const {
147 return val.begin();
148 }
149 const_iterator end() const {
150 return val.end();
151 }
152
153 /// Initialize to zero by default (? exception to other variables)
154 /// allreduce = true by default
155 explicit ReductionVector() {}
156 explicit ReductionVector(int size) : val(size, (T)0) {}
157 explicit ReductionVector(int size, const T &v) : val(size, v) {}
158
159 /// Destructor cleans up communications if they are in progress
161 wait();
162#if defined(HILAPP)
163 _hila_init_gpu_ops_vectorreduction<T>();
164#endif
165 }
166
167 /// And access operators - these do in practice everything already!
168 T &operator[](const int i) {
169 return val[i];
170 }
171
172 T operator[](const int i) const {
173 return val[i];
174 }
175
176 /// allreduce(bool) turns allreduce on or off. By default on.
177 ReductionVector &allreduce(bool b = true) {
178 is_allreduce_ = b;
179 return *this;
180 }
181 bool is_allreduce() {
182 return is_allreduce_;
183 }
184
185 /// nonblocking(bool) turns allreduce on or off. By default on.
186 ReductionVector &nonblocking(bool b = true) {
187 is_nonblocking_ = b;
188 return *this;
189 }
190 bool is_nonblocking() {
191 return is_nonblocking_;
192 }
193
194 /// deferred(bool) turns deferred on or off. By default turns on.
195 ReductionVector &delayed(bool b = true) {
196 is_delayed_ = b;
197 return *this;
198 }
199 bool is_delayed() {
200 return is_delayed_;
201 }
202
203 /// Assignment is used only outside site loops - wait for comms if needed
204 /// Make this return void, hard to imagine it is used for anything useful
205 template <typename S, std::enable_if_t<std::is_assignable<T &, S>::value, int> = 0>
206 void operator=(const S &rhs) {
207 for (auto &vp : val)
208 vp = rhs;
209 }
210
211 /// Assignment from 0
212 void operator=(std::nullptr_t np) {
213 for (auto &vp : val)
214 vp = 0;
215 }
216
217 // Don't even implement compound assignments
218
219 /// Init is to be called before every site loop
220 void init_sum() {
221 // if something is happening wait
222 wait();
223 if (hila::myrank() != 0 && !delay_is_on) {
224 for (auto &vp : val)
225 vp = 0;
226 }
227 }
228 /// Init is to be called before every site loop
230 wait();
231 if (hila::myrank() != 0 && !delay_is_on) {
232 for (auto &vp : val)
233 vp = 1;
234 }
235 }
236
237 /// Start sum reduction -- works only if the type T addition == element-wise
238 /// addition. This is true for all hila predefined data types
239 void reduce_sum() {
240
241 if (is_delayed_) {
242 if (delay_is_on && is_delayed_sum == false) {
243 assert(0 && "Cannot mix sum and product reductions!");
244 }
245 delay_is_on = true;
246 } else {
247 reduce_operation(MPI_SUM);
248 }
249 }
250
251 /// Product reduction -- currently works only for scalar data types.
252 /// For Complex, Matrix and Vector data product is not element-wise.
253 /// TODO: Array or std::array ?
254 /// TODO: implement using custom MPI ops (if needed)
256
257 static_assert(std::is_same<T, int>::value || std::is_same<T, long>::value ||
258 std::is_same<T, float>::value || std::is_same<T, double>::value ||
259 std::is_same<T, long double>::value,
260 "Type not implemented for product reduction");
261
262 if (is_delayed_) {
263 if (delay_is_on && is_delayed_sum == true) {
264 assert(0 && "Cannot mix sum and product reductions!");
265 }
266 delay_is_on = true;
267 } else {
268 reduce_operation(MPI_PROD);
269 }
270 }
271
272 /// Wait for MPI to complete, if it is currently going on
273 void wait() {
274
275 if (comm_is_on) {
276 reduction_wait_timer.start();
277 MPI_Status status;
278 MPI_Wait(&request, &status);
279 reduction_wait_timer.stop();
280 comm_is_on = false;
281 }
282 }
283
284 /// For delayed reduction, reduce starts or completes the reduction operation
286 if (delay_is_on) {
287 delay_is_on = false;
288
289 if (is_delayed_sum)
290 reduce_operation(MPI_SUM);
291 else
292 reduce_operation(MPI_PROD);
293 }
294 }
295
296 /// Complete non-blocking or delayed reduction
297 void reduce() {
298 start_reduce();
299 wait();
300 }
301
302 /// data() returns ptr to the raw storage
303 T *data() {
304 return val.data();
305 }
306
307 std::vector<T> vector() {
308 return val;
309 }
310
311 /// methods from std::vector:
312
313 size_t size() const {
314 return val.size();
315 }
316
317 void resize(size_t count) {
318 val.resize(count);
319 }
320 void resize(size_t count, const T &v) {
321 val.resize(count, v);
322 }
323
324 void clear() {
325 val.clear();
326 }
327
328 void push_back(const T &v) {
329 val.push_back(v);
330 }
331 void pop_back() {
332 val.pop_back();
333 }
334
335 T &front() {
336 return val.front();
337 }
338 T &back() {
339 return val.back();
340 }
341};
342
343
344#endif
The field class implements the standard methods for accessing Fields. Hilapp replaces the parity acce...
Definition field.h:62
T & operator[](const int i)
And access operators - these do in practice everything already!
size_t size() const
methods from std::vector:
ReductionVector & delayed(bool b=true)
deferred(bool) turns deferred on or off. By default turns on.
void start_reduce()
For delayed reduction, reduce starts or completes the reduction operation.
void reduce()
Complete non-blocking or delayed reduction.
void init_product()
Init is to be called before every site loop.
~ReductionVector()
Destructor cleans up communications if they are in progress.
void operator=(const S &rhs)
ReductionVector & nonblocking(bool b=true)
nonblocking(bool) turns allreduce on or off. By default on.
void operator=(std::nullptr_t np)
Assignment from 0.
void init_sum()
Init is to be called before every site loop.
ReductionVector & allreduce(bool b=true)
allreduce(bool) turns allreduce on or off. By default on.
void wait()
Wait for MPI to complete, if it is currently going on.
T * data()
data() returns ptr to the raw storage
constexpr Parity ALL
bit pattern: 011
int myrank()
rank of this node
Definition com_mpi.cpp:237