3#include "plumbing/lattice.h"
5#include "plumbing/com_mpi.h"
6#include "plumbing/timing.h"
16hila::timer reduction_wait_timer(
"MPI reduction wait");
24hila::partitions_struct hila::partitions;
27static bool mpi_initialized =
false;
36static std::vector<double> double_reduction_buffer;
37static std::vector<double *> double_reduction_ptrs;
38static int n_double = 0;
40static std::vector<float> float_reduction_buffer;
41static std::vector<float *> float_reduction_ptrs;
42static int n_float = 0;
45static bool allreduce_on =
true;
47void hila_reduce_double_setup(
double *d,
int n) {
50 if (n + n_double > double_reduction_buffer.size()) {
51 double_reduction_buffer.resize(n + n_double + 2);
52 double_reduction_ptrs.resize(n + n_double + 2);
55 for (
int i = 0; i < n; i++) {
56 double_reduction_buffer[n_double + i] = d[i];
57 double_reduction_ptrs[n_double + i] = d + i;
63void hila_reduce_float_setup(
float *d,
int n) {
66 if (n + n_float > float_reduction_buffer.size()) {
67 float_reduction_buffer.resize(n + n_float + 2);
68 float_reduction_ptrs.resize(n + n_float + 2);
71 for (
int i = 0; i < n; i++) {
72 float_reduction_buffer[n_float + i] = d[i];
73 float_reduction_ptrs[n_float + i] = d + i;
79void hila_reduce_sums() {
82 std::vector<double> work(n_double);
84 reduction_timer.start();
87 MPI_Allreduce((
void *)double_reduction_buffer.data(), (
void *)work.data(), n_double,
88 MPI_DOUBLE, MPI_SUM, lattice.mpi_comm_lat);
89 for (
int i = 0; i < n_double; i++)
90 *(double_reduction_ptrs[i]) = work[i];
93 MPI_Reduce((
void *)double_reduction_buffer.data(), work.data(), n_double,
94 MPI_DOUBLE, MPI_SUM, 0, lattice.mpi_comm_lat);
96 for (
int i = 0; i < n_double; i++)
97 *(double_reduction_ptrs[i]) = work[i];
102 reduction_timer.stop();
106 std::vector<float> work(n_float);
108 reduction_timer.start();
111 MPI_Allreduce((
void *)float_reduction_buffer.data(), work.data(), n_float,
112 MPI_FLOAT, MPI_SUM, lattice.mpi_comm_lat);
113 for (
int i = 0; i < n_float; i++)
114 *(float_reduction_ptrs[i]) = work[i];
117 MPI_Reduce((
void *)float_reduction_buffer.data(), work.data(), n_float, MPI_FLOAT,
118 MPI_SUM, 0, lattice.mpi_comm_lat);
120 for (
int i = 0; i < n_float; i++)
121 *(float_reduction_ptrs[i]) = work[i];
126 reduction_timer.stop();
135bool hila::get_allreduce() {
143#include <sys/types.h>
144void initialize_communications(
int &argc,
char ***argv) {
146 if (!mpi_initialized) {
149 MPI_Init(&argc, argv);
154 MPI_Init_thread(&argc, argv, MPI_THREAD_FUNNELED, &provided);
155 if (provided < MPI_THREAD_FUNNELED) {
157 hila::out <<
"MPI could not provide MPI_THREAD_FUNNELED, exiting\n";
164 mpi_initialized =
true;
167 lattice.mpi_comm_lat = MPI_COMM_WORLD;
169 MPI_Comm_rank(lattice.mpi_comm_lat, &lattice.mynode.rank);
170 MPI_Comm_size(lattice.mpi_comm_lat, &lattice.nodes.number);
175bool is_comm_initialized(
void) {
176 return mpi_initialized;
180void abort_communications(
int status) {
181 if (mpi_initialized) {
182 mpi_initialized =
false;
183 MPI_Abort(lattices[0]->mpi_comm_lat, 0);
188void finish_communications() {
190 mpi_initialized =
false;
191 hila::about_to_finish =
true;
199 if (hila::check_input)
202 int size = var.size();
206 var.resize(size,
' ');
209 broadcast_timer.start();
210 MPI_Bcast((
void *)var.data(), size, MPI_BYTE, rank, lattice.mpi_comm_lat);
211 broadcast_timer.stop();
216 if (hila::check_input)
219 int size = list.size();
223 for (
auto &s : list) {
237 if (!mpi_initialized || hila::check_input)
240 MPI_Comm_rank(lattice.mpi_comm_lat, &node);
246 if (hila::check_input)
247 return hila::check_with_nodes;
250 MPI_Comm_size(lattice.mpi_comm_lat, &nodes);
255 synchronize_timer.start();
256 hila::synchronize_threads();
257 MPI_Barrier(lattice.mpi_comm_lat);
258 synchronize_timer.stop();
265#define MSG_TAG_MIN 100
266#define MSG_TAG_MAX (500)
268int get_next_msg_tag() {
269 static int tag = MSG_TAG_MIN;
271 if (tag > MSG_TAG_MAX)
281void split_into_partitions(
int this_lattice) {
283 if (hila::check_input)
286 if (MPI_Comm_split(MPI_COMM_WORLD, this_lattice, 0, &(lattice.mpi_comm_lat)) !=
288 hila::out0 <<
"MPI_Comm_split() call failed!\n";
292 MPI_Comm_rank(lattice.mpi_comm_lat, &lattice.mynode.rank);
293 MPI_Comm_size(lattice.mpi_comm_lat, &lattice.nodes.number);
301void reset_comm(
bool global)
303 static MPI_Comm mpi_comm_saved;
308 mpi_comm_saved = lattice.mpi_comm_lat;
309 lattice.mpi_comm_lat = MPI_COMM_WORLD;
312 if (!set) halt(
"Comm set error!");
313 lattice.mpi_comm_lat = mpi_comm_saved;
This file defines all includes for HILA.
This files containts definitions for the Field class and the classes required to define it such as fi...
int myrank()
rank of this node
int number_of_nodes()
how many nodes there are
void synchronize()
synchronize mpi
void set_allreduce(bool on=true)
set allreduce on (default) or off on the next reduction
std::ostream out
this is our default output file stream
std::ostream out0
This writes output only from main process (node 0)
T broadcast(T &var, int rank=0)
Broadcast the value of var to all MPI ranks from rank (default=0).
void finishrun()
Normal, controlled exit - all nodes must call this. Prints timing information and information about c...