6#include "../field_storage.h"
7#include "vector_types.h"
8#include "../coordinates.h"
16template <
typename A,
int vector_size,
class Enable =
void>
20template <
typename A,
int vector_size>
21struct vectorize_struct<A, vector_size, typename std::enable_if_t<hila::is_arithmetic<A>::value>> {
22 using type =
typename hila::vector_base_type<A, vector_size>::type;
26template <
template <
typename B>
class C,
typename B,
int vector_size>
29 using type = C<vectorized_B>;
32template <
template <
int a,
typename B>
class C,
int a,
typename B,
int vector_size>
35 using type = C<a, vectorized_B>;
38template <
template <
int a,
int b,
typename B>
class C,
int a,
int b,
typename B,
int vector_size>
41 using type = C<a, b, vectorized_B>;
67 fieldbuf = (T *)memalloc(
69 ->field_alloc_size() *
72 fieldbuf = (T *)memalloc(
sizeof(T) * lattice.field_alloc_size());
78#pragma acc exit data delete (fieldbuf)
79 if (fieldbuf !=
nullptr)
87template <
typename vecT>
89 using vectortype =
typename hila::vector_info<T>::type;
90 using basetype =
typename hila::vector_info<T>::base_type;
95 static_assert(
sizeof(vecT) ==
sizeof(T) * vector_size);
99 basetype *vp = (basetype *)(fieldbuf) + i * elements * vector_size;
100 vectortype *valuep = (vectortype *)(&value);
101 for (
unsigned e = 0; e < elements; e++) {
102 valuep[e].load_a(vp + e * vector_size);
110template <
typename vecT>
112 using vectortype =
typename hila::vector_info<T>::type;
113 using basetype =
typename hila::vector_info<T>::base_type;
117 static_assert(
sizeof(vecT) ==
sizeof(T) * vector_size);
119 basetype *vp = (basetype *)(fieldbuf) + i * elements * vector_size;
120 vectortype *valuep = (vectortype *)(&value);
121 for (
unsigned e = 0; e < elements; e++) {
122 valuep[e].store_a(vp + e * vector_size);
132 using basetype =
typename hila::vector_info<T>::base_type;
139 ((basetype *)(fieldbuf)) + (idx / vector_size) * vector_size * elements + idx % vector_size;
140 const basetype *
RESTRICT vp = (basetype *)(&value);
141 for (
unsigned e = 0; e < elements; e++) {
142 b[e * vector_size] = vp[e];
152 using basetype =
typename hila::vector_info<T>::base_type;
156 static_assert(
sizeof(T) ==
sizeof(basetype) * elements);
162 (basetype *)(fieldbuf) + (idx / vector_size) * vector_size * elements + idx % vector_size;
163 basetype *
RESTRICT vp = (basetype *)(&value);
164 for (
unsigned e = 0; e < elements; e++) {
165 vp[e] = b[e * vector_size];
175 for (
unsigned j = 0; j < n; j++) {
176 buffer[j] = get_element(index_list[j]);
180#ifdef SPECIAL_BOUNDARY_CONDITIONS
184 const unsigned *
RESTRICT index_list,
int n,
187 for (
unsigned j = 0; j < n; j++) {
188 buffer[j] = -get_element(index_list[j]);
192 assert(
sizeof(T) < 1 &&
"Antiperiodic boundary conditions require that unary - "
193 "-operator is defined!");
203 for (
unsigned j = 0; j < n; j++) {
204 set_element(buffer[j], index_list[j]);
213#ifndef SPECIAL_BOUNDARY_CONDITIONS
214 assert(!antiperiodic &&
"antiperiodic only with SPECIAL_BOUNDARY_CONDITIONS");
223 using vectortype =
typename hila::vector_info<T>::type;
224 using basetype =
typename hila::vector_info<T>::base_type;
229 const auto vector_lattice =
230 lattice.backend_lattice
231 ->template get_vectorized_lattice<hila::vector_info<T>::vector_size>();
234 if (vector_lattice->is_boundary_permutation[
abs(dir)] ||
235 vector_lattice->only_local_boundary_copy[dir]) {
238 unsigned end = vector_lattice->n_halo_vectors[dir];
240 start = vector_lattice->n_halo_vectors[dir] / 2;
242 end = vector_lattice->n_halo_vectors[dir] / 2;
243 unsigned offset = vector_lattice->halo_offset[dir];
248 if (vector_lattice->is_boundary_permutation[
abs(dir)]) {
251 const int *
RESTRICT perm = vector_lattice->boundary_permutation[dir];
253 basetype *fp =
static_cast<basetype *
>(
static_cast<void *
>(fieldbuf));
254 for (
unsigned idx = start; idx < end; idx++) {
256 basetype *
RESTRICT t = fp + (idx + offset) * (elements * vector_size);
258 fp + vector_lattice->halo_index[dir][idx] * (elements * vector_size);
261 for (
unsigned e = 0; e < elements * vector_size; e += vector_size)
262 for (
unsigned i = 0; i < vector_size; i++)
263 t[e + i] = s[e + perm[i]];
265#ifdef SPECIAL_BOUNDARY_CONDITIONS
266 for (
unsigned e = 0; e < elements * vector_size; e += vector_size)
267 for (
unsigned i = 0; i < vector_size; i++)
268 t[e + i] = -s[e + perm[i]];
277 for (
unsigned idx = start; idx < end; idx++) {
278 std::memcpy(fieldbuf + (idx + offset) * vector_size,
279 fieldbuf + vector_lattice->halo_index[dir][idx] * vector_size,
280 sizeof(T) * vector_size);
283#ifdef SPECIAL_BOUNDARY_CONDITIONS
284 basetype *fp =
static_cast<basetype *
>(
static_cast<void *
>(fieldbuf));
285 for (
unsigned idx = start; idx < end; idx++) {
287 basetype *
RESTRICT t = fp + (idx + offset) * (elements * vector_size);
289 fp + vector_lattice->halo_index[dir][idx] * (elements * vector_size);
290 for (
unsigned e = 0; e < elements * vector_size; e++)
302#ifdef SPECIAL_BOUNDARY_CONDITIONS
305 unsigned n, start = 0;
307 n = lattice.special_boundaries[dir].n_odd;
308 start = lattice.special_boundaries[dir].n_even;
311 n = lattice.special_boundaries[dir].n_even;
313 n = lattice.special_boundaries[dir].n_total;
315 unsigned offset = lattice.special_boundaries[dir].offset + start;
317 gather_elements_negated(fieldbuf + offset,
318 lattice.special_boundaries[dir].move_index + start, n, lattice);
329 bool antiperiodic)
const {
335 using basetype =
typename hila::vector_info<T>::base_type;
338 const unsigned *index_list = to_node.get_sitelist(par, n);
340 assert(n % vector_size == 0);
343 for (
unsigned i = 0; i < n; i += vector_size) {
344 std::memcpy(buffer + i, fieldbuf + index_list[i],
sizeof(T) * vector_size);
347 for (
unsigned j = 0; j < vector_size; j++)
348 assert(index_list[i] + j == index_list[i + j]);
352 for (
unsigned i = 0; i < n; i += vector_size) {
353 basetype *
RESTRICT t =
static_cast<basetype *
>(
static_cast<void *
>(buffer + i));
355 static_cast<basetype *
>(
static_cast<void *
>(fieldbuf + index_list[i]));
356 for (
unsigned e = 0; e < elements * vector_size; e++)
370 using basetype =
typename hila::vector_info<T>::base_type;
374 start = vlat->recv_list_size[d] / 2;
375 unsigned n = vlat->recv_list_size[d];
381 T *targetbuf =
const_cast<T *
>(fieldbuf);
383 for (
unsigned i = 0; i < n; i++) {
384 unsigned idx = vlat->recv_list[d][i + start];
386 basetype *
RESTRICT t = ((basetype *)targetbuf) +
387 (idx / vector_size) * vector_size * elements + idx % vector_size;
388 const basetype *
RESTRICT vp = (basetype *)(&buffer[i]);
390 for (
unsigned e = 0; e < elements; e++) {
391 t[e * vector_size] = vp[e];
403 return (T *)memalloc(n *
sizeof(T));
The field_storage struct contains minimal information for using the field in a loop....
void place_elements(T *__restrict__ buffer, const unsigned *__restrict__ index_list, int n, const lattice_struct &lattice)
CUDA implementation of place_elements without CUDA aware MPI.
void gather_elements(T *__restrict__ buffer, const unsigned *__restrict__ index_list, int n, const lattice_struct &lattice) const
CUDA implementation of gather_elements without CUDA aware MPI.
void gather_elements_negated(T *__restrict__ buffer, const unsigned *__restrict__ index_list, int n, const lattice_struct &lattice) const
CUDA implementation of gather_elements_negated without CUDA aware MPI.
void set_local_boundary_elements(Direction dir, Parity par, const lattice_struct &lattice, bool antiperiodic)
Place boundary elements from local lattice (used in vectorized version)
auto get_element(const unsigned i, const lattice_struct &lattice) const
Conditionally reture bool type false if type T does't have unary - operator.
T abs(const Complex< T > &a)
Return absolute value of Complex number.
Parity
Parity enum with values EVEN, ODD, ALL; refers to parity of the site. Parity of site (x,...
constexpr Parity EVEN
bit pattern: 001
constexpr Parity ODD
bit pattern: 010
Direction
Enumerator for direction that assigns integer to direction to be interpreted as unit vector.
constexpr Parity ALL
bit pattern: 011
This file defines all includes for HILA.
is_vectorizable_type<T>::value is always false if the target is not vectorizable
Information necessary to communicate with a node.
Replaces basetypes with vectors in a given templated class.