3#include "plumbing/lattice.h"
5#include "plumbing/com_mpi.h"
6#include "plumbing/timing.h"
16hila::timer reduction_wait_timer(
"MPI reduction wait");
24hila::partitions_struct hila::partitions;
27static bool mpi_initialized =
false;
36static std::vector<double> double_reduction_buffer;
37static std::vector<double *> double_reduction_ptrs;
38static int n_double = 0;
40static std::vector<float> float_reduction_buffer;
41static std::vector<float *> float_reduction_ptrs;
42static int n_float = 0;
45static bool allreduce_on =
true;
47void hila_reduce_double_setup(
double *d,
int n) {
50 if (n + n_double > double_reduction_buffer.size()) {
51 double_reduction_buffer.resize(n + n_double + 2);
52 double_reduction_ptrs.resize(n + n_double + 2);
55 for (
int i = 0; i < n; i++) {
56 double_reduction_buffer[n_double + i] = d[i];
57 double_reduction_ptrs[n_double + i] = d + i;
63void hila_reduce_float_setup(
float *d,
int n) {
66 if (n + n_float > float_reduction_buffer.size()) {
67 float_reduction_buffer.resize(n + n_float + 2);
68 float_reduction_ptrs.resize(n + n_float + 2);
71 for (
int i = 0; i < n; i++) {
72 float_reduction_buffer[n_float + i] = d[i];
73 float_reduction_ptrs[n_float + i] = d + i;
79void hila_reduce_sums() {
82 std::vector<double> work(n_double);
84 reduction_timer.start();
87 MPI_Allreduce((
void *)double_reduction_buffer.data(), (
void *)work.data(), n_double,
88 MPI_DOUBLE, MPI_SUM, lattice.mpi_comm_lat);
89 for (
int i = 0; i < n_double; i++)
90 *(double_reduction_ptrs[i]) = work[i];
93 MPI_Reduce((
void *)double_reduction_buffer.data(), work.data(), n_double,
94 MPI_DOUBLE, MPI_SUM, 0, lattice.mpi_comm_lat);
96 for (
int i = 0; i < n_double; i++)
97 *(double_reduction_ptrs[i]) = work[i];
102 reduction_timer.stop();
106 std::vector<float> work(n_float);
108 reduction_timer.start();
111 MPI_Allreduce((
void *)float_reduction_buffer.data(), work.data(), n_float,
112 MPI_FLOAT, MPI_SUM, lattice.mpi_comm_lat);
113 for (
int i = 0; i < n_float; i++)
114 *(float_reduction_ptrs[i]) = work[i];
117 MPI_Reduce((
void *)float_reduction_buffer.data(), work.data(), n_float, MPI_FLOAT,
118 MPI_SUM, 0, lattice.mpi_comm_lat);
120 for (
int i = 0; i < n_float; i++)
121 *(float_reduction_ptrs[i]) = work[i];
126 reduction_timer.stop();
135bool hila::get_allreduce() {
143#include <sys/types.h>
144void hila::initialize_communications(
int &argc,
char ***argv) {
146 if (!mpi_initialized) {
149 MPI_Init(&argc, argv);
154 MPI_Init_thread(&argc, argv, MPI_THREAD_FUNNELED, &provided);
155 if (provided < MPI_THREAD_FUNNELED) {
157 hila::out <<
"MPI could not provide MPI_THREAD_FUNNELED, exiting\n";
164 mpi_initialized =
true;
167 lattice.mpi_comm_lat = MPI_COMM_WORLD;
169 MPI_Comm_rank(lattice.mpi_comm_lat, &lattice.mynode.rank);
170 MPI_Comm_size(lattice.mpi_comm_lat, &lattice.nodes.number);
175bool hila::is_comm_initialized(
void) {
176 return mpi_initialized;
180void hila::abort_communications(
int status) {
181 if (mpi_initialized) {
182 mpi_initialized =
false;
183 MPI_Abort(lattices[0]->mpi_comm_lat, 0);
188void hila::finish_communications() {
190 mpi_initialized =
false;
191 hila::about_to_finish =
true;
199 if (hila::check_input)
202 int size = var.size();
206 var.resize(size,
' ');
209 broadcast_timer.start();
210 MPI_Bcast((
void *)var.data(), size, MPI_BYTE, rank, lattice.mpi_comm_lat);
211 broadcast_timer.stop();
216 if (hila::check_input)
219 int size = list.size();
223 for (
auto &s : list) {
237 if (!mpi_initialized || hila::check_input)
240 MPI_Comm_rank(lattice.mpi_comm_lat, &node);
246 if (hila::check_input)
247 return hila::check_with_nodes;
250 MPI_Comm_size(lattice.mpi_comm_lat, &nodes);
255 synchronize_timer.start();
256 hila::synchronize_threads();
257 MPI_Barrier(lattice.mpi_comm_lat);
258 synchronize_timer.stop();
265#define MSG_TAG_MIN 100
266#define MSG_TAG_MAX (500)
268int get_next_msg_tag() {
269 static int tag = MSG_TAG_MIN;
271 if (tag > MSG_TAG_MAX)
283 if (hila::check_input)
286 if (MPI_Comm_split(MPI_COMM_WORLD, this_lattice, 0, &(lattice.mpi_comm_lat)) !=
288 hila::out0 <<
"MPI_Comm_split() call failed!\n";
292 MPI_Comm_rank(lattice.mpi_comm_lat, &lattice.mynode.rank);
293 MPI_Comm_size(lattice.mpi_comm_lat, &lattice.nodes.number);
296void hila::synchronize_partitions() {
297 if (partitions.number() > 1)
298 MPI_Barrier(MPI_COMM_WORLD);
This file defines all includes for HILA.
This files containts definitions for the Field class and the classes required to define it such as fi...
int myrank()
rank of this node
int number_of_nodes()
how many nodes there are
void synchronize()
synchronize mpi
void set_allreduce(bool on=true)
set allreduce on (default) or off on the next reduction
std::ostream out
this is our default output file stream
std::ostream out0
This writes output only from main process (node 0)
void split_into_partitions(int rank)
T broadcast(T &var, int rank=0)
Broadcast the value of var to all MPI ranks from rank (default=0).
void finishrun()
Normal, controlled exit - all nodes must call this. Prints timing information and information about c...