HILA
Loading...
Searching...
No Matches
scalar.h
1#ifndef SCALAR_H
2#define SCALAR_H
3
4///////////////////////////////////////////////////////////
5/// A scalar type. Templated, so that hilapp can convert
6/// to a vector type
7///////////////////////////////////////////////////////////
8template <typename T = double> struct scalar {
9 using base_type = hila::arithmetic_type<T>;
10 using argument_type = T;
11
12 T value;
13
14 // assignment is automatically OK, by c-standard
15 // scalar operator=(scalar rhs) {
16 // value = rhs.value;
17 // return *this;
18 // }
19 scalar<T>() = default;
20
21 scalar<T>(const scalar<T> &a) = default;
22
23 // constructor from single scalar value
24 template <typename scalar_t,
25 std::enable_if_t<std::is_arithmetic<scalar_t>::value, int> = 0>
26 constexpr scalar<T>(const scalar_t val) : value(static_cast<T>(val)) {}
27
28 ~scalar<T>() = default;
29
30 // automatic casting from scalar<T> -> scalar<A>
31 // TODO: ensure this works if A is vector type!
32 template <typename A> operator scalar<A>() const {
33 return scalar<A>({static_cast<A>(value)});
34 }
35
36 // Conversion to T
37 operator T() { return value; }
38
39 template <typename scalar_t,
40 std::enable_if_t<std::is_arithmetic<scalar_t>::value, int> = 0>
41 scalar<T> &operator=(scalar_t s) {
42 value = static_cast<T>(s);
43 return *this;
44 }
45
46 T real() const { return value; }
47 T imag() const { return 0; }
48
49 T squarenorm() const { return value * value; }
50 // TODO: make this work for vector type! Not double
51
52 // currently this gives a compilation error
53 double abs() const { return sqrt(static_cast<double>(squarenorm())); }
54
55 scalar<T> conj() const { return scalar<T>({value}); }
56
57 // unary + and -
58 scalar<T> operator+() const { return *this; }
59 scalar<T> operator-() const { return scalar<T>(-value); }
60
61 scalar<T> &operator+=(const scalar<T> &lhs) {
62 value += lhs.value;
63 return *this;
64 }
65
66 // TODO: for avx vector too -- #define new template macro
67 template <typename A, std::enable_if_t<std::is_arithmetic<A>::value, int> = 0>
68 scalar<T> &operator+=(const A &a) {
69 value += static_cast<T>(a);
70 return *this;
71 }
72
73 scalar<T> &operator-=(const scalar<T> &lhs) {
74 value -= lhs.value;
75 return *this;
76 }
77
78 // TODO: for vector too
79 template <typename A, std::enable_if_t<std::is_arithmetic<A>::value, int> = 0>
80 scalar<T> &operator-=(const A &a) {
81 value -= static_cast<T>(a);
82 return *this;
83 }
84
85 scalar<T> &operator*=(const scalar<T> &lhs) {
86 value = value * lhs.value;
87 return *this;
88 }
89
90 // TODO: for vector too
91 template <typename A, std::enable_if_t<std::is_arithmetic<A>::value, int> = 0>
92 scalar<T> &operator*=(const A &a) {
93 value *= static_cast<T>(a);
94 return *this;
95 }
96
97 // a/b
98 scalar<T> &operator/=(const scalar<T> &lhs) {
99 value = value / lhs.value;
100 return *this;
101 }
102
103 // TODO: for vector too
104 template <typename A, std::enable_if_t<std::is_arithmetic<A>::value, int> = 0>
105 scalar<T> &operator/=(const A &a) {
106 value /= static_cast<T>(a);
107 return *this;
108 }
109};
110
111template <typename T> scalar<T> operator+(const scalar<T> &a, const scalar<T> &b) {
112 return scalar<T>(a.value + b.value);
113}
114
115// TODO: for avx vector too -- #define new template macro
116template <typename T, typename A, std::enable_if_t<std::is_arithmetic<A>::value, int> = 0>
117scalar<T> operator+(const scalar<T> &c, const A &a) {
118 return scalar<T>(c.value + a);
119}
120
121template <typename T, typename A, std::enable_if_t<std::is_arithmetic<A>::value, int> = 0>
122scalar<T> operator+(const A &a, const scalar<T> &c) {
123 return scalar<T>(c.value + a);
124}
125
126// -
127template <typename T> scalar<T> operator-(const scalar<T> &a, const scalar<T> &b) {
128 return scalar<T>(a.value - b.value);
129}
130
131// TODO: for avx vector too -- #define new template macro
132template <typename T, typename A, std::enable_if_t<std::is_arithmetic<A>::value, int> = 0>
133scalar<T> operator-(const scalar<T> &c, const A &a) {
134 return scalar<T>(c.value - a);
135}
136
137template <typename T, typename A, std::enable_if_t<std::is_arithmetic<A>::value, int> = 0>
138scalar<T> operator-(const A &a, const scalar<T> &c) {
139 return scalar<T>(a - c.value);
140}
141
142// *
143template <typename T> scalar<T> operator*(const scalar<T> &a, const scalar<T> &b) {
144 return scalar<T>(a.value * b.value);
145}
146
147// TODO: for avx vector too -- #define new template macro
148template <typename T, typename A, std::enable_if_t<std::is_arithmetic<A>::value, int> = 0>
149scalar<T> operator*(const scalar<T> &c, const A &a) {
150 return scalar<T>(c.value * a);
151}
152
153template <typename T, typename A, std::enable_if_t<std::is_arithmetic<A>::value, int> = 0>
154scalar<T> operator*(const A &a, const scalar<T> &c) {
155 return scalar<T>(a * c.value);
156}
157
158// / a/b = ab*/|b|^2
159template <typename T> scalar<T> operator/(const scalar<T> &a, const scalar<T> &b) {
160 return scalar<T>(a.value / b.value);
161}
162
163// TODO: for avx vector too -- #define new template macro
164template <typename T, typename A, std::enable_if_t<std::is_arithmetic<A>::value, int> = 0>
165scalar<T> operator/(const scalar<T> &c, const A &a) {
166 return scalar<T>(c.value / a);
167}
168
169// a/c = ac*/|c|^2
170template <typename T, typename A, std::enable_if_t<std::is_arithmetic<A>::value, int> = 0>
171scalar<T> operator/(const A &a, const scalar<T> &c) {
172 return scalar<T>(a / c.value);
173}
174
175#endif
Definition scalar.h:8