3#include "plumbing/lattice.h"
5#include "plumbing/com_mpi.h"
6#include "plumbing/timing.h"
16hila::timer reduction_wait_timer(
"MPI reduction wait");
20hila::timer cancel_receive_timer(
"MPI cancel receive");
25hila::partitions_struct hila::partitions;
28static bool mpi_initialized =
false;
37static std::vector<double> double_reduction_buffer;
38static std::vector<double *> double_reduction_ptrs;
39static int n_double = 0;
41static std::vector<float> float_reduction_buffer;
42static std::vector<float *> float_reduction_ptrs;
43static int n_float = 0;
46static bool allreduce_on =
true;
48void hila_reduce_double_setup(
double *d,
int n) {
51 if (n + n_double > double_reduction_buffer.size()) {
52 double_reduction_buffer.resize(n + n_double + 2);
53 double_reduction_ptrs.resize(n + n_double + 2);
56 for (
int i = 0; i < n; i++) {
57 double_reduction_buffer[n_double + i] = d[i];
58 double_reduction_ptrs[n_double + i] = d + i;
64void hila_reduce_float_setup(
float *d,
int n) {
67 if (n + n_float > float_reduction_buffer.size()) {
68 float_reduction_buffer.resize(n + n_float + 2);
69 float_reduction_ptrs.resize(n + n_float + 2);
72 for (
int i = 0; i < n; i++) {
73 float_reduction_buffer[n_float + i] = d[i];
74 float_reduction_ptrs[n_float + i] = d + i;
80void hila_reduce_sums() {
83 std::vector<double> work(n_double);
85 reduction_timer.start();
88 MPI_Allreduce((
void *)double_reduction_buffer.data(), (
void *)work.data(), n_double,
89 MPI_DOUBLE, MPI_SUM, lattice.mpi_comm_lat);
90 for (
int i = 0; i < n_double; i++)
91 *(double_reduction_ptrs[i]) = work[i];
94 MPI_Reduce((
void *)double_reduction_buffer.data(), work.data(), n_double,
95 MPI_DOUBLE, MPI_SUM, 0, lattice.mpi_comm_lat);
97 for (
int i = 0; i < n_double; i++)
98 *(double_reduction_ptrs[i]) = work[i];
103 reduction_timer.stop();
107 std::vector<float> work(n_float);
109 reduction_timer.start();
112 MPI_Allreduce((
void *)float_reduction_buffer.data(), work.data(), n_float,
113 MPI_FLOAT, MPI_SUM, lattice.mpi_comm_lat);
114 for (
int i = 0; i < n_float; i++)
115 *(float_reduction_ptrs[i]) = work[i];
118 MPI_Reduce((
void *)float_reduction_buffer.data(), work.data(), n_float, MPI_FLOAT,
119 MPI_SUM, 0, lattice.mpi_comm_lat);
121 for (
int i = 0; i < n_float; i++)
122 *(float_reduction_ptrs[i]) = work[i];
127 reduction_timer.stop();
136bool hila::get_allreduce() {
144#include <sys/types.h>
145void initialize_communications(
int &argc,
char ***argv) {
147 if (!mpi_initialized) {
150 MPI_Init(&argc, argv);
155 MPI_Init_thread(&argc, argv, MPI_THREAD_FUNNELED, &provided);
156 if (provided < MPI_THREAD_FUNNELED) {
158 hila::out <<
"MPI could not provide MPI_THREAD_FUNNELED, exiting\n";
165 mpi_initialized =
true;
168 lattice.mpi_comm_lat = MPI_COMM_WORLD;
170 MPI_Comm_rank(lattice.mpi_comm_lat, &lattice.mynode.rank);
171 MPI_Comm_size(lattice.mpi_comm_lat, &lattice.nodes.number);
176bool is_comm_initialized(
void) {
177 return mpi_initialized;
181void abort_communications(
int status) {
182 if (mpi_initialized) {
183 mpi_initialized =
false;
184 MPI_Abort(lattices[0]->mpi_comm_lat, 0);
189void finish_communications() {
191 mpi_initialized =
false;
192 hila::about_to_finish =
true;
200 if (hila::check_input)
203 int size = var.size();
207 var.resize(size,
' ');
210 broadcast_timer.start();
211 MPI_Bcast((
void *)var.data(), size, MPI_BYTE, rank, lattice.mpi_comm_lat);
212 broadcast_timer.stop();
217 if (hila::check_input)
220 int size = list.size();
224 for (
auto &s : list) {
238 if (!mpi_initialized || hila::check_input)
241 MPI_Comm_rank(lattice.mpi_comm_lat, &node);
247 if (hila::check_input)
248 return hila::check_with_nodes;
251 MPI_Comm_size(lattice.mpi_comm_lat, &nodes);
256 synchronize_timer.start();
257 hila::synchronize_threads();
258 MPI_Barrier(lattice.mpi_comm_lat);
259 synchronize_timer.stop();
266#define MSG_TAG_MIN 100
267#define MSG_TAG_MAX (500)
269int get_next_msg_tag() {
270 static int tag = MSG_TAG_MIN;
272 if (tag > MSG_TAG_MAX)
282void split_into_partitions(
int this_lattice) {
284 if (hila::check_input)
287 if (MPI_Comm_split(MPI_COMM_WORLD, this_lattice, 0, &(lattice.mpi_comm_lat)) !=
289 hila::out0 <<
"MPI_Comm_split() call failed!\n";
293 MPI_Comm_rank(lattice.mpi_comm_lat, &lattice.mynode.rank);
294 MPI_Comm_size(lattice.mpi_comm_lat, &lattice.nodes.number);
302void reset_comm(
bool global)
304 static MPI_Comm mpi_comm_saved;
309 mpi_comm_saved = lattice.mpi_comm_lat;
310 lattice.mpi_comm_lat = MPI_COMM_WORLD;
313 if (!set) halt(
"Comm set error!");
314 lattice.mpi_comm_lat = mpi_comm_saved;
This file defines all includes for HILA.
This files containts definitions for the Field class and the classes required to define it such as fi...
int myrank()
rank of this node
int number_of_nodes()
how many nodes there are
void synchronize()
synchronize mpi
void set_allreduce(bool on=true)
set allreduce on (default) or off on the next reduction
std::ostream out
this is our default output file stream
std::ostream out0
This writes output only from main process (node 0)
T broadcast(T &var, int rank=0)
Broadcast the value of var to all MPI ranks from rank (default=0).
void finishrun()
Normal, controlled exit - all nodes must call this. Prints timing information and information about c...