HILA
Loading...
Searching...
No Matches
com_mpi.h
1#ifndef COM_MPI_H
2#define COM_MPI_H
3
4#include "plumbing/defs.h"
5
6#include "plumbing/lattice.h"
7
8// let us house the partitions-struct here
9namespace hila {
10class partitions_struct {
11 public:
12 unsigned _number, _mylattice;
13 bool _sync;
14
15 unsigned number() {
16 return _number;
17 }
18 unsigned mylattice() {
19 return _mylattice;
20 }
21 bool sync() {
22 return _sync;
23 }
24};
25
26extern partitions_struct partitions;
27
28} // namespace hila
29
30// Pile of timers associated with MPI calls
31// clang-format off
32extern hila::timer
33 start_send_timer,
34 wait_send_timer,
35 post_receive_timer,
36 wait_receive_timer,
37 synchronize_timer,
38 reduction_timer,
39 reduction_wait_timer,
40 broadcast_timer,
41 send_timer,
42 cancel_send_timer,
43 cancel_receive_timer,
44 partition_sync_timer;
45// clang-format on
46
47///***********************************************************
48/// Implementations of communication routines.
49///
50
51// The MPI tag generator
52int get_next_msg_tag();
53
54/// Obtain the MPI data type (MPI_XXX) for a particular type of native numbers.
55///
56/// @brief Return MPI data type compatible with native number type
57///
58/// Boolean flag "with_int" is used to return "type + int" - types of MPI, used
59/// in maxloc/minloc reductions
60///
61/// @tparam T hila type, hila::arithmetic_type<T> is converted into MPI_type
62/// @param size (optional) size of the MPI_type in bytes
63/// @param with_int
64/// @return MPI_datatype (e.g. MPI_INT, MPI_DOUBLE etc.)
65template <typename T>
66MPI_Datatype get_MPI_number_type(size_t &size, bool with_int = false) {
67
68 if (std::is_same<hila::arithmetic_type<T>, int>::value) {
69 size = sizeof(int);
70 return with_int ? MPI_2INT : MPI_INT;
71 } else if (std::is_same<hila::arithmetic_type<T>, unsigned>::value) {
72 size = sizeof(unsigned);
73 return with_int ? MPI_2INT : MPI_UNSIGNED; // MPI does not contain MPI_UNSIGNED_INT
74 } else if (std::is_same<hila::arithmetic_type<T>, long>::value) {
75 size = sizeof(long);
76 return with_int ? MPI_LONG_INT : MPI_LONG;
77 } else if (std::is_same<hila::arithmetic_type<T>, int64_t>::value) {
78 size = sizeof(int64_t);
79 return with_int ? MPI_LONG_INT : MPI_INT64_T; // need to use LONG_INT
80 } else if (std::is_same<hila::arithmetic_type<T>, uint64_t>::value) {
81 size = sizeof(uint64_t);
82 return with_int ? MPI_LONG_INT : MPI_UINT64_T; // ditto
83 } else if (std::is_same<hila::arithmetic_type<T>, float>::value) {
84 size = sizeof(float);
85 return with_int ? MPI_FLOAT_INT : MPI_FLOAT;
86 } else if (std::is_same<hila::arithmetic_type<T>, double>::value) {
87 size = sizeof(double);
88 return with_int ? MPI_DOUBLE_INT : MPI_DOUBLE;
89 } else if (std::is_same<hila::arithmetic_type<T>, long double>::value) {
90 size = sizeof(long double);
91 return with_int ? MPI_LONG_DOUBLE_INT : MPI_LONG_DOUBLE;
92 }
93
94 size = 1;
95 return MPI_BYTE;
96}
97
98
99template <typename T>
100MPI_Datatype get_MPI_number_type() {
101 size_t s;
102 return get_MPI_number_type<T>(s);
103}
104
105
106/// @brief Return MPI complex type equivalent to hila type
107///
108/// For example, if T is Complex<double>, get_MPI_complex_double<T>() returns MPI_C_DOUBLE_COMPLEX
109/// @tparam T type to be converted
110/// @param siz size of the type in bytes
111/// @return MPI_Datatype
112template <typename T>
113MPI_Datatype get_MPI_complex_type(size_t &siz) {
114 if constexpr (std::is_same<T, Complex<double>>::value) {
115 siz = sizeof(Complex<double>);
116 return MPI_C_DOUBLE_COMPLEX;
117 } else if constexpr (std::is_same<T, Complex<float>>::value) {
118 siz = sizeof(Complex<float>);
119 return MPI_C_FLOAT_COMPLEX;
120 } else {
121 static_assert(sizeof(T) > 0,
122 "get_MPI_complex_type<T>() called without T being a complex type");
123 return MPI_BYTE;
124 }
125}
126
127
128namespace hila {
129
130/**
131 * @brief Broadcast the value of _var_ to all MPI ranks from _rank_ (default=0).
132 *
133 * NOTE: the function must be called by all MPI ranks, otherwise the program will deadlock.
134 *
135 * The type of the variable _var_ can be any standard plain datatype (trivial type),
136 * std::string, std::vector or std::array
137 *
138 * For trivial types, the input _var_ can be non-modifiable value. In this case
139 * the broadcast value is obtained from the broadcast return value.
140 *
141 * Example:
142 * @code{.cpp}
143 * auto rnd = hila::broadcast(hila::random()); // all MPI ranks get the same random value
144 * @endcode
145 *
146 * @param var variable to be synchronized across the full
147 * @param rank MPI rank from which the
148 * @return template <typename T>
149 */
150
151
152template <typename T>
153T broadcast(T &var, int rank = 0) {
154 static_assert(std::is_trivial<T>::value, "broadcast(var) must use trivial type");
155 if (hila::check_input)
156 return var;
157
158 assert(0 <= rank && rank < hila::number_of_nodes() && "Invalid sender rank in broadcast()");
159
160 broadcast_timer.start();
161 MPI_Bcast(&var, sizeof(T), MPI_BYTE, rank, lattice.mpi_comm_lat);
162 broadcast_timer.stop();
163 return var;
164}
165
166/// Version of broadcast with non-modifiable var
167template <typename T>
168T broadcast(const T &var, int rank = 0) {
169 T tmp = var;
170 return broadcast(tmp, rank);
171}
172
173/// Broadcast for std::vector
174template <typename T>
175void broadcast(std::vector<T> &list, int rank = 0) {
176
177 static_assert(std::is_trivial<T>::value, "broadcast(std::vector<T>) must have trivial T");
178
179 if (hila::check_input)
180 return;
181
182 broadcast_timer.start();
183
184 int size = list.size();
185 MPI_Bcast(&size, sizeof(int), MPI_BYTE, rank, lattice.mpi_comm_lat);
186 if (hila::myrank() != rank) {
187 list.resize(size);
188 }
189
190 // move vectors directly to the storage
191 MPI_Bcast((void *)list.data(), sizeof(T) * size, MPI_BYTE, rank, lattice.mpi_comm_lat);
192
193 broadcast_timer.stop();
194}
195
196/// And broadcast for std::array
197template <typename T,int n>
198void broadcast(std::array<T,n> &arr, int rank = 0) {
199
200 static_assert(std::is_trivial<T>::value, "broadcast(std::array<T>) must have trivial T");
201
202 if (hila::check_input)
203 return;
204
205 broadcast_timer.start();
206
207 // move vectors directly to the storage
208 MPI_Bcast((void *)arr.data(), sizeof(T) * n, MPI_BYTE, rank, lattice.mpi_comm_lat);
209
210 broadcast_timer.stop();
211}
212
213
214
215
216///
217/// Bare pointers cannot be broadcast
218
219template <typename T>
220void broadcast(T *var, int rank = 0) {
221 static_assert(sizeof(T) > 0 &&
222 "Do not use pointers to broadcast()-function. Use 'broadcast_array(T* arr, "
223 "int size)' to broadcast an array");
224}
225
226///
227/// Broadcast for arrays where size must be known and same for all nodes
228
229template <typename T>
230void broadcast_array(T *var, int n, int rank = 0) {
231
232 if (hila::check_input)
233 return;
234
235 broadcast_timer.start();
236 MPI_Bcast((void *)var, sizeof(T) * n, MPI_BYTE, rank, lattice.mpi_comm_lat);
237 broadcast_timer.stop();
238}
239
240// DO string bcasts separately
241void broadcast(std::string &r, int rank = 0);
242void broadcast(std::vector<std::string> &l, int rank = 0);
243
244/// and broadcast with two values
245template <typename T, typename U>
246void broadcast2(T &t, U &u, int rank = 0) {
247
248 if (hila::check_input)
249 return;
250
251 struct {
252 T tv;
253 U uv;
254 } s = {t, u};
255
256 hila::broadcast(s, rank);
257 t = s.tv;
258 u = s.uv;
259}
260
261
262template <typename T>
263void send_to(int to_rank, const T &data) {
264 if (hila::check_input)
265 return;
266
267 send_timer.start();
268 MPI_Send(&data, sizeof(T), MPI_BYTE, to_rank, hila::myrank(), lattice.mpi_comm_lat);
269 send_timer.stop();
270}
271
272template <typename T>
273void receive_from(int from_rank, T &data) {
274 if (hila::check_input)
275 return;
276
277 send_timer.start();
278 MPI_Recv(&data, sizeof(T), MPI_BYTE, from_rank, from_rank, lattice.mpi_comm_lat,
279 MPI_STATUS_IGNORE);
280 send_timer.stop();
281}
282
283template <typename T>
284void send_to(int to_rank, const std::vector<T> &data) {
285 if (hila::check_input)
286 return;
287
288 send_timer.start();
289 size_t s = data.size();
290 MPI_Send(&s, sizeof(size_t), MPI_BYTE, to_rank, hila::myrank(), lattice.mpi_comm_lat);
291
292 MPI_Send(data.data(), sizeof(T) * s, MPI_BYTE, to_rank, hila::myrank(), lattice.mpi_comm_lat);
293 send_timer.stop();
294}
295
296template <typename T>
297void receive_from(int from_rank, std::vector<T> &data) {
298 if (hila::check_input)
299 return;
300
301 send_timer.start();
302 size_t s;
303 MPI_Recv(&s, sizeof(size_t), MPI_BYTE, from_rank, from_rank, lattice.mpi_comm_lat,
304 MPI_STATUS_IGNORE);
305 data.resize(s);
306
307 MPI_Recv(data.data(), sizeof(T) * s, MPI_BYTE, from_rank, from_rank, lattice.mpi_comm_lat,
308 MPI_STATUS_IGNORE);
309 send_timer.stop();
310}
311
312
313///
314/// Reduce an array across nodes
315
316template <typename T>
317void reduce_node_sum(T *value, int send_count, bool allreduce = true) {
318
319 if (hila::check_input)
320 return;
321
322 std::vector<T> recv_data(send_count);
323 MPI_Datatype dtype;
324 dtype = get_MPI_number_type<T>();
325
326 reduction_timer.start();
327 if (allreduce) {
328 MPI_Allreduce((void *)value, (void *)recv_data.data(),
329 send_count * (sizeof(T) / sizeof(hila::arithmetic_type<T>)), dtype, MPI_SUM,
330 lattice.mpi_comm_lat);
331 for (int i = 0; i < send_count; i++)
332 value[i] = recv_data[i];
333 } else {
334 MPI_Reduce((void *)value, (void *)recv_data.data(),
335 send_count * (sizeof(T) / sizeof(hila::arithmetic_type<T>)), dtype, MPI_SUM, 0,
336 lattice.mpi_comm_lat);
337 if (hila::myrank() == 0)
338 for (int i = 0; i < send_count; i++)
339 value[i] = recv_data[i];
340 }
341 reduction_timer.stop();
342}
343
344///
345/// Reduce single variable across nodes.
346/// A bit suboptimal, because uses std::vector
347
348template <typename T>
349T reduce_node_sum(T &var, bool allreduce = true) {
350 hila::reduce_node_sum(&var, 1, allreduce);
351 return var;
352}
353
354// Product reduction template - so far only for int, float, dbl
355
356template <typename T>
357void reduce_node_product(T *send_data, int send_count, bool allreduce = true) {
358 std::vector<T> recv_data(send_count);
359 MPI_Datatype dtype;
360
361 if (hila::check_input)
362 return;
363
364 dtype = get_MPI_number_type<T>();
365
366 reduction_timer.start();
367 if (allreduce) {
368 MPI_Allreduce((void *)send_data, (void *)recv_data.data(), send_count, dtype, MPI_PROD,
369 lattice.mpi_comm_lat);
370 for (int i = 0; i < send_count; i++)
371 send_data[i] = recv_data[i];
372 } else {
373 MPI_Reduce((void *)send_data, (void *)recv_data.data(), send_count, dtype, MPI_PROD, 0,
374 lattice.mpi_comm_lat);
375 if (hila::myrank() == 0)
376 for (int i = 0; i < send_count; i++)
377 send_data[i] = recv_data[i];
378 }
379 reduction_timer.stop();
380}
381
382template <typename T>
383T reduce_node_product(T &var, bool allreduce = true) {
384 reduce_node_product(&var, 1, allreduce);
385 return var;
386}
387
388void set_allreduce(bool on = true);
389bool get_allreduce();
390
391
392} // namespace hila
393
394
395void hila_reduce_double_setup(double *d, int n);
396void hila_reduce_float_setup(float *d, int n);
397void hila_reduce_sums();
398
399
400template <typename T>
401void hila_reduce_sum_setup(T *value) {
402
403 using b_t = hila::arithmetic_type<T>;
404 if (std::is_same<b_t, double>::value) {
405 hila_reduce_double_setup((double *)value, sizeof(T) / sizeof(double));
406 } else if (std::is_same<b_t, float>::value) {
407 hila_reduce_float_setup((float *)value, sizeof(T) / sizeof(float));
408 } else {
409 hila::reduce_node_sum(value, 1, hila::get_allreduce());
410 }
411}
412
413
414#endif // COMM_MPI_H
Complex definition.
Definition cmplx.h:56
This file defines all includes for HILA.
Implement hila::swap for gauge fields.
Definition array.h:981
void broadcast_array(T *var, int n, int rank=0)
Broadcast for arrays where size must be known and same for all nodes.
Definition com_mpi.h:230
int myrank()
rank of this node
Definition com_mpi.cpp:235
int number_of_nodes()
how many nodes there are
Definition com_mpi.cpp:246
void set_allreduce(bool on=true)
set allreduce on (default) or off on the next reduction
Definition com_mpi.cpp:132
void reduce_node_sum(T *value, int send_count, bool allreduce=true)
Reduce an array across nodes.
Definition com_mpi.h:317
T broadcast(T &var, int rank=0)
Broadcast the value of var to all MPI ranks from rank (default=0).
Definition com_mpi.h:153
void broadcast2(T &t, U &u, int rank=0)
and broadcast with two values
Definition com_mpi.h:246