HILA
Loading...
Searching...
No Matches
com_mpi.cpp
1
2#include "plumbing/defs.h"
3#include "plumbing/lattice.h"
4#include "plumbing/field.h"
5#include "plumbing/com_mpi.h"
6#include "plumbing/timing.h"
7
8// declare MPI timers here too - these were externs
9
10hila::timer start_send_timer("MPI start send");
11hila::timer wait_send_timer("MPI wait send");
12hila::timer post_receive_timer("MPI post receive");
13hila::timer wait_receive_timer("MPI wait receive");
14hila::timer synchronize_timer("MPI synchronize");
15hila::timer reduction_timer("MPI reduction");
16hila::timer reduction_wait_timer("MPI reduction wait");
17hila::timer broadcast_timer("MPI broadcast");
18hila::timer send_timer("MPI send field");
19hila::timer drop_comms_timer("MPI wait drop_comms");
20hila::timer partition_sync_timer("partition sync");
21
22// let us house the partitions-struct here
23
24hila::partitions_struct hila::partitions;
25
26/* Keep track of whether MPI has been initialized */
27static bool mpi_initialized = false;
28
29////////////////////////////////////////////////////////////
30/// Reductions: do automatic coalescing of reductions
31/// if the type is float or double
32/// These functions should not be called "by hand"
33
34// buffers - first vector holds the reduction buffer,
35// second the pointers to where distribute results
36static std::vector<double> double_reduction_buffer;
37static std::vector<double *> double_reduction_ptrs;
38static int n_double = 0;
39
40static std::vector<float> float_reduction_buffer;
41static std::vector<float *> float_reduction_ptrs;
42static int n_float = 0;
43
44// static var holding the allreduce state
45static bool allreduce_on = true;
46
47void hila_reduce_double_setup(double *d, int n) {
48
49 // ensure there's enough space
50 if (n + n_double > double_reduction_buffer.size()) {
51 double_reduction_buffer.resize(n + n_double + 2);
52 double_reduction_ptrs.resize(n + n_double + 2);
53 }
54
55 for (int i = 0; i < n; i++) {
56 double_reduction_buffer[n_double + i] = d[i];
57 double_reduction_ptrs[n_double + i] = d + i;
58 }
59
60 n_double += n;
61}
62
63void hila_reduce_float_setup(float *d, int n) {
64
65 // ensure there's enough space
66 if (n + n_float > float_reduction_buffer.size()) {
67 float_reduction_buffer.resize(n + n_float + 2);
68 float_reduction_ptrs.resize(n + n_float + 2);
69 }
70
71 for (int i = 0; i < n; i++) {
72 float_reduction_buffer[n_float + i] = d[i];
73 float_reduction_ptrs[n_float + i] = d + i;
74 }
75
76 n_float += n;
77}
78
79void hila_reduce_sums() {
80
81 if (n_double > 0) {
82 std::vector<double> work(n_double);
83
84 reduction_timer.start();
85
86 if (allreduce_on) {
87 MPI_Allreduce((void *)double_reduction_buffer.data(), (void *)work.data(), n_double,
88 MPI_DOUBLE, MPI_SUM, lattice.mpi_comm_lat);
89 for (int i = 0; i < n_double; i++)
90 *(double_reduction_ptrs[i]) = work[i];
91
92 } else {
93 MPI_Reduce((void *)double_reduction_buffer.data(), work.data(), n_double,
94 MPI_DOUBLE, MPI_SUM, 0, lattice.mpi_comm_lat);
95 if (hila::myrank() == 0)
96 for (int i = 0; i < n_double; i++)
97 *(double_reduction_ptrs[i]) = work[i];
98 }
99
100 n_double = 0;
101
102 reduction_timer.stop();
103 }
104
105 if (n_float > 0) {
106 std::vector<float> work(n_float);
107
108 reduction_timer.start();
109
110 if (allreduce_on) {
111 MPI_Allreduce((void *)float_reduction_buffer.data(), work.data(), n_float,
112 MPI_FLOAT, MPI_SUM, lattice.mpi_comm_lat);
113 for (int i = 0; i < n_float; i++)
114 *(float_reduction_ptrs[i]) = work[i];
115
116 } else {
117 MPI_Reduce((void *)float_reduction_buffer.data(), work.data(), n_float, MPI_FLOAT,
118 MPI_SUM, 0, lattice.mpi_comm_lat);
119 if (hila::myrank() == 0)
120 for (int i = 0; i < n_float; i++)
121 *(float_reduction_ptrs[i]) = work[i];
122 }
123
124 n_float = 0;
125
126 reduction_timer.stop();
127 }
128}
129
130/// set allreduce on (default) or off on the next reduction
131void hila::set_allreduce(bool on) {
132 allreduce_on = on;
133}
134
135bool hila::get_allreduce() {
136 return allreduce_on;
137}
138
139////////////////////////////////////////////////////////////////////////
140
141
142/* Machine initialization */
143#include <sys/types.h>
144void initialize_communications(int &argc, char ***argv) {
145 /* Init MPI */
146 if (!mpi_initialized) {
147
148#ifndef OPENMP
149 MPI_Init(&argc, argv);
150
151#else
152
153 int provided;
154 MPI_Init_thread(&argc, argv, MPI_THREAD_FUNNELED, &provided);
155 if (provided < MPI_THREAD_FUNNELED) {
156 if (hila::myrank() == 0)
157 hila::out << "MPI could not provide MPI_THREAD_FUNNELED, exiting\n";
158 MPI_Finalize();
159 exit(1);
160 }
161
162#endif
163
164 mpi_initialized = true;
165
166 // global var lattice exists, assign the mpi comms there
167 lattice.mpi_comm_lat = MPI_COMM_WORLD;
168
169 MPI_Comm_rank(lattice.mpi_comm_lat, &lattice.mynode.rank);
170 MPI_Comm_size(lattice.mpi_comm_lat, &lattice.nodes.number);
171 }
172}
173
174// check if MPI is on
175bool is_comm_initialized(void) {
176 return mpi_initialized;
177}
178
179/* version of exit for multinode processes -- kill all nodes */
180void abort_communications(int status) {
181 if (mpi_initialized) {
182 mpi_initialized = false;
183 MPI_Abort(lattices[0]->mpi_comm_lat, 0);
184 }
185}
186
187/* clean exit from all nodes */
188void finish_communications() {
189 // turn off mpi -- this is needed to avoid mpi calls in destructors
190 mpi_initialized = false;
191 hila::about_to_finish = true;
192
193 MPI_Finalize();
194}
195
196// broadcast specialization
197void hila::broadcast(std::string &var, int rank) {
198
199 if (hila::check_input)
200 return;
201
202 int size = var.size();
203 hila::broadcast(size,rank);
204
205 if (hila::myrank() != rank) {
206 var.resize(size, ' ');
207 }
208 // copy directy to data() buffer
209 broadcast_timer.start();
210 MPI_Bcast((void *)var.data(), size, MPI_BYTE, rank, lattice.mpi_comm_lat);
211 broadcast_timer.stop();
212}
213
214void hila::broadcast(std::vector<std::string> &list, int rank) {
215
216 if (hila::check_input)
217 return;
218
219 int size = list.size();
220 hila::broadcast(size,rank);
221 list.resize(size);
222
223 for (auto &s : list) {
224 hila::broadcast(s,rank);
225 }
226}
227
228/* BASIC COMMUNICATIONS FUNCTIONS */
229
230/// Return my node number - take care to return
231/// the previous node number if mpi is being
232/// torn down (used in destructors)
233
235 static int node = 0;
236
237 if (!mpi_initialized || hila::check_input)
238 return node;
239
240 MPI_Comm_rank(lattice.mpi_comm_lat, &node);
241 return node;
242}
243
244/// Return number of nodes or "pseudo-nodes"
246 if (hila::check_input)
247 return hila::check_with_nodes;
248
249 int nodes;
250 MPI_Comm_size(lattice.mpi_comm_lat, &nodes);
251 return (nodes);
252}
253
255 synchronize_timer.start();
256 hila::synchronize_threads();
257 MPI_Barrier(lattice.mpi_comm_lat);
258 synchronize_timer.stop();
259}
260
261
262/// Get message tags cyclically -- defined outside classes, so that it is global and
263/// unique
264
265#define MSG_TAG_MIN 100
266#define MSG_TAG_MAX (500) // standard says that at least 32767 tags available
267
268int get_next_msg_tag() {
269 static int tag = MSG_TAG_MIN;
270 ++tag;
271 if (tag > MSG_TAG_MAX)
272 tag = MSG_TAG_MIN;
273 return tag;
274}
275
276
277// Split the communicator to subvolumes, using MPI_Comm_split
278// New MPI_Comm is the global mpi_comm_lat
279// NOTE: no attempt made here to reorder the nodes
280
281void split_into_partitions(int this_lattice) {
282
283 if (hila::check_input)
284 return;
285
286 if (MPI_Comm_split(MPI_COMM_WORLD, this_lattice, 0, &(lattice.mpi_comm_lat)) !=
287 MPI_SUCCESS) {
288 hila::out0 << "MPI_Comm_split() call failed!\n";
290 }
291 // reset also the rank and numbers -fields
292 MPI_Comm_rank(lattice.mpi_comm_lat, &lattice.mynode.rank);
293 MPI_Comm_size(lattice.mpi_comm_lat, &lattice.nodes.number);
294}
295
296#if 0
297
298// Switch comm frame global-sublat
299// for use in io_status_file
300
301void reset_comm(bool global)
302{
303 static MPI_Comm mpi_comm_saved;
304 static int set = 0;
305
306 g_sync_partitions();
307 if (global) {
308 mpi_comm_saved = lattice.mpi_comm_lat;
309 lattice.mpi_comm_lat = MPI_COMM_WORLD;
310 set = 1;
311 } else {
312 if (!set) halt("Comm set error!");
313 lattice.mpi_comm_lat = mpi_comm_saved;
314 set = 0;
315 }
316 mynode = hila::myrank();
317}
318
319#endif
This file defines all includes for HILA.
This files containts definitions for the Field class and the classes required to define it such as fi...
int myrank()
rank of this node
Definition com_mpi.cpp:234
int number_of_nodes()
how many nodes there are
Definition com_mpi.cpp:245
void synchronize()
synchronize mpi
Definition com_mpi.cpp:254
void set_allreduce(bool on=true)
set allreduce on (default) or off on the next reduction
Definition com_mpi.cpp:131
std::ostream out
this is our default output file stream
std::ostream out0
This writes output only from main process (node 0)
T broadcast(T &var, int rank=0)
Broadcast the value of var to all MPI ranks from rank (default=0).
Definition com_mpi.h:152
void finishrun()
Normal, controlled exit - all nodes must call this. Prints timing information and information about c...