11namespace gpucub = cub;
15#include <hipcub/hipcub.hpp>
16namespace gpucub = hipcub;
46 std::vector<SiteIndex> sites;
50 bool auto_join =
true;
54 size_t nmax = lattice.volume();
56 size_t current_index = 0;
57 size_t previous_site = SIZE_MAX;
58 size_t n_overflow = 0;
63 explicit SiteSelect() {
66 nmax = lattice.volume();
68 previous_site = SIZE_MAX;
72 SiteSelect(
const SiteSelect &a) =
default;
75 ~SiteSelect() =
default;
86 if (s.value == previous_site) {
87 sites[current_index - 1] = s;
89 sites[current_index] = s;
90 previous_site = s.value;
95 SiteSelect &no_join() {
100 SiteSelect &max_size(
size_t _max) {
106 sites.resize(lattice.mynode.volume());
108 previous_site = SIZE_MAX;
116 previous_site = SIZE_MAX;
119 size_t size()
const {
124 return sites.at(i).coordinates();
127 const SiteIndex site_index(
size_t i)
const {
136 std::vector<std::nullptr_t> v;
137 join_data_vectors(v);
143 template <
typename T>
144 void join_data_vectors(std::vector<T> &dp) {
147 size_t nsend = nmax - sites.size();
148 hila::send_to(n, nsend);
151 std::vector<SiteIndex> s;
152 hila::receive_from(n, s);
155 n_overflow += s.back().value;
158 sites.reserve(sites.size() + s.size());
159 sites.insert(sites.end(), s.begin(), s.end());
161 if constexpr (!std::is_same<T, std::nullptr_t>::value) {
162 std::vector<T> recvdata;
163 hila::receive_from(n, recvdata);
164 dp.reserve(sites.size());
165 dp.insert(dp.end(), recvdata.begin(), recvdata.end());
170 hila::receive_from(n, over);
179 hila::receive_from(0, nsend);
181 if (nsend < sites.size()) {
182 n_overflow += sites.size() - nsend;
187 sites.push_back(n_overflow);
188 hila::send_to(0, sites);
190 if constexpr (!std::is_same<T, std::nullptr_t>::value) {
191 dp.resize(sites.size() - 1);
192 hila::send_to(0, dp);
197 hila::send_to(0, sites.size() + n_overflow);
208#if !(defined(CUDA) || defined(HIP)) || defined(HILAPP)
210 void endloop_action() {
211 if (current_index > nmax) {
213 n_overflow = current_index - nmax;
214 current_index = nmax;
216 sites.resize(current_index);
225 template <
typename T>
226 void copy_data_to_host_vector(std::vector<T> &dvec,
const char *flag,
const T *d_data) {
227 void *d_temp_storage =
nullptr;
228 size_t temp_storage_bytes = 0;
231 gpuMalloc(&out, lattice.mynode.volume() *
sizeof(T));
234 gpuMalloc(&num_selected_d,
sizeof(
int));
237 GPU_CHECK(gpucub::DeviceSelect::Flagged(d_temp_storage, temp_storage_bytes, d_data, flag,
238 out, num_selected_d, lattice.mynode.volume()));
240 gpuMalloc(&d_temp_storage, temp_storage_bytes);
242 GPU_CHECK(gpucub::DeviceSelect::Flagged(d_temp_storage, temp_storage_bytes, d_data, flag,
243 out, num_selected_d, lattice.mynode.volume()));
245 gpuFree(d_temp_storage);
248 gpuMemcpy(&num_selected, num_selected_d,
sizeof(
int), gpuMemcpyDeviceToHost);
249 gpuFree(num_selected_d);
251 if (num_selected > nmax) {
252 n_overflow = num_selected - nmax;
255 dvec.resize(num_selected);
257 gpuMemcpy(dvec.data(), out,
sizeof(T) * num_selected, gpuMemcpyDeviceToHost);
262 void endloop_action(
const char *flag,
const SiteIndex *d_sites) {
264 copy_data_to_host_vector(sites, flag, d_sites);
273class site_value_select_type_ {};
276class SiteValueSelect :
public SiteSelect {
278 std::vector<T> values;
281 explicit SiteValueSelect() : SiteSelect() {
284 ~SiteValueSelect() =
default;
285 SiteValueSelect(
const SiteValueSelect &v) =
default;
289 values.resize(lattice.mynode.volume());
297 site_value_select_type_ select(
const X_index_type x,
const T &val) {
298 return site_value_select_type_();
301 void select_site_value(
const SiteIndex s,
const T &val) {
302 values[current_index] = val;
303 SiteSelect::select_site(s);
313 join_data_vectors(values);
317#if !(defined(CUDA) || defined(HIP)) || defined(HILAPP)
319 void endloop_action() {
320 bool save = auto_join;
322 SiteSelect::endloop_action();
323 values.resize(current_index);
331 void endloop_action(
const char *flag,
const SiteIndex *d_sites,
const T *d_values) {
332 copy_data_to_host_vector(sites, flag, d_sites);
333 copy_data_to_host_vector(values, flag, d_values);
347inline void dummy_func_2() {
Running index for locating sites on the lattice.
X-coordinate type - "dummy" class.
constexpr Parity ALL
bit pattern: 011
int myrank()
rank of this node
int number_of_nodes()
how many nodes there are
std::ostream out
this is our default output file stream