HILA
Loading...
Searching...
No Matches
com_mpi.h
1#ifndef COM_MPI_H
2#define COM_MPI_H
3
4#include "plumbing/defs.h"
5
6#include "plumbing/lattice.h"
7
8// let us house the partitions-struct here
9namespace hila {
10class partitions_struct {
11 private:
12 unsigned _number, _mylattice;
13 bool _sync;
14
15 public:
16
17 unsigned number() const {
18 return _number;
19 }
20
21 void set_number(const unsigned u) {
22 _number = u;
23 }
24
25 unsigned mylattice() const {
26 return _mylattice;
27 }
28
29 void set_mylattice(const unsigned l) {
30 _mylattice = l;
31 }
32
33 bool sync() const {
34 return _sync;
35 }
36
37 void set_sync(bool s) {
38 _sync = s;
39 }
40};
41
42extern partitions_struct partitions;
43
44} // namespace hila
45
46// Pile of timers associated with MPI calls
47// clang-format off
48extern hila::timer
49 start_send_timer,
50 wait_send_timer,
51 post_receive_timer,
52 wait_receive_timer,
53 synchronize_timer,
54 reduction_timer,
55 reduction_wait_timer,
56 broadcast_timer,
57 send_timer,
58 drop_comms_timer,
59 partition_sync_timer;
60// clang-format on
61
62///***********************************************************
63/// Implementations of communication routines.
64///
65
66// The MPI tag generator
67int get_next_msg_tag();
68
69/// Obtain the MPI data type (MPI_XXX) for a particular type of native numbers.
70///
71/// @brief Return MPI data type compatible with native number type
72///
73/// Boolean flag "with_int" is used to return "type + int" - types of MPI, used
74/// in maxloc/minloc reductions
75///
76/// @tparam T hila type, hila::arithmetic_type<T> is converted into MPI_type
77/// @param size (optional) size of the MPI_type in bytes
78/// @param with_int
79/// @return MPI_datatype (e.g. MPI_INT, MPI_DOUBLE etc.)
80template <typename T>
81MPI_Datatype get_MPI_number_type(size_t &size, bool with_int = false) {
82
83 if (std::is_same<hila::arithmetic_type<T>, int>::value) {
84 size = sizeof(int);
85 return with_int ? MPI_2INT : MPI_INT;
86 } else if (std::is_same<hila::arithmetic_type<T>, unsigned>::value) {
87 size = sizeof(unsigned);
88 return with_int ? MPI_2INT : MPI_UNSIGNED; // MPI does not contain MPI_UNSIGNED_INT
89 } else if (std::is_same<hila::arithmetic_type<T>, long>::value) {
90 size = sizeof(long);
91 return with_int ? MPI_LONG_INT : MPI_LONG;
92 } else if (std::is_same<hila::arithmetic_type<T>, int64_t>::value) {
93 size = sizeof(int64_t);
94 return with_int ? MPI_LONG_INT : MPI_INT64_T; // need to use LONG_INT
95 } else if (std::is_same<hila::arithmetic_type<T>, uint64_t>::value) {
96 size = sizeof(uint64_t);
97 return with_int ? MPI_LONG_INT : MPI_UINT64_T; // ditto
98 } else if (std::is_same<hila::arithmetic_type<T>, float>::value) {
99 size = sizeof(float);
100 return with_int ? MPI_FLOAT_INT : MPI_FLOAT;
101 } else if (std::is_same<hila::arithmetic_type<T>, double>::value) {
102 size = sizeof(double);
103 return with_int ? MPI_DOUBLE_INT : MPI_DOUBLE;
104 } else if (std::is_same<hila::arithmetic_type<T>, long double>::value) {
105 size = sizeof(long double);
106 return with_int ? MPI_LONG_DOUBLE_INT : MPI_LONG_DOUBLE;
107 }
108
109 size = 1;
110 return MPI_BYTE;
111}
112
113
114template <typename T>
115MPI_Datatype get_MPI_number_type() {
116 size_t s;
117 return get_MPI_number_type<T>(s);
118}
119
120
121/// @brief Return MPI complex type equivalent to hila type
122///
123/// For example, if T is Complex<double>, get_MPI_complex_double<T>() returns MPI_C_DOUBLE_COMPLEX
124/// @tparam T type to be converted
125/// @param siz size of the type in bytes
126/// @return MPI_Datatype
127template <typename T>
128MPI_Datatype get_MPI_complex_type(size_t &siz) {
129 if constexpr (std::is_same<T, Complex<double>>::value) {
130 siz = sizeof(Complex<double>);
131 return MPI_C_DOUBLE_COMPLEX;
132 } else if constexpr (std::is_same<T, Complex<float>>::value) {
133 siz = sizeof(Complex<float>);
134 return MPI_C_FLOAT_COMPLEX;
135 } else {
136 static_assert(sizeof(T) > 0,
137 "get_MPI_complex_type<T>() called without T being a complex type");
138 return MPI_BYTE;
139 }
140}
141
142
143namespace hila {
144
145/**
146 * @brief Broadcast the value of _var_ to all MPI ranks from _rank_ (default=0).
147 *
148 * NOTE: the function must be called by all MPI ranks, otherwise the program will deadlock.
149 *
150 * The type of the variable _var_ can be any standard plain datatype (trivial type),
151 * std::string, std::vector or std::array
152 *
153 * For trivial types, the input _var_ can be non-modifiable value. In this case
154 * the broadcast value is obtained from the broadcast return value.
155 *
156 * Example:
157 * @code{.cpp}
158 * auto rnd = hila::broadcast(hila::random()); // all MPI ranks get the same random value
159 * @endcode
160 *
161 * @param var variable to be synchronized across the full
162 * @param rank MPI rank from which the
163 * @return template <typename T>
164 */
165
166
167template <typename T>
168T broadcast(T &var, int rank = 0) {
169 static_assert(std::is_trivial<T>::value, "broadcast(var) must use trivial type");
170 if (hila::check_input)
171 return var;
172
173 assert(0 <= rank && rank < hila::number_of_nodes() && "Invalid sender rank in broadcast()");
174
175 broadcast_timer.start();
176 MPI_Bcast(&var, sizeof(T), MPI_BYTE, rank, lattice.mpi_comm_lat);
177 broadcast_timer.stop();
178 return var;
179}
180
181/// Version of broadcast with non-modifiable var
182template <typename T>
183T broadcast(const T &var, int rank = 0) {
184 T tmp = var;
185 return broadcast(tmp, rank);
186}
187
188/// Broadcast for std::vector
189template <typename T>
190void broadcast(std::vector<T> &list, int rank = 0) {
191
192 static_assert(std::is_trivial<T>::value, "broadcast(std::vector<T>) must have trivial T");
193
194 if (hila::check_input)
195 return;
196
197 broadcast_timer.start();
198
199 int size = list.size();
200 MPI_Bcast(&size, sizeof(int), MPI_BYTE, rank, lattice.mpi_comm_lat);
201 if (hila::myrank() != rank) {
202 list.resize(size);
203 }
204
205 // move vectors directly to the storage
206 MPI_Bcast((void *)list.data(), sizeof(T) * size, MPI_BYTE, rank, lattice.mpi_comm_lat);
207
208 broadcast_timer.stop();
209}
210
211/// And broadcast for std::array
212template <typename T,int n>
213void broadcast(std::array<T,n> &arr, int rank = 0) {
214
215 static_assert(std::is_trivial<T>::value, "broadcast(std::array<T>) must have trivial T");
216
217 if (hila::check_input)
218 return;
219
220 broadcast_timer.start();
221
222 // move vectors directly to the storage
223 MPI_Bcast((void *)arr.data(), sizeof(T) * n, MPI_BYTE, rank, lattice.mpi_comm_lat);
224
225 broadcast_timer.stop();
226}
227
228
229
230
231///
232/// Bare pointers cannot be broadcast
233
234template <typename T>
235void broadcast(T *var, int rank = 0) {
236 static_assert(sizeof(T) > 0 &&
237 "Do not use pointers to broadcast()-function. Use 'broadcast_array(T* arr, "
238 "int size)' to broadcast an array");
239}
240
241///
242/// Broadcast for arrays where size must be known and same for all nodes
243
244template <typename T>
245void broadcast_array(T *var, int n, int rank = 0) {
246
247 if (hila::check_input)
248 return;
249
250 broadcast_timer.start();
251 MPI_Bcast((void *)var, sizeof(T) * n, MPI_BYTE, rank, lattice.mpi_comm_lat);
252 broadcast_timer.stop();
253}
254
255// DO string bcasts separately
256void broadcast(std::string &r, int rank = 0);
257void broadcast(std::vector<std::string> &l, int rank = 0);
258
259/// and broadcast with two values
260template <typename T, typename U>
261void broadcast2(T &t, U &u, int rank = 0) {
262
263 if (hila::check_input)
264 return;
265
266 struct {
267 T tv;
268 U uv;
269 } s = {t, u};
270
271 hila::broadcast(s, rank);
272 t = s.tv;
273 u = s.uv;
274}
275
276
277template <typename T>
278void send_to(int to_rank, const T &data) {
279 if (hila::check_input)
280 return;
281
282 send_timer.start();
283 MPI_Send(&data, sizeof(T), MPI_BYTE, to_rank, hila::myrank(), lattice.mpi_comm_lat);
284 send_timer.stop();
285}
286
287template <typename T>
288void receive_from(int from_rank, T &data) {
289 if (hila::check_input)
290 return;
291
292 send_timer.start();
293 MPI_Recv(&data, sizeof(T), MPI_BYTE, from_rank, from_rank, lattice.mpi_comm_lat,
294 MPI_STATUS_IGNORE);
295 send_timer.stop();
296}
297
298template <typename T>
299void send_to(int to_rank, const std::vector<T> &data) {
300 if (hila::check_input)
301 return;
302
303 send_timer.start();
304 size_t s = data.size();
305 MPI_Send(&s, sizeof(size_t), MPI_BYTE, to_rank, hila::myrank(), lattice.mpi_comm_lat);
306
307 MPI_Send(data.data(), sizeof(T) * s, MPI_BYTE, to_rank, hila::myrank(), lattice.mpi_comm_lat);
308 send_timer.stop();
309}
310
311template <typename T>
312void receive_from(int from_rank, std::vector<T> &data) {
313 if (hila::check_input)
314 return;
315
316 send_timer.start();
317 size_t s;
318 MPI_Recv(&s, sizeof(size_t), MPI_BYTE, from_rank, from_rank, lattice.mpi_comm_lat,
319 MPI_STATUS_IGNORE);
320 data.resize(s);
321
322 MPI_Recv(data.data(), sizeof(T) * s, MPI_BYTE, from_rank, from_rank, lattice.mpi_comm_lat,
323 MPI_STATUS_IGNORE);
324 send_timer.stop();
325}
326
327
328///
329/// Reduce an array across nodes
330
331template <typename T>
332void reduce_node_sum(T *value, int send_count, bool allreduce = true) {
333
334 if (hila::check_input)
335 return;
336
337 std::vector<T> recv_data(send_count);
338 MPI_Datatype dtype;
339 dtype = get_MPI_number_type<T>();
340
341 reduction_timer.start();
342 if (allreduce) {
343 MPI_Allreduce((void *)value, (void *)recv_data.data(),
344 send_count * (sizeof(T) / sizeof(hila::arithmetic_type<T>)), dtype, MPI_SUM,
345 lattice.mpi_comm_lat);
346 for (int i = 0; i < send_count; i++)
347 value[i] = recv_data[i];
348 } else {
349 MPI_Reduce((void *)value, (void *)recv_data.data(),
350 send_count * (sizeof(T) / sizeof(hila::arithmetic_type<T>)), dtype, MPI_SUM, 0,
351 lattice.mpi_comm_lat);
352 if (hila::myrank() == 0)
353 for (int i = 0; i < send_count; i++)
354 value[i] = recv_data[i];
355 }
356 reduction_timer.stop();
357}
358
359///
360/// Reduce single variable across nodes.
361/// A bit suboptimal, because uses std::vector
362
363template <typename T>
364T reduce_node_sum(T &var, bool allreduce = true) {
365 hila::reduce_node_sum(&var, 1, allreduce);
366 return var;
367}
368
369// Product reduction template - so far only for int, float, dbl
370
371template <typename T>
372void reduce_node_product(T *send_data, int send_count, bool allreduce = true) {
373 std::vector<T> recv_data(send_count);
374 MPI_Datatype dtype;
375
376 if (hila::check_input)
377 return;
378
379 dtype = get_MPI_number_type<T>();
380
381 reduction_timer.start();
382 if (allreduce) {
383 MPI_Allreduce((void *)send_data, (void *)recv_data.data(), send_count, dtype, MPI_PROD,
384 lattice.mpi_comm_lat);
385 for (int i = 0; i < send_count; i++)
386 send_data[i] = recv_data[i];
387 } else {
388 MPI_Reduce((void *)send_data, (void *)recv_data.data(), send_count, dtype, MPI_PROD, 0,
389 lattice.mpi_comm_lat);
390 if (hila::myrank() == 0)
391 for (int i = 0; i < send_count; i++)
392 send_data[i] = recv_data[i];
393 }
394 reduction_timer.stop();
395}
396
397template <typename T>
398T reduce_node_product(T &var, bool allreduce = true) {
399 reduce_node_product(&var, 1, allreduce);
400 return var;
401}
402
403void set_allreduce(bool on = true);
404bool get_allreduce();
405
406
407} // namespace hila
408
409
410void hila_reduce_double_setup(double *d, int n);
411void hila_reduce_float_setup(float *d, int n);
412void hila_reduce_sums();
413
414
415template <typename T>
416void hila_reduce_sum_setup(T *value) {
417
418 using b_t = hila::arithmetic_type<T>;
419 if (std::is_same<b_t, double>::value) {
420 hila_reduce_double_setup((double *)value, sizeof(T) / sizeof(double));
421 } else if (std::is_same<b_t, float>::value) {
422 hila_reduce_float_setup((float *)value, sizeof(T) / sizeof(float));
423 } else {
424 hila::reduce_node_sum(value, 1, hila::get_allreduce());
425 }
426}
427
428
429#endif // COMM_MPI_H
Complex definition.
Definition cmplx.h:56
This file defines all includes for HILA.
Implement hila::swap for gauge fields.
Definition array.h:981
void broadcast_array(T *var, int n, int rank=0)
Broadcast for arrays where size must be known and same for all nodes.
Definition com_mpi.h:245
int myrank()
rank of this node
Definition com_mpi.cpp:234
int number_of_nodes()
how many nodes there are
Definition com_mpi.cpp:245
void set_allreduce(bool on=true)
set allreduce on (default) or off on the next reduction
Definition com_mpi.cpp:131
void reduce_node_sum(T *value, int send_count, bool allreduce=true)
Reduce an array across nodes.
Definition com_mpi.h:332
T broadcast(T &var, int rank=0)
Broadcast the value of var to all MPI ranks from rank (default=0).
Definition com_mpi.h:168
void broadcast2(T &t, U &u, int rank=0)
and broadcast with two values
Definition com_mpi.h:261