HILA
Loading...
Searching...
No Matches
site_select.h
1#ifndef HILA_SITE_SELECT_H_
2#define HILA_SITE_SELECT_H_
3
4// We insert the GPU code in the same file too
5// hilapp should not read in .cuh, because it does not understand it
6
7
8#include "hila.h"
9
10#include "gpucub.h"
11
12//////////////////////////////////////////////////////////////////////////////////
13/// Site selection: special vector to accumulate chosen sites or sites + variable
14///
15/// SiteSelect s;
16/// SiteValueSelect<T> sv;
17///
18/// To be used within site loops as
19/// onsites(ALL ) {
20/// if ( condition1 )
21/// s.select(X);
22/// if ( condition2 )
23/// sv.select(X, A[X]);
24/// }
25///
26///
27///
28
29// just an empty class used to flag select operations
31
32class SiteSelect {
33
34 protected:
35 std::vector<SiteIndex> sites;
36
37
38 /// status variables of reduction
39 bool auto_join = true;
40 bool joined = false;
41
42 // max number of elements to collect - default volume
43 size_t nmax = lattice.volume();
44
45 size_t current_index = 0;
46 size_t previous_site = SIZE_MAX;
47 size_t n_overflow = 0;
48
49 public:
50 /// Initialize to zero by default (? exception to other variables)
51 /// allreduce = true by default
52 explicit SiteSelect() {
53 auto_join = true;
54 joined = false;
55 nmax = lattice.volume();
56 current_index = 0;
57 previous_site = SIZE_MAX;
58 n_overflow = 0;
59 }
60
61 SiteSelect(const SiteSelect &a) = default;
62
63 /// Destructor cleans up communications if they are in progress
64 ~SiteSelect() = default;
65
66 /// Selection - use only inside loops
67
68 site_select_type_ select(const X_index_type x) {
69 return site_select_type_();
70 // filled in by hilapp
71 }
72
73 // this makes sense only for cpu targets
74 void select_site(const SiteIndex s) {
75 if (s.value == previous_site) {
76 sites[current_index - 1] = s;
77 } else {
78 sites[current_index] = s;
79 previous_site = s.value;
80 current_index++;
81 }
82 }
83
84 SiteSelect &no_join() {
85 auto_join = false;
86 return *this;
87 }
88
89 SiteSelect &max_size(size_t _max) {
90 nmax = _max;
91 return *this;
92 }
93
94 void setup() {
95 sites.resize(lattice->mynode.volume);
96 current_index = 0;
97 previous_site = SIZE_MAX;
98 n_overflow = 0;
99 joined = false;
100 }
101
102 void clear() {
103 sites.clear();
104 current_index = 0;
105 previous_site = SIZE_MAX;
106 }
107
108 size_t size() const {
109 return sites.size();
110 }
111
112 const CoordinateVector coordinates(size_t i) const {
113 return sites.at(i).coordinates();
114 }
115
116 const SiteIndex site_index(size_t i) const {
117 return sites.at(i);
118 }
119
120 // Don't even implement assignments
121
122 /// @brief std::move SiteIndex vector of selected sites, invalidating this variable
123 std::vector<SiteIndex> move_sites() {
124 return std::move(sites);
125 }
126
127 void join() {
128 if (!joined) {
129 std::vector<std::nullptr_t> v;
130 join_data_vectors(v);
131 joined = true;
132 }
133 }
134
135 /// For delayed collect, joining starts or completes the reduction operation
136 template <typename T>
137 void join_data_vectors(std::vector<T> &dp) {
138 if (hila::myrank() == 0) {
139 for (int n = 1; n < hila::number_of_nodes(); n++) {
140 size_t nsend = nmax - sites.size();
141 hila::send_to(n, nsend);
142
143 if (nsend > 0) {
144 std::vector<SiteIndex> s;
145 hila::receive_from(n, s);
146
147 // last element of s contains the overflow number
148 n_overflow += s.back().value;
149 s.pop_back();
150
151 sites.reserve(sites.size() + s.size());
152 sites.insert(sites.end(), s.begin(), s.end());
153
154 if constexpr (!std::is_same<T, std::nullptr_t>::value) {
155 std::vector<T> recvdata;
156 hila::receive_from(n, recvdata);
157 dp.reserve(sites.size());
158 dp.insert(dp.end(), recvdata.begin(), recvdata.end());
159 }
160 } else {
161 // get the overflow number in any case
162 size_t over;
163 hila::receive_from(n, over);
164 n_overflow += over;
165 }
166 }
167
168 } else {
169 // now rank /= 0
170 // wait for the number to be sent
171 size_t nsend;
172 hila::receive_from(0, nsend);
173 if (nsend > 0) {
174 if (nsend < sites.size()) {
175 n_overflow += sites.size() - nsend;
176 sites.resize(nsend);
177 }
178
179 // append overflow info
180 sites.push_back(n_overflow);
181 hila::send_to(0, sites);
182
183 if constexpr (!std::is_same<T, std::nullptr_t>::value) {
184 dp.resize(sites.size() - 1);
185 hila::send_to(0, dp);
186 }
187
188 } else {
189 // send overflow
190 hila::send_to(0, sites.size() + n_overflow);
191 }
192 // empty data to release space
193 clear();
194 }
195 }
196
197 size_t overflow() {
198 return n_overflow;
199 }
200
201#if !(defined(CUDA) || defined(HIP)) || defined(HILAPP)
202
203 void endloop_action() {
204 if (current_index > nmax) {
205 // too many elements, trunc
206 n_overflow = current_index - nmax;
207 current_index = nmax;
208 }
209 sites.resize(current_index);
210 if (auto_join)
211 join();
212 }
213
214#else
215
216 // this is GPU version of endloop_action
217 // skip this for hilapp
218 template <typename T>
219 void copy_data_to_host_vector(std::vector<T> &dvec, const char *flag, const T *d_data) {
220 void *d_temp_storage = nullptr;
221 size_t temp_storage_bytes = 0;
222
223 T *out;
224 gpuMalloc(&out, lattice->mynode.volume * sizeof(T));
225
226 int *num_selected_d;
227 gpuMalloc(&num_selected_d, sizeof(int));
228
229
230 GPU_CHECK(gpucub::DeviceSelect::Flagged(d_temp_storage, temp_storage_bytes, d_data, flag,
231 out, num_selected_d, lattice->mynode.volume));
232
233 gpuMalloc(&d_temp_storage, temp_storage_bytes);
234
235 GPU_CHECK(gpucub::DeviceSelect::Flagged(d_temp_storage, temp_storage_bytes, d_data, flag,
236 out, num_selected_d, lattice->mynode.volume));
237
238 gpuFree(d_temp_storage);
239
240 int num_selected;
241 gpuMemcpy(&num_selected, num_selected_d, sizeof(int), gpuMemcpyDeviceToHost);
242 gpuFree(num_selected_d);
243
244 if (num_selected > nmax) {
245 n_overflow = num_selected - nmax;
246 num_selected = nmax;
247 }
248 dvec.resize(num_selected);
249
250 gpuMemcpy(dvec.data(), out, sizeof(T) * num_selected, gpuMemcpyDeviceToHost);
251 gpuFree(out);
252 }
253
254 // endloop action for this
255 void endloop_action(const char *flag, const SiteIndex *d_sites) {
256
257 copy_data_to_host_vector(sites, flag, d_sites);
258
259 if (auto_join)
260 join();
261 }
262
263#endif // GPU
264};
265
266class site_value_select_type_ {};
267
268template <typename T>
269class SiteValueSelect : public SiteSelect {
270 protected:
271 std::vector<T> values;
272
273 public:
274 explicit SiteValueSelect() : SiteSelect() {
275 values.clear();
276 }
277 ~SiteValueSelect() = default;
278 SiteValueSelect(const SiteValueSelect &v) = default;
279
280 void setup() {
281 SiteSelect::setup();
282 values.resize(lattice->mynode.volume);
283 }
284
285 void clear() {
286 SiteSelect::clear();
287 values.clear();
288 }
289
290 site_value_select_type_ select(const X_index_type x, const T &val) {
291 return site_value_select_type_();
292 }
293
294 void select_site_value(const SiteIndex s, const T &val) {
295 values[current_index] = val;
296 SiteSelect::select_site(s);
297 }
298
299
300 T value(size_t i) {
301 return values.at(i);
302 }
303
304 void join() {
305 if (!joined)
306 join_data_vectors(values);
307 joined = true;
308 }
309
310#if !(defined(CUDA) || defined(HIP)) || defined(HILAPP)
311
312 void endloop_action() {
313 bool save = auto_join;
314 auto_join = false;
315 SiteSelect::endloop_action();
316 values.resize(current_index);
317 auto_join = save;
318 if (auto_join)
319 join();
320 }
321
322#else
323 // skip this for hilapp
324 void endloop_action(const char *flag, const SiteIndex *d_sites, const T *d_values) {
325 copy_data_to_host_vector(sites, flag, d_sites);
326 copy_data_to_host_vector(values, flag, d_values);
327
328 if (auto_join)
329 join();
330 }
331
332#endif // GPU
333};
334
335
336#ifdef HILAPP
337
338// Make hilapp generate __device__ versions of SiteIndex function - this is removed in final program
339
340inline void dummy_func_2() {
341 onsites(ALL) {
342 auto s = SiteIndex(X.coordinates());
343 }
344}
345
346#endif
347
348
349#endif
int64_t volume() const
lattice.volume() returns lattice volume Can be used inside onsites()-loops
Definition lattice.h:424
Running index for locating sites on the lattice.
Definition site_index.h:17
X-coordinate type - "dummy" class.
constexpr Parity ALL
bit pattern: 011
int myrank()
rank of this node
Definition com_mpi.cpp:237
int number_of_nodes()
how many nodes there are
Definition com_mpi.cpp:248
std::ostream out
this is our default output file stream