6#include "plumbing/lattice.h"
13class partitions_struct {
15 unsigned _number, _mylattice;
19 unsigned number()
const {
23 void set_number(
const unsigned u) {
27 unsigned mylattice()
const {
31 void set_mylattice(
const unsigned l) {
39 void set_sync(
bool s) {
44extern partitions_struct partitions;
69int get_next_msg_tag();
83MPI_Datatype get_MPI_number_type(
size_t &size,
bool with_int =
false) {
85 if (std::is_same<hila::arithmetic_type<T>,
int>::value) {
87 return with_int ? MPI_2INT : MPI_INT;
88 }
else if (std::is_same<hila::arithmetic_type<T>,
unsigned>::value) {
89 size =
sizeof(unsigned);
90 return with_int ? MPI_2INT : MPI_UNSIGNED;
91 }
else if (std::is_same<hila::arithmetic_type<T>,
long>::value) {
93 return with_int ? MPI_LONG_INT : MPI_LONG;
94 }
else if (std::is_same<hila::arithmetic_type<T>, int64_t>::value) {
95 size =
sizeof(int64_t);
96 return with_int ? MPI_LONG_INT : MPI_INT64_T;
97 }
else if (std::is_same<hila::arithmetic_type<T>, uint64_t>::value) {
98 size =
sizeof(uint64_t);
99 return with_int ? MPI_LONG_INT : MPI_UINT64_T;
100 }
else if (std::is_same<hila::arithmetic_type<T>,
float>::value) {
101 size =
sizeof(float);
102 return with_int ? MPI_FLOAT_INT : MPI_FLOAT;
103 }
else if (std::is_same<hila::arithmetic_type<T>,
double>::value) {
104 size =
sizeof(double);
105 return with_int ? MPI_DOUBLE_INT : MPI_DOUBLE;
106 }
else if (std::is_same<hila::arithmetic_type<T>,
long double>::value) {
107 size =
sizeof(
long double);
108 return with_int ? MPI_LONG_DOUBLE_INT : MPI_LONG_DOUBLE;
117MPI_Datatype get_MPI_number_type() {
119 return get_MPI_number_type<T>(s);
130MPI_Datatype get_MPI_complex_type(
size_t &siz) {
131 if constexpr (std::is_same<T, Complex<double>>::value) {
133 return MPI_C_DOUBLE_COMPLEX;
134 }
else if constexpr (std::is_same<T, Complex<float>>::value) {
136 return MPI_C_FLOAT_COMPLEX;
138 static_assert(
sizeof(T) > 0,
139 "get_MPI_complex_type<T>() called without T being a complex type");
171 static_assert(std::is_trivial<T>::value,
"broadcast(var) must use trivial type");
172 if (hila::check_input)
177 broadcast_timer.start();
178 MPI_Bcast(&var,
sizeof(T), MPI_BYTE, rank, lattice->mpi_comm_lat);
179 broadcast_timer.stop();
194 static_assert(std::is_trivial<T>::value,
"broadcast(std::vector<T>) must have trivial T");
196 if (hila::check_input)
199 broadcast_timer.start();
201 size_t size = list.size();
202 MPI_Bcast(&size,
sizeof(
size_t), MPI_BYTE, rank, lattice->mpi_comm_lat);
209 MPI_Bcast((
void *)list.data(),
sizeof(T) * size, MPI_BYTE, rank, lattice->mpi_comm_lat);
212 broadcast_timer.stop();
216template <
typename T,
int n>
219 static_assert(std::is_trivial<T>::value,
"broadcast(std::array<T>) must have trivial T");
221 if (hila::check_input || n <= 0)
224 broadcast_timer.start();
227 MPI_Bcast((
void *)arr.data(),
sizeof(T) * n, MPI_BYTE, rank, lattice->mpi_comm_lat);
229 broadcast_timer.stop();
238 static_assert(
sizeof(T) > 0 &&
239 "Do not use pointers to broadcast()-function. Use 'broadcast_array(T* arr, "
240 "int size)' to broadcast an array");
249 if (hila::check_input || n <= 0)
252 broadcast_timer.start();
253 MPI_Bcast((
void *)var,
sizeof(T) * n, MPI_BYTE, rank, lattice->mpi_comm_lat);
254 broadcast_timer.stop();
258void broadcast(std::string &r,
int rank = 0);
259void broadcast(std::vector<std::string> &l,
int rank = 0);
262template <
typename T,
typename U>
265 if (hila::check_input)
280void send_to(
int to_rank,
const T &data) {
281 if (hila::check_input)
285 MPI_Send(&data,
sizeof(T), MPI_BYTE, to_rank,
hila::myrank(), lattice->mpi_comm_lat);
290void receive_from(
int from_rank, T &data) {
291 if (hila::check_input)
295 MPI_Recv(&data,
sizeof(T), MPI_BYTE, from_rank, from_rank, lattice->mpi_comm_lat,
301void send_to(
int to_rank,
const std::vector<T> &data) {
302 if (hila::check_input)
306 size_t s = data.size();
307 MPI_Send(&s,
sizeof(
size_t), MPI_BYTE, to_rank,
hila::myrank(), lattice->mpi_comm_lat);
310 MPI_Send(data.data(),
sizeof(T) * s, MPI_BYTE, to_rank,
hila::myrank(),
311 lattice->mpi_comm_lat);
316void receive_from(
int from_rank, std::vector<T> &data) {
317 if (hila::check_input)
322 MPI_Recv(&s,
sizeof(
size_t), MPI_BYTE, from_rank, from_rank, lattice->mpi_comm_lat,
327 MPI_Recv(data.data(),
sizeof(T) * s, MPI_BYTE, from_rank, from_rank, lattice->mpi_comm_lat,
339 if (hila::check_input || send_count == 0)
342 std::vector<T> recv_data(send_count);
344 dtype = get_MPI_number_type<T>();
345 reduction_timer.start();
347 MPI_Allreduce((
void *)value, (
void *)recv_data.data(),
348 send_count * (
sizeof(T) /
sizeof(hila::arithmetic_type<T>)), dtype, MPI_SUM,
349 lattice->mpi_comm_lat);
350 for (
int i = 0; i < send_count; i++)
351 value[i] = recv_data[i];
353 MPI_Reduce((
void *)value, (
void *)recv_data.data(),
354 send_count * (
sizeof(T) /
sizeof(hila::arithmetic_type<T>)), dtype, MPI_SUM, 0,
355 lattice->mpi_comm_lat);
357 for (
int i = 0; i < send_count; i++)
358 value[i] = recv_data[i];
360 reduction_timer.stop();
376void reduce_node_product(T *send_data,
int send_count,
bool allreduce =
true) {
377 std::vector<T> recv_data(send_count);
380 if (hila::check_input)
383 dtype = get_MPI_number_type<T>();
385 reduction_timer.start();
387 MPI_Allreduce((
void *)send_data, (
void *)recv_data.data(), send_count, dtype, MPI_PROD,
388 lattice->mpi_comm_lat);
389 for (
int i = 0; i < send_count; i++)
390 send_data[i] = recv_data[i];
392 MPI_Reduce((
void *)send_data, (
void *)recv_data.data(), send_count, dtype, MPI_PROD, 0,
393 lattice->mpi_comm_lat);
395 for (
int i = 0; i < send_count; i++)
396 send_data[i] = recv_data[i];
398 reduction_timer.stop();
402T reduce_node_product(T &var,
bool allreduce =
true) {
403 reduce_node_product(&var, 1, allreduce);
414void hila_reduce_double_setup(
double *d,
int n);
415void hila_reduce_float_setup(
float *d,
int n);
416void hila_reduce_sums();
417void reduce_node_sum_extended(ExtendedPrecision *value,
int send_count,
bool allreduce =
true);
419extern MPI_Datatype MPI_ExtendedPrecision_type;
420extern MPI_Op MPI_ExtendedPrecision_sum_op;
422void create_extended_MPI_type();
423void create_extended_MPI_operation();
424void extended_sum_op(
void *in,
void *inout,
int *len, MPI_Datatype *datatype);
428void hila_reduce_sum_setup(T *value) {
430 using b_t = hila::arithmetic_type<T>;
431 if constexpr (std::is_same<b_t, ExtendedPrecision>::value) {
432 reduce_node_sum_extended((ExtendedPrecision *)value,
sizeof(T) /
sizeof(ExtendedPrecision),
433 hila::get_allreduce());
434 }
else if (std::is_same<b_t, double>::value) {
435 hila_reduce_double_setup((
double *)value,
sizeof(T) /
sizeof(
double));
436 }
else if (std::is_same<b_t, float>::value) {
437 hila_reduce_float_setup((
float *)value,
sizeof(T) /
sizeof(
float));
This file defines all includes for HILA.
This files containts definitions for the extended precision class that allows for high precision redu...
Implement hila::swap for gauge fields.
void broadcast_array(T *var, int n, int rank=0)
Broadcast for arrays where size must be known and same for all nodes.
int myrank()
rank of this node
int number_of_nodes()
how many nodes there are
void set_allreduce(bool on=true)
set allreduce on (default) or off on the next reduction
void reduce_node_sum(T *value, int send_count, bool allreduce=true)
Reduce an array across nodes.
T broadcast(T &var, int rank=0)
Broadcast the value of var to all MPI ranks from rank (default=0).
void broadcast2(T &t, U &u, int rank=0)
and broadcast with two values