HILA
Loading...
Searching...
No Matches
com_mpi.cpp
1
2#include "plumbing/defs.h"
3#include "plumbing/lattice.h"
4#include "plumbing/field.h"
5#include "plumbing/com_mpi.h"
6#include "plumbing/timing.h"
7
8// declare MPI timers here too - these were externs
9
10hila::timer start_send_timer("MPI start send");
11hila::timer wait_send_timer("MPI wait send");
12hila::timer post_receive_timer("MPI post receive");
13hila::timer wait_receive_timer("MPI wait receive");
14hila::timer synchronize_timer("MPI synchronize");
15hila::timer reduction_timer("MPI reduction");
16hila::timer reduction_wait_timer("MPI reduction wait");
17hila::timer broadcast_timer("MPI broadcast");
18hila::timer send_timer("MPI send field");
19hila::timer cancel_send_timer("MPI cancel send");
20hila::timer cancel_receive_timer("MPI cancel receive");
21hila::timer partition_sync_timer("partition sync");
22
23// let us house the partitions-struct here
24
25hila::partitions_struct hila::partitions;
26
27/* Keep track of whether MPI has been initialized */
28static bool mpi_initialized = false;
29
30////////////////////////////////////////////////////////////
31/// Reductions: do automatic coalescing of reductions
32/// if the type is float or double
33/// These functions should not be called "by hand"
34
35// buffers - first vector holds the reduction buffer,
36// second the pointers to where distribute results
37static std::vector<double> double_reduction_buffer;
38static std::vector<double *> double_reduction_ptrs;
39static int n_double = 0;
40
41static std::vector<float> float_reduction_buffer;
42static std::vector<float *> float_reduction_ptrs;
43static int n_float = 0;
44
45// static var holding the allreduce state
46static bool allreduce_on = true;
47
48void hila_reduce_double_setup(double *d, int n) {
49
50 // ensure there's enough space
51 if (n + n_double > double_reduction_buffer.size()) {
52 double_reduction_buffer.resize(n + n_double + 2);
53 double_reduction_ptrs.resize(n + n_double + 2);
54 }
55
56 for (int i = 0; i < n; i++) {
57 double_reduction_buffer[n_double + i] = d[i];
58 double_reduction_ptrs[n_double + i] = d + i;
59 }
60
61 n_double += n;
62}
63
64void hila_reduce_float_setup(float *d, int n) {
65
66 // ensure there's enough space
67 if (n + n_float > float_reduction_buffer.size()) {
68 float_reduction_buffer.resize(n + n_float + 2);
69 float_reduction_ptrs.resize(n + n_float + 2);
70 }
71
72 for (int i = 0; i < n; i++) {
73 float_reduction_buffer[n_float + i] = d[i];
74 float_reduction_ptrs[n_float + i] = d + i;
75 }
76
77 n_float += n;
78}
79
80void hila_reduce_sums() {
81
82 if (n_double > 0) {
83 double work[n_double];
84
85 reduction_timer.start();
86
87 if (allreduce_on) {
88 MPI_Allreduce((void *)double_reduction_buffer.data(), work, n_double,
89 MPI_DOUBLE, MPI_SUM, lattice.mpi_comm_lat);
90 for (int i = 0; i < n_double; i++)
91 *(double_reduction_ptrs[i]) = work[i];
92
93 } else {
94 MPI_Reduce((void *)double_reduction_buffer.data(), work, n_double,
95 MPI_DOUBLE, MPI_SUM, 0, lattice.mpi_comm_lat);
96 if (hila::myrank() == 0)
97 for (int i = 0; i < n_double; i++)
98 *(double_reduction_ptrs[i]) = work[i];
99 }
100
101 n_double = 0;
102
103 reduction_timer.stop();
104 }
105
106 if (n_float > 0) {
107 float work[n_float];
108
109 reduction_timer.start();
110
111 if (allreduce_on) {
112 MPI_Allreduce((void *)float_reduction_buffer.data(), work, n_float,
113 MPI_FLOAT, MPI_SUM, lattice.mpi_comm_lat);
114 for (int i = 0; i < n_float; i++)
115 *(float_reduction_ptrs[i]) = work[i];
116
117 } else {
118 MPI_Reduce((void *)float_reduction_buffer.data(), work, n_float, MPI_FLOAT,
119 MPI_SUM, 0, lattice.mpi_comm_lat);
120 if (hila::myrank() == 0)
121 for (int i = 0; i < n_float; i++)
122 *(float_reduction_ptrs[i]) = work[i];
123 }
124
125 n_float = 0;
126
127 reduction_timer.stop();
128 }
129}
130
131/// set allreduce on (default) or off on the next reduction
132void hila::set_allreduce(bool on) {
133 allreduce_on = on;
134}
135
136bool hila::get_allreduce() {
137 return allreduce_on;
138}
139
140////////////////////////////////////////////////////////////////////////
141
142
143/* Machine initialization */
144#include <sys/types.h>
145void initialize_communications(int &argc, char ***argv) {
146 /* Init MPI */
147 if (!mpi_initialized) {
148
149#ifndef OPENMP
150 MPI_Init(&argc, argv);
151
152#else
153
154 int provided;
155 MPI_Init_thread(&argc, argv, MPI_THREAD_FUNNELED, &provided);
156 if (provided < MPI_THREAD_FUNNELED) {
157 if (hila::myrank() == 0)
158 hila::out << "MPI could not provide MPI_THREAD_FUNNELED, exiting\n";
159 MPI_Finalize();
160 exit(1);
161 }
162
163#endif
164
165 mpi_initialized = true;
166
167 // global var lattice exists, assign the mpi comms there
168 lattice.mpi_comm_lat = MPI_COMM_WORLD;
169
170 MPI_Comm_rank(lattice.mpi_comm_lat, &lattice.mynode.rank);
171 MPI_Comm_size(lattice.mpi_comm_lat, &lattice.nodes.number);
172 }
173}
174
175// check if MPI is on
176bool is_comm_initialized(void) {
177 return mpi_initialized;
178}
179
180/* version of exit for multinode processes -- kill all nodes */
181void abort_communications(int status) {
182 if (mpi_initialized) {
183 mpi_initialized = false;
184 MPI_Abort(lattices[0]->mpi_comm_lat, 0);
185 }
186}
187
188/* clean exit from all nodes */
189void finish_communications() {
190 // turn off mpi -- this is needed to avoid mpi calls in destructors
191 mpi_initialized = false;
192 hila::about_to_finish = true;
193
194 MPI_Finalize();
195}
196
197// broadcast specialization
198void hila::broadcast(std::string &var, int rank) {
199
200 if (hila::check_input)
201 return;
202
203 int size = var.size();
204 hila::broadcast(size,rank);
205
206 if (hila::myrank() != rank) {
207 var.resize(size, ' ');
208 }
209 // copy directy to data() buffer
210 broadcast_timer.start();
211 MPI_Bcast((void *)var.data(), size, MPI_BYTE, rank, lattice.mpi_comm_lat);
212 broadcast_timer.stop();
213}
214
215void hila::broadcast(std::vector<std::string> &list, int rank) {
216
217 if (hila::check_input)
218 return;
219
220 int size = list.size();
221 hila::broadcast(size,rank);
222 list.resize(size);
223
224 for (auto &s : list) {
225 hila::broadcast(s,rank);
226 }
227}
228
229/* BASIC COMMUNICATIONS FUNCTIONS */
230
231/// Return my node number - take care to return
232/// the previous node number if mpi is being
233/// torn down (used in destructors)
234
236 static int node = 0;
237
238 if (!mpi_initialized || hila::check_input)
239 return node;
240
241 MPI_Comm_rank(lattice.mpi_comm_lat, &node);
242 return node;
243}
244
245/// Return number of nodes or "pseudo-nodes"
247 if (hila::check_input)
248 return hila::check_with_nodes;
249
250 int nodes;
251 MPI_Comm_size(lattice.mpi_comm_lat, &nodes);
252 return (nodes);
253}
254
256 synchronize_timer.start();
257 hila::synchronize_threads();
258 MPI_Barrier(lattice.mpi_comm_lat);
259 synchronize_timer.stop();
260}
261
262
263/// Get message tags cyclically -- defined outside classes, so that it is global and
264/// unique
265
266#define MSG_TAG_MIN 100
267#define MSG_TAG_MAX (500) // standard says that at least 32767 tags available
268
269int get_next_msg_tag() {
270 static int tag = MSG_TAG_MIN;
271 ++tag;
272 if (tag > MSG_TAG_MAX)
273 tag = MSG_TAG_MIN;
274 return tag;
275}
276
277
278// Split the communicator to subvolumes, using MPI_Comm_split
279// New MPI_Comm is the global mpi_comm_lat
280// NOTE: no attempt made here to reorder the nodes
281
282void split_into_partitions(int this_lattice) {
283
284 if (hila::check_input)
285 return;
286
287 if (MPI_Comm_split(MPI_COMM_WORLD, this_lattice, 0, &(lattice.mpi_comm_lat)) !=
288 MPI_SUCCESS) {
289 hila::out0 << "MPI_Comm_split() call failed!\n";
291 }
292 // reset also the rank and numbers -fields
293 MPI_Comm_rank(lattice.mpi_comm_lat, &lattice.mynode.rank);
294 MPI_Comm_size(lattice.mpi_comm_lat, &lattice.nodes.number);
295}
296
297#if 0
298
299// Switch comm frame global-sublat
300// for use in io_status_file
301
302void reset_comm(bool global)
303{
304 static MPI_Comm mpi_comm_saved;
305 static int set = 0;
306
307 g_sync_partitions();
308 if (global) {
309 mpi_comm_saved = lattice.mpi_comm_lat;
310 lattice.mpi_comm_lat = MPI_COMM_WORLD;
311 set = 1;
312 } else {
313 if (!set) halt("Comm set error!");
314 lattice.mpi_comm_lat = mpi_comm_saved;
315 set = 0;
316 }
317 mynode = hila::myrank();
318}
319
320#endif
This file defines all includes for HILA.
This files containts definitions for the Field class and the classes required to define it such as fi...
int myrank()
rank of this node
Definition com_mpi.cpp:235
int number_of_nodes()
how many nodes there are
Definition com_mpi.cpp:246
void synchronize()
synchronize mpi
Definition com_mpi.cpp:255
void set_allreduce(bool on=true)
set allreduce on (default) or off on the next reduction
Definition com_mpi.cpp:132
std::ostream out
this is our default output file stream
std::ostream out0
This writes output only from main process (node 0)
T broadcast(T &var, int rank=0)
Broadcast the value of var to all MPI ranks from rank (default=0).
Definition com_mpi.h:153
void finishrun()
Normal, controlled exit - all nodes must call this. Prints timing information and information about c...