6#include "plumbing/lattice.h"
10class partitions_struct {
12 unsigned _number, _mylattice;
18 unsigned mylattice() {
26extern partitions_struct partitions;
52int get_next_msg_tag();
66MPI_Datatype get_MPI_number_type(
size_t &size,
bool with_int =
false) {
68 if (std::is_same<hila::arithmetic_type<T>,
int>::value) {
70 return with_int ? MPI_2INT : MPI_INT;
71 }
else if (std::is_same<hila::arithmetic_type<T>,
unsigned>::value) {
72 size =
sizeof(unsigned);
73 return with_int ? MPI_2INT : MPI_UNSIGNED;
74 }
else if (std::is_same<hila::arithmetic_type<T>,
long>::value) {
76 return with_int ? MPI_LONG_INT : MPI_LONG;
77 }
else if (std::is_same<hila::arithmetic_type<T>, int64_t>::value) {
78 size =
sizeof(int64_t);
79 return with_int ? MPI_LONG_INT : MPI_INT64_T;
80 }
else if (std::is_same<hila::arithmetic_type<T>, uint64_t>::value) {
81 size =
sizeof(uint64_t);
82 return with_int ? MPI_LONG_INT : MPI_UINT64_T;
83 }
else if (std::is_same<hila::arithmetic_type<T>,
float>::value) {
85 return with_int ? MPI_FLOAT_INT : MPI_FLOAT;
86 }
else if (std::is_same<hila::arithmetic_type<T>,
double>::value) {
87 size =
sizeof(double);
88 return with_int ? MPI_DOUBLE_INT : MPI_DOUBLE;
89 }
else if (std::is_same<hila::arithmetic_type<T>,
long double>::value) {
90 size =
sizeof(
long double);
91 return with_int ? MPI_LONG_DOUBLE_INT : MPI_LONG_DOUBLE;
100MPI_Datatype get_MPI_number_type() {
102 return get_MPI_number_type<T>(s);
113MPI_Datatype get_MPI_complex_type(
size_t &siz) {
114 if constexpr (std::is_same<T, Complex<double>>::value) {
116 return MPI_C_DOUBLE_COMPLEX;
117 }
else if constexpr (std::is_same<T, Complex<float>>::value) {
119 return MPI_C_FLOAT_COMPLEX;
121 static_assert(
sizeof(T) > 0,
122 "get_MPI_complex_type<T>() called without T being a complex type");
154 static_assert(std::is_trivial<T>::value,
"broadcast(var) must use trivial type");
155 if (hila::check_input)
160 broadcast_timer.start();
161 MPI_Bcast(&var,
sizeof(T), MPI_BYTE, rank, lattice.mpi_comm_lat);
162 broadcast_timer.stop();
177 static_assert(std::is_trivial<T>::value,
"broadcast(std::vector<T>) must have trivial T");
179 if (hila::check_input)
182 broadcast_timer.start();
184 int size = list.size();
185 MPI_Bcast(&size,
sizeof(
int), MPI_BYTE, rank, lattice.mpi_comm_lat);
191 MPI_Bcast((
void *)list.data(),
sizeof(T) * size, MPI_BYTE, rank, lattice.mpi_comm_lat);
193 broadcast_timer.stop();
197template <
typename T,
int n>
200 static_assert(std::is_trivial<T>::value,
"broadcast(std::array<T>) must have trivial T");
202 if (hila::check_input)
205 broadcast_timer.start();
208 MPI_Bcast((
void *)arr.data(),
sizeof(T) * n, MPI_BYTE, rank, lattice.mpi_comm_lat);
210 broadcast_timer.stop();
221 static_assert(
sizeof(T) > 0 &&
222 "Do not use pointers to broadcast()-function. Use 'broadcast_array(T* arr, "
223 "int size)' to broadcast an array");
232 if (hila::check_input)
235 broadcast_timer.start();
236 MPI_Bcast((
void *)var,
sizeof(T) * n, MPI_BYTE, rank, lattice.mpi_comm_lat);
237 broadcast_timer.stop();
241void broadcast(std::string &r,
int rank = 0);
242void broadcast(std::vector<std::string> &l,
int rank = 0);
245template <
typename T,
typename U>
248 if (hila::check_input)
263void send_to(
int to_rank,
const T &data) {
264 if (hila::check_input)
268 MPI_Send(&data,
sizeof(T), MPI_BYTE, to_rank,
hila::myrank(), lattice.mpi_comm_lat);
273void receive_from(
int from_rank, T &data) {
274 if (hila::check_input)
278 MPI_Recv(&data,
sizeof(T), MPI_BYTE, from_rank, from_rank, lattice.mpi_comm_lat,
284void send_to(
int to_rank,
const std::vector<T> &data) {
285 if (hila::check_input)
289 size_t s = data.size();
290 MPI_Send(&s,
sizeof(
size_t), MPI_BYTE, to_rank,
hila::myrank(), lattice.mpi_comm_lat);
292 MPI_Send(data.data(),
sizeof(T) * s, MPI_BYTE, to_rank,
hila::myrank(), lattice.mpi_comm_lat);
297void receive_from(
int from_rank, std::vector<T> &data) {
298 if (hila::check_input)
303 MPI_Recv(&s,
sizeof(
size_t), MPI_BYTE, from_rank, from_rank, lattice.mpi_comm_lat,
307 MPI_Recv(data.data(),
sizeof(T) * s, MPI_BYTE, from_rank, from_rank, lattice.mpi_comm_lat,
319 if (hila::check_input)
322 std::vector<T> recv_data(send_count);
324 dtype = get_MPI_number_type<T>();
326 reduction_timer.start();
328 MPI_Allreduce((
void *)value, (
void *)recv_data.data(),
329 send_count * (
sizeof(T) /
sizeof(hila::arithmetic_type<T>)), dtype, MPI_SUM,
330 lattice.mpi_comm_lat);
331 for (
int i = 0; i < send_count; i++)
332 value[i] = recv_data[i];
334 MPI_Reduce((
void *)value, (
void *)recv_data.data(),
335 send_count * (
sizeof(T) /
sizeof(hila::arithmetic_type<T>)), dtype, MPI_SUM, 0,
336 lattice.mpi_comm_lat);
338 for (
int i = 0; i < send_count; i++)
339 value[i] = recv_data[i];
341 reduction_timer.stop();
357void reduce_node_product(T *send_data,
int send_count,
bool allreduce =
true) {
358 std::vector<T> recv_data(send_count);
361 if (hila::check_input)
364 dtype = get_MPI_number_type<T>();
366 reduction_timer.start();
368 MPI_Allreduce((
void *)send_data, (
void *)recv_data.data(), send_count, dtype, MPI_PROD,
369 lattice.mpi_comm_lat);
370 for (
int i = 0; i < send_count; i++)
371 send_data[i] = recv_data[i];
373 MPI_Reduce((
void *)send_data, (
void *)recv_data.data(), send_count, dtype, MPI_PROD, 0,
374 lattice.mpi_comm_lat);
376 for (
int i = 0; i < send_count; i++)
377 send_data[i] = recv_data[i];
379 reduction_timer.stop();
383T reduce_node_product(T &var,
bool allreduce =
true) {
384 reduce_node_product(&var, 1, allreduce);
395void hila_reduce_double_setup(
double *d,
int n);
396void hila_reduce_float_setup(
float *d,
int n);
397void hila_reduce_sums();
401void hila_reduce_sum_setup(T *value) {
403 using b_t = hila::arithmetic_type<T>;
404 if (std::is_same<b_t, double>::value) {
405 hila_reduce_double_setup((
double *)value,
sizeof(T) /
sizeof(
double));
406 }
else if (std::is_same<b_t, float>::value) {
407 hila_reduce_float_setup((
float *)value,
sizeof(T) /
sizeof(
float));
This file defines all includes for HILA.
Implement hila::swap for gauge fields.
void broadcast_array(T *var, int n, int rank=0)
Broadcast for arrays where size must be known and same for all nodes.
int myrank()
rank of this node
int number_of_nodes()
how many nodes there are
void set_allreduce(bool on=true)
set allreduce on (default) or off on the next reduction
void reduce_node_sum(T *value, int send_count, bool allreduce=true)
Reduce an array across nodes.
T broadcast(T &var, int rank=0)
Broadcast the value of var to all MPI ranks from rank (default=0).
void broadcast2(T &t, U &u, int rank=0)
and broadcast with two values