HILA
Loading...
Searching...
No Matches
backend_vector/field_storage_backend.h
1#ifndef VECTOR_BACKEND
2#define VECTOR_BACKEND
3
4#include "../defs.h"
5#include "../lattice.h"
6#include "../field_storage.h"
7#include "vector_types.h"
8#include "../coordinates.h"
9#include "defs.h"
10
11/// Replaces basetypes with vectors in a given templated class
12
13/// First base definition for replace_type, which recursively looks for the
14/// base type and replaces it in the end
15/// General template, never matched
16template <typename A, int vector_size, class Enable = void>
18
19/// A is a basic type, so just return the matching vector type
20template <typename A, int vector_size>
21struct vectorize_struct<A, vector_size, typename std::enable_if_t<hila::is_arithmetic<A>::value>> {
22 using type = typename hila::vector_base_type<A, vector_size>::type;
23};
24
25// B is a templated class, so construct a vectorized type
26template <template <typename B> class C, typename B, int vector_size>
27struct vectorize_struct<C<B>, vector_size> {
28 using vectorized_B = typename vectorize_struct<B, vector_size>::type;
29 using type = C<vectorized_B>;
30};
31
32template <template <int a, typename B> class C, int a, typename B, int vector_size>
33struct vectorize_struct<C<a, B>, vector_size> {
34 using vectorized_B = typename vectorize_struct<B, vector_size>::type;
35 using type = C<a, vectorized_B>;
36};
37
38template <template <int a, int b, typename B> class C, int a, int b, typename B, int vector_size>
39struct vectorize_struct<C<a, b, B>, vector_size> {
40 using vectorized_B = typename vectorize_struct<B, vector_size>::type;
41 using type = C<a, b, vectorized_B>;
42};
43
44/// Match coordinate vectors explicitly
45// template<>
46// struct vectorize_struct<CoordinateVector, 4> {
47// using type = std::array<Vec4i, NDIM>;
48// };
49
50// template<>
51// struct vectorize_struct<CoordinateVector, 8> {
52// using type = std::array<Vec8i, NDIM>;
53// };
54
55// template<>
56// struct vectorize_struct<CoordinateVector, 16> {
57// using type = std::array<Vec16i, NDIM>;
58// };
59
60/// Short version of mapping type to longest possible vector
61template <typename T>
62using vector_type = typename vectorize_struct<T, hila::vector_info<T>::vector_size>::type;
63
64template <typename T>
67 fieldbuf = (T *)memalloc(
69 ->field_alloc_size() *
70 sizeof(T));
71 } else {
72 fieldbuf = (T *)memalloc(sizeof(T) * lattice.field_alloc_size());
73 }
74}
75
76template <typename T>
78#pragma acc exit data delete (fieldbuf)
79 if (fieldbuf != nullptr)
80 free(fieldbuf);
81 fieldbuf = nullptr;
82}
83
84// get and set a full vector T
85
86template <typename T>
87template <typename vecT>
88inline vecT field_storage<T>::get_vector(const unsigned i) const {
89 using vectortype = typename hila::vector_info<T>::type;
90 using basetype = typename hila::vector_info<T>::base_type;
91 constexpr size_t elements = hila::vector_info<T>::elements;
92 constexpr size_t vector_size = hila::vector_info<T>::vector_size;
93 // using vectorized_type = vector_type<T>;
94
95 static_assert(sizeof(vecT) == sizeof(T) * vector_size);
96 // assert (((int64_t)fieldbuf) % ((vector_size)*sizeof(basetype)) == 0);
97
98 vecT value;
99 basetype *vp = (basetype *)(fieldbuf) + i * elements * vector_size;
100 vectortype *valuep = (vectortype *)(&value);
101 for (unsigned e = 0; e < elements; e++) {
102 valuep[e].load_a(vp + e * vector_size);
103 }
104 return value;
105}
106
107// note: here i is the vector index
108
109template <typename T>
110template <typename vecT>
111inline void field_storage<T>::set_vector(const vecT &value, const unsigned i) {
112 using vectortype = typename hila::vector_info<T>::type;
113 using basetype = typename hila::vector_info<T>::base_type;
114 constexpr size_t elements = hila::vector_info<T>::elements;
115 constexpr size_t vector_size = hila::vector_info<T>::vector_size;
116
117 static_assert(sizeof(vecT) == sizeof(T) * vector_size);
118
119 basetype *vp = (basetype *)(fieldbuf) + i * elements * vector_size;
120 vectortype *valuep = (vectortype *)(&value);
121 for (unsigned e = 0; e < elements; e++) {
122 valuep[e].store_a(vp + e * vector_size);
123 }
124}
125
126/// set_element scatters one individual T-element to vectorized store,
127/// using the "site" index idx.
128
129template <typename T>
130inline void field_storage<T>::set_element(const T &value, const unsigned idx) {
132 using basetype = typename hila::vector_info<T>::base_type;
133 constexpr size_t elements = hila::vector_info<T>::elements;
134 constexpr size_t vector_size = hila::vector_info<T>::vector_size;
135
136 // "base" of the vector is (idx/vector_size)*elements; index in vector is idx %
137 // vector_size
138 basetype *RESTRICT b =
139 ((basetype *)(fieldbuf)) + (idx / vector_size) * vector_size * elements + idx % vector_size;
140 const basetype *RESTRICT vp = (basetype *)(&value);
141 for (unsigned e = 0; e < elements; e++) {
142 b[e * vector_size] = vp[e];
143 }
144}
145
146/// get_element gathers one T-element from vectorized store
147/// again, idx is the "site" index
148
149template <typename T>
150inline T field_storage<T>::get_element(const unsigned idx) const {
152 using basetype = typename hila::vector_info<T>::base_type;
153 constexpr size_t elements = hila::vector_info<T>::elements;
154 constexpr size_t vector_size = hila::vector_info<T>::vector_size;
155
156 static_assert(sizeof(T) == sizeof(basetype) * elements);
157
158 T value;
159 // "base" of the vector is (idx/vector_size)*elements; index in vector is idx %
160 // vector_size
161 const basetype *RESTRICT b =
162 (basetype *)(fieldbuf) + (idx / vector_size) * vector_size * elements + idx % vector_size;
163 basetype *RESTRICT vp = (basetype *)(&value); // does going through address slow down?
164 for (unsigned e = 0; e < elements; e++) {
165 vp[e] = b[e * vector_size];
166 }
167 return value;
168}
169
170/// Fetch elements from the field to buffer using sites in index_list
171template <typename T>
172void field_storage<T>::gather_elements(T *RESTRICT buffer, const unsigned *RESTRICT index_list,
173 int n, const lattice_struct &lattice) const {
174
175 for (unsigned j = 0; j < n; j++) {
176 buffer[j] = get_element(index_list[j]);
177 }
178}
179
180#ifdef SPECIAL_BOUNDARY_CONDITIONS
181
182template <typename T>
184 const unsigned *RESTRICT index_list, int n,
185 const lattice_struct &lattice) const {
186 if constexpr (hila::has_unary_minus<T>::value) {
187 for (unsigned j = 0; j < n; j++) {
188 buffer[j] = -get_element(index_list[j]); /// requires unary - !!
189 }
190 } else {
191 // sizeof(T) here to prevent compile time evaluation of assert
192 assert(sizeof(T) < 1 && "Antiperiodic boundary conditions require that unary - "
193 "-operator is defined!");
194 }
195}
196
197#endif
198
199/// Vectorized implementation of setting elements
200template <typename T>
201void field_storage<T>::place_elements(T *RESTRICT buffer, const unsigned *RESTRICT index_list,
202 int n, const lattice_struct &lattice) {
203 for (unsigned j = 0; j < n; j++) {
204 set_element(buffer[j], index_list[j]);
205 }
206}
207
208template <typename T>
210 const lattice_struct &lattice,
211 bool antiperiodic) {
212
213#ifndef SPECIAL_BOUNDARY_CONDITIONS
214 assert(!antiperiodic && "antiperiodic only with SPECIAL_BOUNDARY_CONDITIONS");
215#endif
216
218
219 // do the boundary vectorized copy
220
221 constexpr size_t vector_size = hila::vector_info<T>::vector_size;
222 constexpr size_t elements = hila::vector_info<T>::elements;
223 using vectortype = typename hila::vector_info<T>::type;
224 using basetype = typename hila::vector_info<T>::base_type;
225
226 // hila::out0 << "Vectorized boundary dir " << dir << " parity " << (int)par << "
227 // bc " << (int)antiperiodic << '\n';
228
229 const auto vector_lattice =
230 lattice.backend_lattice
231 ->template get_vectorized_lattice<hila::vector_info<T>::vector_size>();
232 // The halo copy and permutation is only necessary if vectorization
233 // splits the lattice in this Direction or local boundary is copied
234 if (vector_lattice->is_boundary_permutation[abs(dir)] ||
235 vector_lattice->only_local_boundary_copy[dir]) {
236
237 unsigned start = 0;
238 unsigned end = vector_lattice->n_halo_vectors[dir];
239 if (par == ODD)
240 start = vector_lattice->n_halo_vectors[dir] / 2;
241 if (par == EVEN)
242 end = vector_lattice->n_halo_vectors[dir] / 2;
243 unsigned offset = vector_lattice->halo_offset[dir];
244
245 /// Loop over the boundary sites - i is the vector index
246 /// location where the vectors are copied from are in halo_index
247
248 if (vector_lattice->is_boundary_permutation[abs(dir)]) {
249
250 // hila::out0 << "its permutation\n";
251 const int *RESTRICT perm = vector_lattice->boundary_permutation[dir];
252
253 basetype *fp = static_cast<basetype *>(static_cast<void *>(fieldbuf));
254 for (unsigned idx = start; idx < end; idx++) {
255 /// get ptrs to target and source vec elements
256 basetype *RESTRICT t = fp + (idx + offset) * (elements * vector_size);
257 basetype *RESTRICT s =
258 fp + vector_lattice->halo_index[dir][idx] * (elements * vector_size);
259
260 if (!antiperiodic) {
261 for (unsigned e = 0; e < elements * vector_size; e += vector_size)
262 for (unsigned i = 0; i < vector_size; i++)
263 t[e + i] = s[e + perm[i]];
264 } else {
265#ifdef SPECIAL_BOUNDARY_CONDITIONS
266 for (unsigned e = 0; e < elements * vector_size; e += vector_size)
267 for (unsigned i = 0; i < vector_size; i++)
268 t[e + i] = -s[e + perm[i]];
269#endif
270 }
271 }
272 } else {
273 // hila::out0 << "its not permutation, go for copy: bc " <<
274 // (int)antiperiodic << '\n';
275 if (!antiperiodic) {
276 // no boundary permutation, straight copy for all vectors
277 for (unsigned idx = start; idx < end; idx++) {
278 std::memcpy(fieldbuf + (idx + offset) * vector_size,
279 fieldbuf + vector_lattice->halo_index[dir][idx] * vector_size,
280 sizeof(T) * vector_size);
281 }
282 } else {
283#ifdef SPECIAL_BOUNDARY_CONDITIONS
284 basetype *fp = static_cast<basetype *>(static_cast<void *>(fieldbuf));
285 for (unsigned idx = start; idx < end; idx++) {
286 /// get ptrs to target and source vec elements
287 basetype *RESTRICT t = fp + (idx + offset) * (elements * vector_size);
288 basetype *RESTRICT s =
289 fp + vector_lattice->halo_index[dir][idx] * (elements * vector_size);
290 for (unsigned e = 0; e < elements * vector_size; e++)
291 t[e] = -s[e];
292 }
293#endif
294 }
295 }
296 }
297
298 } else {
299 // now the field is not vectorized. Std. access copy
300 // needed only if b.c. is not periodic
301
302#ifdef SPECIAL_BOUNDARY_CONDITIONS
303 if (antiperiodic) {
304 // need to copy or do something w. local boundary
305 unsigned n, start = 0;
306 if (par == ODD) {
307 n = lattice.special_boundaries[dir].n_odd;
308 start = lattice.special_boundaries[dir].n_even;
309 } else {
310 if (par == EVEN)
311 n = lattice.special_boundaries[dir].n_even;
312 else
313 n = lattice.special_boundaries[dir].n_total;
314 }
315 unsigned offset = lattice.special_boundaries[dir].offset + start;
316
317 gather_elements_negated(fieldbuf + offset,
318 lattice.special_boundaries[dir].move_index + start, n, lattice);
319 }
320#endif
321 }
322}
323
324// gather full vectors from fieldbuf to buffer, for communications
325template <typename T>
327 T *RESTRICT buffer, const lattice_struct::comm_node_struct &to_node, Parity par,
329 bool antiperiodic) const {
330
331 // Use sitelist in to_node, but use only every vector_size -index. These point to
332 // the beginning of the vector
333 constexpr size_t vector_size = hila::vector_info<T>::vector_size;
334 constexpr size_t elements = hila::vector_info<T>::elements;
335 using basetype = typename hila::vector_info<T>::base_type;
336
337 int n;
338 const unsigned *index_list = to_node.get_sitelist(par, n);
339
340 assert(n % vector_size == 0);
341
342 if (!antiperiodic) {
343 for (unsigned i = 0; i < n; i += vector_size) {
344 std::memcpy(buffer + i, fieldbuf + index_list[i], sizeof(T) * vector_size);
345
346 // check that indices are really what they should -- REMOVE
347 for (unsigned j = 0; j < vector_size; j++)
348 assert(index_list[i] + j == index_list[i + j]);
349 }
350 } else {
351 // copy this as elements
352 for (unsigned i = 0; i < n; i += vector_size) {
353 basetype *RESTRICT t = static_cast<basetype *>(static_cast<void *>(buffer + i));
354 basetype *RESTRICT s =
355 static_cast<basetype *>(static_cast<void *>(fieldbuf + index_list[i]));
356 for (unsigned e = 0; e < elements * vector_size; e++)
357 t[e] = -s[e];
358 }
359 }
360}
361
362// Place the received MPI elements to halo (neighbour) buffer
363template <typename T>
365 const T *RESTRICT buffer, Direction d, Parity par,
367
368 constexpr size_t vector_size = hila::vector_info<T>::vector_size;
369 constexpr size_t elements = hila::vector_info<T>::elements;
370 using basetype = typename hila::vector_info<T>::base_type;
371
372 unsigned start = 0;
373 if (par == ODD)
374 start = vlat->recv_list_size[d] / 2;
375 unsigned n = vlat->recv_list_size[d];
376 if (par != ALL)
377 n /= 2;
378
379 // remove const -- the payload of the buffer remains const, but the halo bits are
380 // changed
381 T *targetbuf = const_cast<T *>(fieldbuf);
382
383 for (unsigned i = 0; i < n; i++) {
384 unsigned idx = vlat->recv_list[d][i + start];
385
386 basetype *RESTRICT t = ((basetype *)targetbuf) +
387 (idx / vector_size) * vector_size * elements + idx % vector_size;
388 const basetype *RESTRICT vp = (basetype *)(&buffer[i]);
389
390 for (unsigned e = 0; e < elements; e++) {
391 t[e * vector_size] = vp[e];
392 }
393 }
394}
395
396template <typename T>
397void field_storage<T>::free_mpi_buffer(T *buffer) {
398 std::free(buffer);
399}
400
401template <typename T>
403 return (T *)memalloc(n * sizeof(T));
404}
405
406#endif
The field_storage struct contains minimal information for using the field in a loop....
void place_elements(T *__restrict__ buffer, const unsigned *__restrict__ index_list, int n, const lattice_struct &lattice)
CUDA implementation of place_elements without CUDA aware MPI.
void gather_elements(T *__restrict__ buffer, const unsigned *__restrict__ index_list, int n, const lattice_struct &lattice) const
CUDA implementation of gather_elements without CUDA aware MPI.
void gather_elements_negated(T *__restrict__ buffer, const unsigned *__restrict__ index_list, int n, const lattice_struct &lattice) const
void set_local_boundary_elements(Direction dir, Parity par, const lattice_struct &lattice, bool antiperiodic)
Place boundary elements from local lattice (used in vectorized version)
auto get_element(const unsigned i, const lattice_struct &lattice) const
Conditionally reture bool type false if type T does't have unary - operator.
T abs(const Complex< T > &a)
Return absolute value of Complex number.
Definition cmplx.h:1187
Parity
Parity enum with values EVEN, ODD, ALL; refers to parity of the site. Parity of site (x,...
constexpr Parity EVEN
bit pattern: 001
constexpr Parity ODD
bit pattern: 010
Direction
Enumerator for direction that assigns integer to direction to be interpreted as unit vector.
Definition coordinates.h:34
constexpr Parity ALL
bit pattern: 011
This file defines all includes for HILA.
#define RESTRICT
Definition defs.h:51
vectorized_lattice_struct< vector_size > * get_vectorized_lattice()
Returns a vectorized lattice with given vector size.
is_vectorizable_type<T>::value is always false if the target is not vectorizable
Definition vector_types.h:8
Information necessary to communicate with a node.
Definition lattice.h:134
Replaces basetypes with vectors in a given templated class.