3#include "plumbing/lattice.h"
5#include "plumbing/com_mpi.h"
6#include "plumbing/timing.h"
19hila::timer reduction_wait_timer(
"MPI reduction wait");
27hila::partitions_struct hila::partitions;
30static bool mpi_initialized =
false;
39static std::vector<double> double_reduction_buffer;
40static std::vector<double *> double_reduction_ptrs;
41static int n_double = 0;
43static std::vector<float> float_reduction_buffer;
44static std::vector<float *> float_reduction_ptrs;
45static int n_float = 0;
48static bool allreduce_on =
true;
50void hila_reduce_double_setup(
double *d,
int n) {
53 if (n + n_double > double_reduction_buffer.size()) {
54 double_reduction_buffer.resize(n + n_double + 2);
55 double_reduction_ptrs.resize(n + n_double + 2);
58 for (
int i = 0; i < n; i++) {
59 double_reduction_buffer[n_double + i] = d[i];
60 double_reduction_ptrs[n_double + i] = d + i;
66void hila_reduce_float_setup(
float *d,
int n) {
69 if (n + n_float > float_reduction_buffer.size()) {
70 float_reduction_buffer.resize(n + n_float + 2);
71 float_reduction_ptrs.resize(n + n_float + 2);
74 for (
int i = 0; i < n; i++) {
75 float_reduction_buffer[n_float + i] = d[i];
76 float_reduction_ptrs[n_float + i] = d + i;
82void hila_reduce_sums() {
85 std::vector<double> work(n_double);
87 reduction_timer.start();
90 MPI_Allreduce((
void *)double_reduction_buffer.data(), (
void *)work.data(), n_double,
91 MPI_DOUBLE, MPI_SUM, lattice->mpi_comm_lat);
92 for (
int i = 0; i < n_double; i++)
93 *(double_reduction_ptrs[i]) = work[i];
96 MPI_Reduce((
void *)double_reduction_buffer.data(), work.data(), n_double, MPI_DOUBLE,
97 MPI_SUM, 0, lattice->mpi_comm_lat);
99 for (
int i = 0; i < n_double; i++)
100 *(double_reduction_ptrs[i]) = work[i];
105 reduction_timer.stop();
109 std::vector<float> work(n_float);
111 reduction_timer.start();
114 MPI_Allreduce((
void *)float_reduction_buffer.data(), work.data(), n_float, MPI_FLOAT,
115 MPI_SUM, lattice->mpi_comm_lat);
116 for (
int i = 0; i < n_float; i++)
117 *(float_reduction_ptrs[i]) = work[i];
120 MPI_Reduce((
void *)float_reduction_buffer.data(), work.data(), n_float, MPI_FLOAT,
121 MPI_SUM, 0, lattice->mpi_comm_lat);
123 for (
int i = 0; i < n_float; i++)
124 *(float_reduction_ptrs[i]) = work[i];
129 reduction_timer.stop();
138bool hila::get_allreduce() {
146#include <sys/types.h>
147void hila::initialize_communications(
int &argc,
char ***argv) {
149 if (!mpi_initialized) {
152 MPI_Init(&argc, argv);
157 MPI_Init_thread(&argc, argv, MPI_THREAD_FUNNELED, &provided);
158 if (provided < MPI_THREAD_FUNNELED) {
160 hila::out <<
"MPI could not provide MPI_THREAD_FUNNELED, exiting\n";
167 mpi_initialized =
true;
170 lattice.
ptr()->mpi_comm_lat = MPI_COMM_WORLD;
172 MPI_Comm_rank(lattice->mpi_comm_lat, &lattice.
ptr()->mynode.rank);
173 MPI_Comm_size(lattice->mpi_comm_lat, &lattice.
ptr()->nodes.number);
178bool hila::is_comm_initialized(
void) {
179 return mpi_initialized;
183void hila::abort_communications(
int status) {
184 if (mpi_initialized) {
185 mpi_initialized =
false;
186 MPI_Abort(lattice->mpi_comm_lat, 0);
191void hila::finish_communications() {
193 mpi_initialized =
false;
194 hila::about_to_finish =
true;
202 if (hila::check_input)
205 int size = var.size();
209 var.resize(size,
' ');
212 broadcast_timer.start();
213 MPI_Bcast((
void *)var.data(), size, MPI_BYTE, rank, lattice->mpi_comm_lat);
214 broadcast_timer.stop();
219 if (hila::check_input)
222 int size = list.size();
226 for (
auto &s : list) {
240 if (!mpi_initialized || hila::check_input)
243 MPI_Comm_rank(lattice->mpi_comm_lat, &node);
249 if (hila::check_input)
250 return hila::check_with_nodes;
253 MPI_Comm_size(lattice->mpi_comm_lat, &nodes);
258 synchronize_timer.start();
259 hila::synchronize_threads();
260 MPI_Barrier(lattice->mpi_comm_lat);
261 synchronize_timer.stop();
265 synchronize_timer.start();
266 MPI_Barrier(lattice->mpi_comm_lat);
267 synchronize_timer.stop();
274#define MSG_TAG_MIN 100
275#define MSG_TAG_MAX (500)
277int get_next_msg_tag() {
278 static int tag = MSG_TAG_MIN;
280 if (tag > MSG_TAG_MAX)
292 if (hila::check_input)
295 if (MPI_Comm_split(MPI_COMM_WORLD, this_lattice, 0, &(lattice.
ptr()->mpi_comm_lat)) != MPI_SUCCESS) {
296 hila::out0 <<
"MPI_Comm_split() call failed!\n";
300 MPI_Comm_rank(lattice->mpi_comm_lat, &lattice.
ptr()->mynode.rank);
301 MPI_Comm_size(lattice->mpi_comm_lat, &lattice.
ptr()->nodes.number);
304void hila::synchronize_partitions() {
305 if (partitions.number() > 1)
306 MPI_Barrier(MPI_COMM_WORLD);
309MPI_Datatype MPI_ExtendedPrecision_type;
310MPI_Op MPI_ExtendedPrecision_sum_op;
312void create_extended_MPI_type() {
313 ExtendedPrecision dummy;
314 int block_lengths[3] = {1, 1, 1};
315 MPI_Aint displacements[3];
316 MPI_Datatype types[3] = {MPI_DOUBLE, MPI_DOUBLE, MPI_DOUBLE};
319 MPI_Get_address(&dummy, &base);
320 MPI_Get_address(&dummy.value, &displacements[0]);
321 MPI_Get_address(&dummy.compensation, &displacements[1]);
322 MPI_Get_address(&dummy.compensation2, &displacements[2]);
323 displacements[0] -= base;
324 displacements[1] -= base;
325 displacements[2] -= base;
327 MPI_Type_create_struct(3, block_lengths, displacements, types, &MPI_ExtendedPrecision_type);
328 MPI_Type_commit(&MPI_ExtendedPrecision_type);
331void extended_sum_op(
void *in,
void *inout,
int *len, MPI_Datatype *datatype) {
332 ExtendedPrecision *in_data = (ExtendedPrecision *)in;
333 ExtendedPrecision *inout_data = (ExtendedPrecision *)inout;
335 for (
int i = 0; i < *len; i++) {
336 inout_data[i] += in_data[i];
340void create_extended_MPI_operation() {
341 MPI_Op_create(&extended_sum_op,
true, &MPI_ExtendedPrecision_sum_op);
352void reduce_node_sum_extended(ExtendedPrecision *value,
int send_count,
bool allreduce) {
354 if (hila::check_input)
357 static bool init_extended_type_and_operation =
true;
358 if (init_extended_type_and_operation) {
359 create_extended_MPI_type();
360 create_extended_MPI_operation();
361 init_extended_type_and_operation =
false;
364 std::vector<ExtendedPrecision> recv_data(send_count);
365 reduction_timer.start();
367 MPI_Allreduce((
void *)value, (
void *)recv_data.data(), send_count,
368 MPI_ExtendedPrecision_type, MPI_ExtendedPrecision_sum_op,
369 lattice->mpi_comm_lat);
370 for (
int i = 0; i < send_count; i++)
371 value[i] = recv_data[i];
373 MPI_Reduce((
void *)value, (
void *)recv_data.data(), send_count, MPI_ExtendedPrecision_type,
374 MPI_ExtendedPrecision_sum_op, 0, lattice->mpi_comm_lat);
376 for (
int i = 0; i < send_count; i++)
377 value[i] = recv_data[i];
379 reduction_timer.stop();
lattice_struct * ptr() const
get non-const pointer to lattice_struct (cf. operator ->)
This file defines all includes for HILA.
This files containts definitions for the extended precision class that allows for high precision redu...
This files containts definitions for the Field class and the classes required to define it such as fi...
int myrank()
rank of this node
int number_of_nodes()
how many nodes there are
void synchronize()
synchronize mpi + gpu
void set_allreduce(bool on=true)
set allreduce on (default) or off on the next reduction
std::ostream out
this is our default output file stream
std::ostream out0
This writes output only from main process (node 0)
void split_into_partitions(int rank)
T broadcast(T &var, int rank=0)
Broadcast the value of var to all MPI ranks from rank (default=0).
void finishrun()
Normal, controlled exit - all nodes must call this. Prints timing information and information about c...