HILA
Loading...
Searching...
No Matches
com_mpi.h
1#ifndef HILA_COM_MPI_H_
2#define HILA_COM_MPI_H_
3
4#include "plumbing/defs.h"
5
6#include "plumbing/lattice.h"
7
9
10
11// let us house the partitions-struct here
12namespace hila {
13class partitions_struct {
14 private:
15 unsigned _number, _mylattice;
16 bool _sync;
17
18 public:
19 unsigned number() const {
20 return _number;
21 }
22
23 void set_number(const unsigned u) {
24 _number = u;
25 }
26
27 unsigned mylattice() const {
28 return _mylattice;
29 }
30
31 void set_mylattice(const unsigned l) {
32 _mylattice = l;
33 }
34
35 bool sync() const {
36 return _sync;
37 }
38
39 void set_sync(bool s) {
40 _sync = s;
41 }
42};
43
44extern partitions_struct partitions;
45
46} // namespace hila
47
48// Pile of timers associated with MPI calls
49// clang-format off
50extern hila::timer
51 start_send_timer,
52 wait_send_timer,
53 post_receive_timer,
54 wait_receive_timer,
55 synchronize_timer,
56 reduction_timer,
57 reduction_wait_timer,
58 broadcast_timer,
59 send_timer,
60 drop_comms_timer,
61 partition_sync_timer;
62// clang-format on
63
64///***********************************************************
65/// Implementations of communication routines.
66///
67
68// The MPI tag generator
69int get_next_msg_tag();
70
71/// Obtain the MPI data type (MPI_XXX) for a particular type of native numbers.
72///
73/// @brief Return MPI data type compatible with native number type
74///
75/// Boolean flag "with_int" is used to return "type + int" - types of MPI, used
76/// in maxloc/minloc reductions
77///
78/// @tparam T hila type, hila::arithmetic_type<T> is converted into MPI_type
79/// @param size (optional) size of the MPI_type in bytes
80/// @param with_int
81/// @return MPI_datatype (e.g. MPI_INT, MPI_DOUBLE etc.)
82template <typename T>
83MPI_Datatype get_MPI_number_type(size_t &size, bool with_int = false) {
84
85 if (std::is_same<hila::arithmetic_type<T>, int>::value) {
86 size = sizeof(int);
87 return with_int ? MPI_2INT : MPI_INT;
88 } else if (std::is_same<hila::arithmetic_type<T>, unsigned>::value) {
89 size = sizeof(unsigned);
90 return with_int ? MPI_2INT : MPI_UNSIGNED; // MPI does not contain MPI_UNSIGNED_INT
91 } else if (std::is_same<hila::arithmetic_type<T>, long>::value) {
92 size = sizeof(long);
93 return with_int ? MPI_LONG_INT : MPI_LONG;
94 } else if (std::is_same<hila::arithmetic_type<T>, int64_t>::value) {
95 size = sizeof(int64_t);
96 return with_int ? MPI_LONG_INT : MPI_INT64_T; // need to use LONG_INT
97 } else if (std::is_same<hila::arithmetic_type<T>, uint64_t>::value) {
98 size = sizeof(uint64_t);
99 return with_int ? MPI_LONG_INT : MPI_UINT64_T; // ditto
100 } else if (std::is_same<hila::arithmetic_type<T>, float>::value) {
101 size = sizeof(float);
102 return with_int ? MPI_FLOAT_INT : MPI_FLOAT;
103 } else if (std::is_same<hila::arithmetic_type<T>, double>::value) {
104 size = sizeof(double);
105 return with_int ? MPI_DOUBLE_INT : MPI_DOUBLE;
106 } else if (std::is_same<hila::arithmetic_type<T>, long double>::value) {
107 size = sizeof(long double);
108 return with_int ? MPI_LONG_DOUBLE_INT : MPI_LONG_DOUBLE;
109 }
110
111 size = 1;
112 return MPI_BYTE;
113}
114
115
116template <typename T>
117MPI_Datatype get_MPI_number_type() {
118 size_t s;
119 return get_MPI_number_type<T>(s);
120}
121
122
123/// @brief Return MPI complex type equivalent to hila type
124///
125/// For example, if T is Complex<double>, get_MPI_complex_double<T>() returns MPI_C_DOUBLE_COMPLEX
126/// @tparam T type to be converted
127/// @param siz size of the type in bytes
128/// @return MPI_Datatype
129template <typename T>
130MPI_Datatype get_MPI_complex_type(size_t &siz) {
131 if constexpr (std::is_same<T, Complex<double>>::value) {
132 siz = sizeof(Complex<double>);
133 return MPI_C_DOUBLE_COMPLEX;
134 } else if constexpr (std::is_same<T, Complex<float>>::value) {
135 siz = sizeof(Complex<float>);
136 return MPI_C_FLOAT_COMPLEX;
137 } else {
138 static_assert(sizeof(T) > 0,
139 "get_MPI_complex_type<T>() called without T being a complex type");
140 return MPI_BYTE;
141 }
142}
143
144
145namespace hila {
146
147/**
148 * @brief Broadcast the value of _var_ to all MPI ranks from _rank_ (default=0).
149 *
150 * NOTE: the function must be called by all MPI ranks, otherwise the program will deadlock.
151 *
152 * The type of the variable _var_ can be any standard plain datatype (trivial type),
153 * std::string, std::vector or std::array
154 *
155 * For trivial types, the input _var_ can be non-modifiable value. In this case
156 * the broadcast value is obtained from the broadcast return value.
157 *
158 * Example:
159 * @code{.cpp}
160 * auto rnd = hila::broadcast(hila::random()); // all MPI ranks get the same random value
161 * @endcode
162 *
163 * @param var variable to be synchronized across the full
164 * @param rank MPI rank from which the
165 * @return template <typename T>
166 */
167
168
169template <typename T>
170T broadcast(T &var, int rank = 0) {
171 static_assert(std::is_trivial<T>::value, "broadcast(var) must use trivial type");
172 if (hila::check_input)
173 return var;
174
175 assert(0 <= rank && rank < hila::number_of_nodes() && "Invalid sender rank in broadcast()");
176
177 broadcast_timer.start();
178 MPI_Bcast(&var, sizeof(T), MPI_BYTE, rank, lattice->mpi_comm_lat);
179 broadcast_timer.stop();
180 return var;
181}
182
183/// Version of broadcast with non-modifiable var
184template <typename T>
185T broadcast(const T &var, int rank = 0) {
186 T tmp = var;
187 return broadcast(tmp, rank);
188}
189
190/// Broadcast for std::vector
191template <typename T>
192void broadcast(std::vector<T> &list, int rank = 0) {
193
194 static_assert(std::is_trivial<T>::value, "broadcast(std::vector<T>) must have trivial T");
195
196 if (hila::check_input)
197 return;
198
199 broadcast_timer.start();
200
201 size_t size = list.size();
202 MPI_Bcast(&size, sizeof(size_t), MPI_BYTE, rank, lattice->mpi_comm_lat);
203 if (hila::myrank() != rank) {
204 list.resize(size);
205 }
206
207 if (size > 0) {
208 // move vectors directly to the storage
209 MPI_Bcast((void *)list.data(), sizeof(T) * size, MPI_BYTE, rank, lattice->mpi_comm_lat);
210 }
211
212 broadcast_timer.stop();
213}
214
215/// And broadcast for std::array
216template <typename T, int n>
217void broadcast(std::array<T, n> &arr, int rank = 0) {
218
219 static_assert(std::is_trivial<T>::value, "broadcast(std::array<T>) must have trivial T");
220
221 if (hila::check_input || n <= 0)
222 return;
223
224 broadcast_timer.start();
225
226 // move vectors directly to the storage
227 MPI_Bcast((void *)arr.data(), sizeof(T) * n, MPI_BYTE, rank, lattice->mpi_comm_lat);
228
229 broadcast_timer.stop();
230}
231
232
233///
234/// Bare pointers cannot be broadcast
235
236template <typename T>
237void broadcast(T *var, int rank = 0) {
238 static_assert(sizeof(T) > 0 &&
239 "Do not use pointers to broadcast()-function. Use 'broadcast_array(T* arr, "
240 "int size)' to broadcast an array");
241}
242
243///
244/// Broadcast for arrays where size must be known and same for all nodes
245
246template <typename T>
247void broadcast_array(T *var, int n, int rank = 0) {
248
249 if (hila::check_input || n <= 0)
250 return;
251
252 broadcast_timer.start();
253 MPI_Bcast((void *)var, sizeof(T) * n, MPI_BYTE, rank, lattice->mpi_comm_lat);
254 broadcast_timer.stop();
255}
256
257// DO string bcasts separately
258void broadcast(std::string &r, int rank = 0);
259void broadcast(std::vector<std::string> &l, int rank = 0);
260
261/// and broadcast with two values
262template <typename T, typename U>
263void broadcast2(T &t, U &u, int rank = 0) {
264
265 if (hila::check_input)
266 return;
267
268 struct {
269 T tv;
270 U uv;
271 } s = {t, u};
272
273 hila::broadcast(s, rank);
274 t = s.tv;
275 u = s.uv;
276}
277
278
279template <typename T>
280void send_to(int to_rank, const T &data) {
281 if (hila::check_input)
282 return;
283
284 send_timer.start();
285 MPI_Send(&data, sizeof(T), MPI_BYTE, to_rank, hila::myrank(), lattice->mpi_comm_lat);
286 send_timer.stop();
287}
288
289template <typename T>
290void receive_from(int from_rank, T &data) {
291 if (hila::check_input)
292 return;
293
294 send_timer.start();
295 MPI_Recv(&data, sizeof(T), MPI_BYTE, from_rank, from_rank, lattice->mpi_comm_lat,
296 MPI_STATUS_IGNORE);
297 send_timer.stop();
298}
299
300template <typename T>
301void send_to(int to_rank, const std::vector<T> &data) {
302 if (hila::check_input)
303 return;
304
305 send_timer.start();
306 size_t s = data.size();
307 MPI_Send(&s, sizeof(size_t), MPI_BYTE, to_rank, hila::myrank(), lattice->mpi_comm_lat);
308
309 if (s > 0)
310 MPI_Send(data.data(), sizeof(T) * s, MPI_BYTE, to_rank, hila::myrank(),
311 lattice->mpi_comm_lat);
312 send_timer.stop();
313}
314
315template <typename T>
316void receive_from(int from_rank, std::vector<T> &data) {
317 if (hila::check_input)
318 return;
319
320 send_timer.start();
321 size_t s;
322 MPI_Recv(&s, sizeof(size_t), MPI_BYTE, from_rank, from_rank, lattice->mpi_comm_lat,
323 MPI_STATUS_IGNORE);
324 data.resize(s);
325
326 if (s > 0)
327 MPI_Recv(data.data(), sizeof(T) * s, MPI_BYTE, from_rank, from_rank, lattice->mpi_comm_lat,
328 MPI_STATUS_IGNORE);
329 send_timer.stop();
330}
331
332
333///
334/// Reduce an array across nodes
335
336template <typename T>
337void reduce_node_sum(T *value, int send_count, bool allreduce = true) {
338
339 if (hila::check_input || send_count == 0)
340 return;
341
342 std::vector<T> recv_data(send_count);
343 MPI_Datatype dtype;
344 dtype = get_MPI_number_type<T>();
345 reduction_timer.start();
346 if (allreduce) {
347 MPI_Allreduce((void *)value, (void *)recv_data.data(),
348 send_count * (sizeof(T) / sizeof(hila::arithmetic_type<T>)), dtype, MPI_SUM,
349 lattice->mpi_comm_lat);
350 for (int i = 0; i < send_count; i++)
351 value[i] = recv_data[i];
352 } else {
353 MPI_Reduce((void *)value, (void *)recv_data.data(),
354 send_count * (sizeof(T) / sizeof(hila::arithmetic_type<T>)), dtype, MPI_SUM, 0,
355 lattice->mpi_comm_lat);
356 if (hila::myrank() == 0)
357 for (int i = 0; i < send_count; i++)
358 value[i] = recv_data[i];
359 }
360 reduction_timer.stop();
361}
362
363///
364/// Reduce single variable across nodes.
365/// A bit suboptimal, because uses std::vector
366
367template <typename T>
368T reduce_node_sum(T &var, bool allreduce = true) {
369 hila::reduce_node_sum(&var, 1, allreduce);
370 return var;
371}
372
373// Product reduction template - so far only for int, float, dbl
374
375template <typename T>
376void reduce_node_product(T *send_data, int send_count, bool allreduce = true) {
377 std::vector<T> recv_data(send_count);
378 MPI_Datatype dtype;
379
380 if (hila::check_input)
381 return;
382
383 dtype = get_MPI_number_type<T>();
384
385 reduction_timer.start();
386 if (allreduce) {
387 MPI_Allreduce((void *)send_data, (void *)recv_data.data(), send_count, dtype, MPI_PROD,
388 lattice->mpi_comm_lat);
389 for (int i = 0; i < send_count; i++)
390 send_data[i] = recv_data[i];
391 } else {
392 MPI_Reduce((void *)send_data, (void *)recv_data.data(), send_count, dtype, MPI_PROD, 0,
393 lattice->mpi_comm_lat);
394 if (hila::myrank() == 0)
395 for (int i = 0; i < send_count; i++)
396 send_data[i] = recv_data[i];
397 }
398 reduction_timer.stop();
399}
400
401template <typename T>
402T reduce_node_product(T &var, bool allreduce = true) {
403 reduce_node_product(&var, 1, allreduce);
404 return var;
405}
406
407void set_allreduce(bool on = true);
408bool get_allreduce();
409
410
411} // namespace hila
412
413
414void hila_reduce_double_setup(double *d, int n);
415void hila_reduce_float_setup(float *d, int n);
416void hila_reduce_sums();
417void reduce_node_sum_extended(ExtendedPrecision *value, int send_count, bool allreduce = true);
418
419extern MPI_Datatype MPI_ExtendedPrecision_type;
420extern MPI_Op MPI_ExtendedPrecision_sum_op;
421
422void create_extended_MPI_type();
423void create_extended_MPI_operation();
424void extended_sum_op(void *in, void *inout, int *len, MPI_Datatype *datatype);
425
426
427template <typename T>
428void hila_reduce_sum_setup(T *value) {
429
430 using b_t = hila::arithmetic_type<T>;
431 if constexpr (std::is_same<b_t, ExtendedPrecision>::value) {
432 reduce_node_sum_extended((ExtendedPrecision *)value, sizeof(T) / sizeof(ExtendedPrecision),
433 hila::get_allreduce());
434 } else if (std::is_same<b_t, double>::value) {
435 hila_reduce_double_setup((double *)value, sizeof(T) / sizeof(double));
436 } else if (std::is_same<b_t, float>::value) {
437 hila_reduce_float_setup((float *)value, sizeof(T) / sizeof(float));
438 } else {
439 hila::reduce_node_sum(value, 1, hila::get_allreduce());
440 }
441}
442
443#endif // COMM_MPI_H
Complex definition.
Definition cmplx.h:58
This file defines all includes for HILA.
This files containts definitions for the extended precision class that allows for high precision redu...
Implement hila::swap for gauge fields.
Definition array.h:982
void broadcast_array(T *var, int n, int rank=0)
Broadcast for arrays where size must be known and same for all nodes.
Definition com_mpi.h:247
int myrank()
rank of this node
Definition com_mpi.cpp:237
int number_of_nodes()
how many nodes there are
Definition com_mpi.cpp:248
void set_allreduce(bool on=true)
set allreduce on (default) or off on the next reduction
Definition com_mpi.cpp:134
void reduce_node_sum(T *value, int send_count, bool allreduce=true)
Reduce an array across nodes.
Definition com_mpi.h:337
T broadcast(T &var, int rank=0)
Broadcast the value of var to all MPI ranks from rank (default=0).
Definition com_mpi.h:170
void broadcast2(T &t, U &u, int rank=0)
and broadcast with two values
Definition com_mpi.h:263