HILA
Loading...
Searching...
No Matches
com_mpi.cpp
1
2#include "plumbing/defs.h"
3#include "plumbing/lattice.h"
4#include "plumbing/field.h"
5#include "plumbing/com_mpi.h"
6#include "plumbing/timing.h"
7
9
10
11// declare MPI timers here too - these were externs
12
13hila::timer start_send_timer("MPI start send");
14hila::timer wait_send_timer("MPI wait send");
15hila::timer post_receive_timer("MPI post receive");
16hila::timer wait_receive_timer("MPI wait receive");
17hila::timer synchronize_timer("MPI synchronize");
18hila::timer reduction_timer("MPI reduction");
19hila::timer reduction_wait_timer("MPI reduction wait");
20hila::timer broadcast_timer("MPI broadcast");
21hila::timer send_timer("MPI send field");
22hila::timer drop_comms_timer("MPI wait drop_comms");
23hila::timer partition_sync_timer("partition sync");
24
25// let us house the partitions-struct here
26
27hila::partitions_struct hila::partitions;
28
29/* Keep track of whether MPI has been initialized */
30static bool mpi_initialized = false;
31
32////////////////////////////////////////////////////////////
33/// Reductions: do automatic coalescing of reductions
34/// if the type is float or double
35/// These functions should not be called "by hand"
36
37// buffers - first vector holds the reduction buffer,
38// second the pointers to where distribute results
39static std::vector<double> double_reduction_buffer;
40static std::vector<double *> double_reduction_ptrs;
41static int n_double = 0;
42
43static std::vector<float> float_reduction_buffer;
44static std::vector<float *> float_reduction_ptrs;
45static int n_float = 0;
46
47// static var holding the allreduce state
48static bool allreduce_on = true;
49
50void hila_reduce_double_setup(double *d, int n) {
51
52 // ensure there's enough space
53 if (n + n_double > double_reduction_buffer.size()) {
54 double_reduction_buffer.resize(n + n_double + 2);
55 double_reduction_ptrs.resize(n + n_double + 2);
56 }
57
58 for (int i = 0; i < n; i++) {
59 double_reduction_buffer[n_double + i] = d[i];
60 double_reduction_ptrs[n_double + i] = d + i;
61 }
62
63 n_double += n;
64}
65
66void hila_reduce_float_setup(float *d, int n) {
67
68 // ensure there's enough space
69 if (n + n_float > float_reduction_buffer.size()) {
70 float_reduction_buffer.resize(n + n_float + 2);
71 float_reduction_ptrs.resize(n + n_float + 2);
72 }
73
74 for (int i = 0; i < n; i++) {
75 float_reduction_buffer[n_float + i] = d[i];
76 float_reduction_ptrs[n_float + i] = d + i;
77 }
78
79 n_float += n;
80}
81
82void hila_reduce_sums() {
83
84 if (n_double > 0) {
85 std::vector<double> work(n_double);
86
87 reduction_timer.start();
88
89 if (allreduce_on) {
90 MPI_Allreduce((void *)double_reduction_buffer.data(), (void *)work.data(), n_double,
91 MPI_DOUBLE, MPI_SUM, lattice->mpi_comm_lat);
92 for (int i = 0; i < n_double; i++)
93 *(double_reduction_ptrs[i]) = work[i];
94
95 } else {
96 MPI_Reduce((void *)double_reduction_buffer.data(), work.data(), n_double, MPI_DOUBLE,
97 MPI_SUM, 0, lattice->mpi_comm_lat);
98 if (hila::myrank() == 0)
99 for (int i = 0; i < n_double; i++)
100 *(double_reduction_ptrs[i]) = work[i];
101 }
102
103 n_double = 0;
104
105 reduction_timer.stop();
106 }
107
108 if (n_float > 0) {
109 std::vector<float> work(n_float);
110
111 reduction_timer.start();
112
113 if (allreduce_on) {
114 MPI_Allreduce((void *)float_reduction_buffer.data(), work.data(), n_float, MPI_FLOAT,
115 MPI_SUM, lattice->mpi_comm_lat);
116 for (int i = 0; i < n_float; i++)
117 *(float_reduction_ptrs[i]) = work[i];
118
119 } else {
120 MPI_Reduce((void *)float_reduction_buffer.data(), work.data(), n_float, MPI_FLOAT,
121 MPI_SUM, 0, lattice->mpi_comm_lat);
122 if (hila::myrank() == 0)
123 for (int i = 0; i < n_float; i++)
124 *(float_reduction_ptrs[i]) = work[i];
125 }
126
127 n_float = 0;
128
129 reduction_timer.stop();
130 }
131}
132
133/// set allreduce on (default) or off on the next reduction
134void hila::set_allreduce(bool on) {
135 allreduce_on = on;
136}
137
138bool hila::get_allreduce() {
139 return allreduce_on;
140}
141
142////////////////////////////////////////////////////////////////////////
143
144
145/* Machine initialization */
146#include <sys/types.h>
147void hila::initialize_communications(int &argc, char ***argv) {
148 /* Init MPI */
149 if (!mpi_initialized) {
150
151#ifndef OPENMP
152 MPI_Init(&argc, argv);
153
154#else
155
156 int provided;
157 MPI_Init_thread(&argc, argv, MPI_THREAD_FUNNELED, &provided);
158 if (provided < MPI_THREAD_FUNNELED) {
159 if (hila::myrank() == 0)
160 hila::out << "MPI could not provide MPI_THREAD_FUNNELED, exiting\n";
161 MPI_Finalize();
162 exit(1);
163 }
164
165#endif
166
167 mpi_initialized = true;
168
169 // global var lattice exists, assign the mpi comms there
170 lattice.ptr()->mpi_comm_lat = MPI_COMM_WORLD;
171
172 MPI_Comm_rank(lattice->mpi_comm_lat, &lattice.ptr()->mynode.rank);
173 MPI_Comm_size(lattice->mpi_comm_lat, &lattice.ptr()->nodes.number);
174 }
175}
176
177// check if MPI is on
178bool hila::is_comm_initialized(void) {
179 return mpi_initialized;
180}
181
182/* version of exit for multinode processes -- kill all nodes */
183void hila::abort_communications(int status) {
184 if (mpi_initialized) {
185 mpi_initialized = false;
186 MPI_Abort(lattice->mpi_comm_lat, 0);
187 }
188}
189
190/* clean exit from all nodes */
191void hila::finish_communications() {
192 // turn off mpi -- this is needed to avoid mpi calls in destructors
193 mpi_initialized = false;
194 hila::about_to_finish = true;
195
196 MPI_Finalize();
197}
198
199// broadcast specialization
200void hila::broadcast(std::string &var, int rank) {
201
202 if (hila::check_input)
203 return;
204
205 int size = var.size();
206 hila::broadcast(size, rank);
207
208 if (hila::myrank() != rank) {
209 var.resize(size, ' ');
210 }
211 // copy directy to data() buffer
212 broadcast_timer.start();
213 MPI_Bcast((void *)var.data(), size, MPI_BYTE, rank, lattice->mpi_comm_lat);
214 broadcast_timer.stop();
215}
216
217void hila::broadcast(std::vector<std::string> &list, int rank) {
218
219 if (hila::check_input)
220 return;
221
222 int size = list.size();
223 hila::broadcast(size, rank);
224 list.resize(size);
225
226 for (auto &s : list) {
227 hila::broadcast(s, rank);
228 }
229}
230
231/* BASIC COMMUNICATIONS FUNCTIONS */
232
233/// Return my node number - take care to return
234/// the previous node number if mpi is being
235/// torn down (used in destructors)
236
238 static int node = 0;
239
240 if (!mpi_initialized || hila::check_input)
241 return node;
242
243 MPI_Comm_rank(lattice->mpi_comm_lat, &node);
244 return node;
245}
246
247/// Return number of nodes or "pseudo-nodes"
249 if (hila::check_input)
250 return hila::check_with_nodes;
251
252 int nodes;
253 MPI_Comm_size(lattice->mpi_comm_lat, &nodes);
254 return (nodes);
255}
256
258 synchronize_timer.start();
259 hila::synchronize_threads();
260 MPI_Barrier(lattice->mpi_comm_lat);
261 synchronize_timer.stop();
262}
263
265 synchronize_timer.start();
266 MPI_Barrier(lattice->mpi_comm_lat);
267 synchronize_timer.stop();
268}
269
270
271/// Get message tags cyclically -- defined outside classes, so that it is global and
272/// unique
273
274#define MSG_TAG_MIN 100
275#define MSG_TAG_MAX (500) // standard says that at least 32767 tags available
276
277int get_next_msg_tag() {
278 static int tag = MSG_TAG_MIN;
279 ++tag;
280 if (tag > MSG_TAG_MAX)
281 tag = MSG_TAG_MIN;
282 return tag;
283}
284
285
286/// Split the communicator to subvolumes, using MPI_Comm_split
287/// New MPI_Comm is the global mpi_comm_lat
288/// NOTE: no attempt made here to reorder the nodes
289
290void hila::split_into_partitions(int this_lattice) {
291
292 if (hila::check_input)
293 return;
294
295 if (MPI_Comm_split(MPI_COMM_WORLD, this_lattice, 0, &(lattice.ptr()->mpi_comm_lat)) != MPI_SUCCESS) {
296 hila::out0 << "MPI_Comm_split() call failed!\n";
298 }
299 // reset also the rank and numbers -fields
300 MPI_Comm_rank(lattice->mpi_comm_lat, &lattice.ptr()->mynode.rank);
301 MPI_Comm_size(lattice->mpi_comm_lat, &lattice.ptr()->nodes.number);
302}
303
304void hila::synchronize_partitions() {
305 if (partitions.number() > 1)
306 MPI_Barrier(MPI_COMM_WORLD);
307}
308
309MPI_Datatype MPI_ExtendedPrecision_type;
310MPI_Op MPI_ExtendedPrecision_sum_op;
311
312void create_extended_MPI_type() {
313 ExtendedPrecision dummy;
314 int block_lengths[3] = {1, 1, 1};
315 MPI_Aint displacements[3];
316 MPI_Datatype types[3] = {MPI_DOUBLE, MPI_DOUBLE, MPI_DOUBLE};
317
318 MPI_Aint base;
319 MPI_Get_address(&dummy, &base);
320 MPI_Get_address(&dummy.value, &displacements[0]);
321 MPI_Get_address(&dummy.compensation, &displacements[1]);
322 MPI_Get_address(&dummy.compensation2, &displacements[2]);
323 displacements[0] -= base;
324 displacements[1] -= base;
325 displacements[2] -= base;
326
327 MPI_Type_create_struct(3, block_lengths, displacements, types, &MPI_ExtendedPrecision_type);
328 MPI_Type_commit(&MPI_ExtendedPrecision_type);
329}
330
331void extended_sum_op(void *in, void *inout, int *len, MPI_Datatype *datatype) {
332 ExtendedPrecision *in_data = (ExtendedPrecision *)in;
333 ExtendedPrecision *inout_data = (ExtendedPrecision *)inout;
334
335 for (int i = 0; i < *len; i++) {
336 inout_data[i] += in_data[i];
337 }
338}
339
340void create_extended_MPI_operation() {
341 MPI_Op_create(&extended_sum_op, true, &MPI_ExtendedPrecision_sum_op);
342}
343
344
345/**
346 * @brief Custom MPI reduction for extended type that performs Kahan summation
347 *
348 * @param value Input extended dataa
349 * @param send_count Number of extended variables to reduce (default 1)
350 * @param allreduce If true, performs MPI_Allreduce, otherwise MPI_Reduce
351 */
352void reduce_node_sum_extended(ExtendedPrecision *value, int send_count, bool allreduce) {
353
354 if (hila::check_input)
355 return;
356
357 static bool init_extended_type_and_operation = true;
358 if (init_extended_type_and_operation) {
359 create_extended_MPI_type();
360 create_extended_MPI_operation();
361 init_extended_type_and_operation = false;
362 }
363
364 std::vector<ExtendedPrecision> recv_data(send_count);
365 reduction_timer.start();
366 if (allreduce) {
367 MPI_Allreduce((void *)value, (void *)recv_data.data(), send_count,
368 MPI_ExtendedPrecision_type, MPI_ExtendedPrecision_sum_op,
369 lattice->mpi_comm_lat);
370 for (int i = 0; i < send_count; i++)
371 value[i] = recv_data[i];
372 } else {
373 MPI_Reduce((void *)value, (void *)recv_data.data(), send_count, MPI_ExtendedPrecision_type,
374 MPI_ExtendedPrecision_sum_op, 0, lattice->mpi_comm_lat);
375 if (hila::myrank() == 0)
376 for (int i = 0; i < send_count; i++)
377 value[i] = recv_data[i];
378 }
379 reduction_timer.stop();
380}
lattice_struct * ptr() const
get non-const pointer to lattice_struct (cf. operator ->)
Definition lattice.h:474
This file defines all includes for HILA.
This files containts definitions for the extended precision class that allows for high precision redu...
This files containts definitions for the Field class and the classes required to define it such as fi...
void barrier()
sync MPI
Definition com_mpi.cpp:264
int myrank()
rank of this node
Definition com_mpi.cpp:237
int number_of_nodes()
how many nodes there are
Definition com_mpi.cpp:248
void synchronize()
synchronize mpi + gpu
Definition com_mpi.cpp:257
void set_allreduce(bool on=true)
set allreduce on (default) or off on the next reduction
Definition com_mpi.cpp:134
std::ostream out
this is our default output file stream
std::ostream out0
This writes output only from main process (node 0)
void split_into_partitions(int rank)
Definition com_mpi.cpp:290
T broadcast(T &var, int rank=0)
Broadcast the value of var to all MPI ranks from rank (default=0).
Definition com_mpi.h:170
void finishrun()
Normal, controlled exit - all nodes must call this. Prints timing information and information about c...