HILA
Loading...
Searching...
No Matches
com_mpi.h
1#ifndef COM_MPI_H
2#define COM_MPI_H
3
4#include "plumbing/defs.h"
5
6#include "plumbing/lattice.h"
7
8// let us house the partitions-struct here
9namespace hila {
10class partitions_struct {
11 public:
12 unsigned _number, _mylattice;
13 bool _sync;
14
15 unsigned number() {
16 return _number;
17 }
18 unsigned mylattice() {
19 return _mylattice;
20 }
21 bool sync() {
22 return _sync;
23 }
24};
25
26extern partitions_struct partitions;
27
28} // namespace hila
29
30// Pile of timers associated with MPI calls
31// clang-format off
32extern hila::timer
33 start_send_timer,
34 wait_send_timer,
35 post_receive_timer,
36 wait_receive_timer,
37 synchronize_timer,
38 reduction_timer,
39 reduction_wait_timer,
40 broadcast_timer,
41 send_timer,
42 drop_comms_timer,
43 partition_sync_timer;
44// clang-format on
45
46///***********************************************************
47/// Implementations of communication routines.
48///
49
50// The MPI tag generator
51int get_next_msg_tag();
52
53/// Obtain the MPI data type (MPI_XXX) for a particular type of native numbers.
54///
55/// @brief Return MPI data type compatible with native number type
56///
57/// Boolean flag "with_int" is used to return "type + int" - types of MPI, used
58/// in maxloc/minloc reductions
59///
60/// @tparam T hila type, hila::arithmetic_type<T> is converted into MPI_type
61/// @param size (optional) size of the MPI_type in bytes
62/// @param with_int
63/// @return MPI_datatype (e.g. MPI_INT, MPI_DOUBLE etc.)
64template <typename T>
65MPI_Datatype get_MPI_number_type(size_t &size, bool with_int = false) {
66
67 if (std::is_same<hila::arithmetic_type<T>, int>::value) {
68 size = sizeof(int);
69 return with_int ? MPI_2INT : MPI_INT;
70 } else if (std::is_same<hila::arithmetic_type<T>, unsigned>::value) {
71 size = sizeof(unsigned);
72 return with_int ? MPI_2INT : MPI_UNSIGNED; // MPI does not contain MPI_UNSIGNED_INT
73 } else if (std::is_same<hila::arithmetic_type<T>, long>::value) {
74 size = sizeof(long);
75 return with_int ? MPI_LONG_INT : MPI_LONG;
76 } else if (std::is_same<hila::arithmetic_type<T>, int64_t>::value) {
77 size = sizeof(int64_t);
78 return with_int ? MPI_LONG_INT : MPI_INT64_T; // need to use LONG_INT
79 } else if (std::is_same<hila::arithmetic_type<T>, uint64_t>::value) {
80 size = sizeof(uint64_t);
81 return with_int ? MPI_LONG_INT : MPI_UINT64_T; // ditto
82 } else if (std::is_same<hila::arithmetic_type<T>, float>::value) {
83 size = sizeof(float);
84 return with_int ? MPI_FLOAT_INT : MPI_FLOAT;
85 } else if (std::is_same<hila::arithmetic_type<T>, double>::value) {
86 size = sizeof(double);
87 return with_int ? MPI_DOUBLE_INT : MPI_DOUBLE;
88 } else if (std::is_same<hila::arithmetic_type<T>, long double>::value) {
89 size = sizeof(long double);
90 return with_int ? MPI_LONG_DOUBLE_INT : MPI_LONG_DOUBLE;
91 }
92
93 size = 1;
94 return MPI_BYTE;
95}
96
97
98template <typename T>
99MPI_Datatype get_MPI_number_type() {
100 size_t s;
101 return get_MPI_number_type<T>(s);
102}
103
104
105/// @brief Return MPI complex type equivalent to hila type
106///
107/// For example, if T is Complex<double>, get_MPI_complex_double<T>() returns MPI_C_DOUBLE_COMPLEX
108/// @tparam T type to be converted
109/// @param siz size of the type in bytes
110/// @return MPI_Datatype
111template <typename T>
112MPI_Datatype get_MPI_complex_type(size_t &siz) {
113 if constexpr (std::is_same<T, Complex<double>>::value) {
114 siz = sizeof(Complex<double>);
115 return MPI_C_DOUBLE_COMPLEX;
116 } else if constexpr (std::is_same<T, Complex<float>>::value) {
117 siz = sizeof(Complex<float>);
118 return MPI_C_FLOAT_COMPLEX;
119 } else {
120 static_assert(sizeof(T) > 0,
121 "get_MPI_complex_type<T>() called without T being a complex type");
122 return MPI_BYTE;
123 }
124}
125
126
127namespace hila {
128
129/**
130 * @brief Broadcast the value of _var_ to all MPI ranks from _rank_ (default=0).
131 *
132 * NOTE: the function must be called by all MPI ranks, otherwise the program will deadlock.
133 *
134 * The type of the variable _var_ can be any standard plain datatype (trivial type),
135 * std::string, std::vector or std::array
136 *
137 * For trivial types, the input _var_ can be non-modifiable value. In this case
138 * the broadcast value is obtained from the broadcast return value.
139 *
140 * Example:
141 * @code{.cpp}
142 * auto rnd = hila::broadcast(hila::random()); // all MPI ranks get the same random value
143 * @endcode
144 *
145 * @param var variable to be synchronized across the full
146 * @param rank MPI rank from which the
147 * @return template <typename T>
148 */
149
150
151template <typename T>
152T broadcast(T &var, int rank = 0) {
153 static_assert(std::is_trivial<T>::value, "broadcast(var) must use trivial type");
154 if (hila::check_input)
155 return var;
156
157 assert(0 <= rank && rank < hila::number_of_nodes() && "Invalid sender rank in broadcast()");
158
159 broadcast_timer.start();
160 MPI_Bcast(&var, sizeof(T), MPI_BYTE, rank, lattice.mpi_comm_lat);
161 broadcast_timer.stop();
162 return var;
163}
164
165/// Version of broadcast with non-modifiable var
166template <typename T>
167T broadcast(const T &var, int rank = 0) {
168 T tmp = var;
169 return broadcast(tmp, rank);
170}
171
172/// Broadcast for std::vector
173template <typename T>
174void broadcast(std::vector<T> &list, int rank = 0) {
175
176 static_assert(std::is_trivial<T>::value, "broadcast(std::vector<T>) must have trivial T");
177
178 if (hila::check_input)
179 return;
180
181 broadcast_timer.start();
182
183 int size = list.size();
184 MPI_Bcast(&size, sizeof(int), MPI_BYTE, rank, lattice.mpi_comm_lat);
185 if (hila::myrank() != rank) {
186 list.resize(size);
187 }
188
189 // move vectors directly to the storage
190 MPI_Bcast((void *)list.data(), sizeof(T) * size, MPI_BYTE, rank, lattice.mpi_comm_lat);
191
192 broadcast_timer.stop();
193}
194
195/// And broadcast for std::array
196template <typename T,int n>
197void broadcast(std::array<T,n> &arr, int rank = 0) {
198
199 static_assert(std::is_trivial<T>::value, "broadcast(std::array<T>) must have trivial T");
200
201 if (hila::check_input)
202 return;
203
204 broadcast_timer.start();
205
206 // move vectors directly to the storage
207 MPI_Bcast((void *)arr.data(), sizeof(T) * n, MPI_BYTE, rank, lattice.mpi_comm_lat);
208
209 broadcast_timer.stop();
210}
211
212
213
214
215///
216/// Bare pointers cannot be broadcast
217
218template <typename T>
219void broadcast(T *var, int rank = 0) {
220 static_assert(sizeof(T) > 0 &&
221 "Do not use pointers to broadcast()-function. Use 'broadcast_array(T* arr, "
222 "int size)' to broadcast an array");
223}
224
225///
226/// Broadcast for arrays where size must be known and same for all nodes
227
228template <typename T>
229void broadcast_array(T *var, int n, int rank = 0) {
230
231 if (hila::check_input)
232 return;
233
234 broadcast_timer.start();
235 MPI_Bcast((void *)var, sizeof(T) * n, MPI_BYTE, rank, lattice.mpi_comm_lat);
236 broadcast_timer.stop();
237}
238
239// DO string bcasts separately
240void broadcast(std::string &r, int rank = 0);
241void broadcast(std::vector<std::string> &l, int rank = 0);
242
243/// and broadcast with two values
244template <typename T, typename U>
245void broadcast2(T &t, U &u, int rank = 0) {
246
247 if (hila::check_input)
248 return;
249
250 struct {
251 T tv;
252 U uv;
253 } s = {t, u};
254
255 hila::broadcast(s, rank);
256 t = s.tv;
257 u = s.uv;
258}
259
260
261template <typename T>
262void send_to(int to_rank, const T &data) {
263 if (hila::check_input)
264 return;
265
266 send_timer.start();
267 MPI_Send(&data, sizeof(T), MPI_BYTE, to_rank, hila::myrank(), lattice.mpi_comm_lat);
268 send_timer.stop();
269}
270
271template <typename T>
272void receive_from(int from_rank, T &data) {
273 if (hila::check_input)
274 return;
275
276 send_timer.start();
277 MPI_Recv(&data, sizeof(T), MPI_BYTE, from_rank, from_rank, lattice.mpi_comm_lat,
278 MPI_STATUS_IGNORE);
279 send_timer.stop();
280}
281
282template <typename T>
283void send_to(int to_rank, const std::vector<T> &data) {
284 if (hila::check_input)
285 return;
286
287 send_timer.start();
288 size_t s = data.size();
289 MPI_Send(&s, sizeof(size_t), MPI_BYTE, to_rank, hila::myrank(), lattice.mpi_comm_lat);
290
291 MPI_Send(data.data(), sizeof(T) * s, MPI_BYTE, to_rank, hila::myrank(), lattice.mpi_comm_lat);
292 send_timer.stop();
293}
294
295template <typename T>
296void receive_from(int from_rank, std::vector<T> &data) {
297 if (hila::check_input)
298 return;
299
300 send_timer.start();
301 size_t s;
302 MPI_Recv(&s, sizeof(size_t), MPI_BYTE, from_rank, from_rank, lattice.mpi_comm_lat,
303 MPI_STATUS_IGNORE);
304 data.resize(s);
305
306 MPI_Recv(data.data(), sizeof(T) * s, MPI_BYTE, from_rank, from_rank, lattice.mpi_comm_lat,
307 MPI_STATUS_IGNORE);
308 send_timer.stop();
309}
310
311
312///
313/// Reduce an array across nodes
314
315template <typename T>
316void reduce_node_sum(T *value, int send_count, bool allreduce = true) {
317
318 if (hila::check_input)
319 return;
320
321 std::vector<T> recv_data(send_count);
322 MPI_Datatype dtype;
323 dtype = get_MPI_number_type<T>();
324
325 reduction_timer.start();
326 if (allreduce) {
327 MPI_Allreduce((void *)value, (void *)recv_data.data(),
328 send_count * (sizeof(T) / sizeof(hila::arithmetic_type<T>)), dtype, MPI_SUM,
329 lattice.mpi_comm_lat);
330 for (int i = 0; i < send_count; i++)
331 value[i] = recv_data[i];
332 } else {
333 MPI_Reduce((void *)value, (void *)recv_data.data(),
334 send_count * (sizeof(T) / sizeof(hila::arithmetic_type<T>)), dtype, MPI_SUM, 0,
335 lattice.mpi_comm_lat);
336 if (hila::myrank() == 0)
337 for (int i = 0; i < send_count; i++)
338 value[i] = recv_data[i];
339 }
340 reduction_timer.stop();
341}
342
343///
344/// Reduce single variable across nodes.
345/// A bit suboptimal, because uses std::vector
346
347template <typename T>
348T reduce_node_sum(T &var, bool allreduce = true) {
349 hila::reduce_node_sum(&var, 1, allreduce);
350 return var;
351}
352
353// Product reduction template - so far only for int, float, dbl
354
355template <typename T>
356void reduce_node_product(T *send_data, int send_count, bool allreduce = true) {
357 std::vector<T> recv_data(send_count);
358 MPI_Datatype dtype;
359
360 if (hila::check_input)
361 return;
362
363 dtype = get_MPI_number_type<T>();
364
365 reduction_timer.start();
366 if (allreduce) {
367 MPI_Allreduce((void *)send_data, (void *)recv_data.data(), send_count, dtype, MPI_PROD,
368 lattice.mpi_comm_lat);
369 for (int i = 0; i < send_count; i++)
370 send_data[i] = recv_data[i];
371 } else {
372 MPI_Reduce((void *)send_data, (void *)recv_data.data(), send_count, dtype, MPI_PROD, 0,
373 lattice.mpi_comm_lat);
374 if (hila::myrank() == 0)
375 for (int i = 0; i < send_count; i++)
376 send_data[i] = recv_data[i];
377 }
378 reduction_timer.stop();
379}
380
381template <typename T>
382T reduce_node_product(T &var, bool allreduce = true) {
383 reduce_node_product(&var, 1, allreduce);
384 return var;
385}
386
387void set_allreduce(bool on = true);
388bool get_allreduce();
389
390
391} // namespace hila
392
393
394void hila_reduce_double_setup(double *d, int n);
395void hila_reduce_float_setup(float *d, int n);
396void hila_reduce_sums();
397
398
399template <typename T>
400void hila_reduce_sum_setup(T *value) {
401
402 using b_t = hila::arithmetic_type<T>;
403 if (std::is_same<b_t, double>::value) {
404 hila_reduce_double_setup((double *)value, sizeof(T) / sizeof(double));
405 } else if (std::is_same<b_t, float>::value) {
406 hila_reduce_float_setup((float *)value, sizeof(T) / sizeof(float));
407 } else {
408 hila::reduce_node_sum(value, 1, hila::get_allreduce());
409 }
410}
411
412
413#endif // COMM_MPI_H
Complex definition.
Definition cmplx.h:56
This file defines all includes for HILA.
Implement hila::swap for gauge fields.
Definition array.h:981
void broadcast_array(T *var, int n, int rank=0)
Broadcast for arrays where size must be known and same for all nodes.
Definition com_mpi.h:229
int myrank()
rank of this node
Definition com_mpi.cpp:234
int number_of_nodes()
how many nodes there are
Definition com_mpi.cpp:245
void set_allreduce(bool on=true)
set allreduce on (default) or off on the next reduction
Definition com_mpi.cpp:131
void reduce_node_sum(T *value, int send_count, bool allreduce=true)
Reduce an array across nodes.
Definition com_mpi.h:316
T broadcast(T &var, int rank=0)
Broadcast the value of var to all MPI ranks from rank (default=0).
Definition com_mpi.h:152
void broadcast2(T &t, U &u, int rank=0)
and broadcast with two values
Definition com_mpi.h:245