HILA
Loading...
Searching...
No Matches
site_select.h
1#ifndef SITE_SELECT_H_
2#define SITE_SELECT_H_
3
4// We insert the GPU code in the same file too
5// hilapp should not read in .cuh, because it does not understand it
6
7//#if (defined(CUDA) || defined(HIP)) && !defined(HILAPP)
8#if !defined(HILAPP)
9#if defined(CUDA)
10#include <cub/cub.cuh>
11namespace gpucub = cub;
12#endif
13
14#if defined(HIP)
15#include <hipcub/hipcub.hpp>
16namespace gpucub = hipcub;
17#endif
18#endif // HILAPP
19
20#include "hila.h"
21
22
23//////////////////////////////////////////////////////////////////////////////////
24/// Site selection: special vector to accumulate chosen sites or sites + variable
25///
26/// SiteSelect<> s;
27/// SiteValueSelect<T> sv;
28///
29/// To be used within site loops as
30/// onsites(ALL ) {
31/// if ( condition1 )
32/// s.select(X);
33/// if ( condition2 )
34/// sv.select(X, A[X]);
35/// }
36///
37///
38///
39
40// just an empty class used to flag select operations
42
43class SiteSelect {
44
45 protected:
46 std::vector<SiteIndex> sites;
47
48
49 /// status variables of reduction
50 bool auto_join = true;
51 bool joined = false;
52
53 // max number of elements to collect - default volume
54 size_t nmax = lattice.volume();
55
56 size_t current_index = 0;
57 size_t previous_site = SIZE_MAX;
58 size_t n_overflow = 0;
59
60 public:
61 /// Initialize to zero by default (? exception to other variables)
62 /// allreduce = true by default
63 explicit SiteSelect() {
64 auto_join = true;
65 joined = false;
66 nmax = lattice.volume();
67 current_index = 0;
68 previous_site = SIZE_MAX;
69 n_overflow = 0;
70 }
71
72 SiteSelect(const SiteSelect &a) = default;
73
74 /// Destructor cleans up communications if they are in progress
75 ~SiteSelect() = default;
76
77 /// Selection - use only inside loops
78
79 site_select_type_ select(const X_index_type x) {
80 return site_select_type_();
81 // filled in by hilapp
82 }
83
84 // this makes sense only for cpu targets
85 void select_site(const SiteIndex s) {
86 if (s.value == previous_site) {
87 sites[current_index - 1] = s;
88 } else {
89 sites[current_index] = s;
90 previous_site = s.value;
91 current_index++;
92 }
93 }
94
95 SiteSelect &no_join() {
96 auto_join = false;
97 return *this;
98 }
99
100 SiteSelect &max_size(size_t _max) {
101 nmax = _max;
102 return *this;
103 }
104
105 void setup() {
106 sites.resize(lattice.mynode.volume());
107 current_index = 0;
108 previous_site = SIZE_MAX;
109 n_overflow = 0;
110 joined = false;
111 }
112
113 void clear() {
114 sites.clear();
115 current_index = 0;
116 previous_site = SIZE_MAX;
117 }
118
119 size_t size() const {
120 return sites.size();
121 }
122
123 const CoordinateVector coordinates(size_t i) const {
124 return sites.at(i).coordinates();
125 }
126
127 const SiteIndex site_index(size_t i) const {
128 return sites.at(i);
129 }
130
131 // Don't even implement assignments
132
133
134 void join() {
135 if (!joined) {
136 std::vector<std::nullptr_t> v;
137 join_data_vectors(v);
138 joined = true;
139 }
140 }
141
142 /// For delayed collect, joining starts or completes the reduction operation
143 template <typename T>
144 void join_data_vectors(std::vector<T> &dp) {
145 if (hila::myrank() == 0) {
146 for (int n = 1; n < hila::number_of_nodes(); n++) {
147 size_t nsend = nmax - sites.size();
148 hila::send_to(n, nsend);
149
150 if (nsend > 0) {
151 std::vector<SiteIndex> s;
152 hila::receive_from(n, s);
153
154 // last element of s contains the overflow number
155 n_overflow += s.back().value;
156 s.pop_back();
157
158 sites.reserve(sites.size() + s.size());
159 sites.insert(sites.end(), s.begin(), s.end());
160
161 if constexpr (!std::is_same<T, std::nullptr_t>::value) {
162 std::vector<T> recvdata;
163 hila::receive_from(n, recvdata);
164 dp.reserve(sites.size());
165 dp.insert(dp.end(), recvdata.begin(), recvdata.end());
166 }
167 } else {
168 // get the overflow number in any case
169 size_t over;
170 hila::receive_from(n, over);
171 n_overflow += over;
172 }
173 }
174
175 } else {
176 // now rank /= 0
177 // wait for the number to be sent
178 size_t nsend;
179 hila::receive_from(0, nsend);
180 if (nsend > 0) {
181 if (nsend < sites.size()) {
182 n_overflow += sites.size() - nsend;
183 sites.resize(nsend);
184 }
185
186 // append overflow info
187 sites.push_back(n_overflow);
188 hila::send_to(0, sites);
189
190 if constexpr (!std::is_same<T, std::nullptr_t>::value) {
191 dp.resize(sites.size() - 1);
192 hila::send_to(0, dp);
193 }
194
195 } else {
196 // send overflow
197 hila::send_to(0, sites.size() + n_overflow);
198 }
199 // empty data to release space
200 clear();
201 }
202 }
203
204 size_t overflow() {
205 return n_overflow;
206 }
207
208#if !(defined(CUDA) || defined(HIP)) || defined(HILAPP)
209
210 void endloop_action() {
211 if (current_index > nmax) {
212 // too many elements, trunc
213 n_overflow = current_index - nmax;
214 current_index = nmax;
215 }
216 sites.resize(current_index);
217 if (auto_join)
218 join();
219 }
220
221#else
222
223 // this is GPU version of endloop_action
224 // skip this for hilapp
225 template <typename T>
226 void copy_data_to_host_vector(std::vector<T> &dvec, const char *flag, const T *d_data) {
227 void *d_temp_storage = nullptr;
228 size_t temp_storage_bytes = 0;
229
230 T *out;
231 gpuMalloc(&out, lattice.mynode.volume() * sizeof(T));
232
233 int *num_selected_d;
234 gpuMalloc(&num_selected_d, sizeof(int));
235
236
237 GPU_CHECK(gpucub::DeviceSelect::Flagged(d_temp_storage, temp_storage_bytes, d_data, flag,
238 out, num_selected_d, lattice.mynode.volume()));
239
240 gpuMalloc(&d_temp_storage, temp_storage_bytes);
241
242 GPU_CHECK(gpucub::DeviceSelect::Flagged(d_temp_storage, temp_storage_bytes, d_data, flag,
243 out, num_selected_d, lattice.mynode.volume()));
244
245 gpuFree(d_temp_storage);
246
247 int num_selected;
248 gpuMemcpy(&num_selected, num_selected_d, sizeof(int), gpuMemcpyDeviceToHost);
249 gpuFree(num_selected_d);
250
251 if (num_selected > nmax) {
252 n_overflow = num_selected - nmax;
253 num_selected = nmax;
254 }
255 dvec.resize(num_selected);
256
257 gpuMemcpy(dvec.data(), out, sizeof(T) * num_selected, gpuMemcpyDeviceToHost);
258 gpuFree(out);
259 }
260
261 // endloop action for this
262 void endloop_action(const char *flag, const SiteIndex *d_sites) {
263
264 copy_data_to_host_vector(sites, flag, d_sites);
265
266 if (auto_join)
267 join();
268 }
269
270#endif // GPU
271};
272
273class site_value_select_type_ {};
274
275template <typename T>
276class SiteValueSelect : public SiteSelect {
277 protected:
278 std::vector<T> values;
279
280 public:
281 explicit SiteValueSelect() : SiteSelect() {
282 values.clear();
283 }
284 ~SiteValueSelect() = default;
285 SiteValueSelect(const SiteValueSelect &v) = default;
286
287 void setup() {
288 SiteSelect::setup();
289 values.resize(lattice.mynode.volume());
290 }
291
292 void clear() {
293 SiteSelect::clear();
294 values.clear();
295 }
296
297 site_value_select_type_ select(const X_index_type x, const T &val) {
298 return site_value_select_type_();
299 }
300
301 void select_site_value(const SiteIndex s, const T &val) {
302 values[current_index] = val;
303 SiteSelect::select_site(s);
304 }
305
306
307 T value(size_t i) {
308 return values.at(i);
309 }
310
311 void join() {
312 if (!joined)
313 join_data_vectors(values);
314 joined = true;
315 }
316
317#if !(defined(CUDA) || defined(HIP)) || defined(HILAPP)
318
319 void endloop_action() {
320 bool save = auto_join;
321 auto_join = false;
322 SiteSelect::endloop_action();
323 values.resize(current_index);
324 auto_join = save;
325 if (auto_join)
326 join();
327 }
328
329#else
330 // skip this for hilapp
331 void endloop_action(const char *flag, const SiteIndex *d_sites, const T *d_values) {
332 copy_data_to_host_vector(sites, flag, d_sites);
333 copy_data_to_host_vector(values, flag, d_values);
334
335 if (auto_join)
336 join();
337 }
338
339#endif // GPU
340};
341
342
343#ifdef HILAPP
344
345// Make hilapp generate __device__ versions of SiteIndex function - this is removed in final program
346
347inline void dummy_func_2() {
348 onsites(ALL) {
349 auto s = SiteIndex(X.coordinates());
350 }
351}
352
353#endif
354
355
356#endif
Running index for locating sites on the lattice.
Definition site_index.h:17
X-coordinate type - "dummy" class.
constexpr Parity ALL
bit pattern: 011
int myrank()
rank of this node
Definition com_mpi.cpp:235
int number_of_nodes()
how many nodes there are
Definition com_mpi.cpp:246
std::ostream out
this is our default output file stream