HILA
Loading...
Searching...
No Matches
diagonal_matrix.h
1#ifndef DIAGONAL_MATRIX_H_
2#define DIAGONAL_MATRIX_H_
3
4#include "matrix.h"
5
6/// Define type DiagonalMatrix<n,T>
7
8
9/**
10 * @brief Class for diagonal matrix
11 *
12 * More optimal storage and algebra than normal square matrix
13 *
14 * @tparam n dimensionality
15 * @tparam T type
16 */
17template <int n, typename T>
19
20 public:
22 "DiagonalMatrix requires Complex or arithmetic type");
23
24 // std incantation for field types
25 using base_type = hila::arithmetic_type<T>;
26 using argument_type = T;
27
28 T c[n];
29
30 /// Define default constructors to ensure std::is_trivial
31 DiagonalMatrix() = default;
32 ~DiagonalMatrix() = default;
33 DiagonalMatrix(const DiagonalMatrix &v) = default;
34
35 // constructor from scalar -- keep it explicit! Not good for auto use
36 template <typename S, std::enable_if_t<(hila::is_assignable<T &, S>::value), int> = 0>
37 explicit inline DiagonalMatrix(const S rhs) {
38 for (int i = 0; i < n; i++)
39 c[i] = rhs;
40 }
41
42
43 // Construct matrix automatically from right-size initializer list
44 // This does not seem to be dangerous, so keep non-explicit
45 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
46 inline DiagonalMatrix(std::initializer_list<S> rhs) {
47 assert(rhs.size() == n && "Matrix/Vector initializer list size must match variable size");
48 int i = 0;
49 for (auto it = rhs.begin(); it != rhs.end(); it++, i++) {
50 c[i] = *it;
51 }
52 }
53
54 /**
55 * @brief Returns matrix size - all give same result
56 */
57 static constexpr int rows() {
58 return n;
59 }
60 static constexpr int columns() {
61 return n;
62 }
63 static constexpr int size() {
64 return n;
65 }
66
67 /**
68 * @brief Element access - e(i) gives diagonal element i
69 *
70 */
71
72 #pragma hila loop_function
73 inline T e(const int i) const {
74 return c[i];
75 }
76
77 // Same as above but with const_function, see const_function for details
78 inline T &e(const int i) const_function {
79 return c[i];
80 }
81
82 // For completeness, add e(i,j) - only for const
83 T e(const int i, const int j) const {
84 T ret(0);
85 if (i == j)
86 ret = c[i];
87 return ret;
88 }
89
90
91 /**
92 * @brief Return row from Diagonal matrix
93 *
94 */
95 RowVector<n, T> row(int i) const {
96 RowVector<n, T> res = 0;
97 res.e(i) = c[i];
98 return res;
99 }
100
101 /**
102 * @brief Returns column vector i
103 *
104 */
105 Vector<n, T> column(int i) const {
106 Vector<n, T> v = 0;
107 v.e(i) = c[i];
108 return v;
109 }
110
111
112 /**
113 * @brief Unary - operator
114 *
115 */
118 for (int i = 0; i < n; i++) {
119 res.e(i) = -c[i];
120 }
121 return res;
122 }
123
124 /**
125 * @brief Unary + operator
126 */
127 inline const auto &operator+() const {
128 return *this;
129 }
130
131
132 /**
133 * @brief Boolean operator != to check if matrices are exactly different
134 */
135 template <typename S>
136 bool operator!=(const S &rhs) const {
137 return !(*this == rhs);
138 }
139
140
141 /**
142 * Assignment operators: assign from another DiagonalMatrix, scalar or initializer list
143 */
144
145 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
146 inline DiagonalMatrix &operator=(const DiagonalMatrix<n, S> &rhs) out_only & {
147
148 for (int i = 0; i < n; i++) {
149 c[i] = rhs.e(i);
150 }
151 return *this;
152 }
153
154 // Assign from "scalar"
155
156 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
157 inline DiagonalMatrix &operator=(const S &rhs) out_only & {
158
159 for (int i = 0; i < n; i++) {
160 c[i] = rhs;
161 }
162 return *this;
163 }
164
165 // Assign from initializer list
166
167 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
168 inline DiagonalMatrix &operator=(std::initializer_list<S> rhs) out_only & {
169 assert(rhs.size() == n && "Initializer list has a wrong size in assignment");
170 int i = 0;
171 for (auto it = rhs.begin(); it != rhs.end(); it++, i++) {
172 c[i] = *it;
173 }
174 return *this;
175 }
176
177
178 // +=
179
180 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
181 DiagonalMatrix &operator+=(const DiagonalMatrix<n, S> &rhs) & {
182 for (int i = 0; i < n; i++) {
183 c[i] += rhs.e(i);
184 }
185 return *this;
186 }
187
188
189 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
190 DiagonalMatrix &operator+=(const S &rhs) & {
191 for (int i = 0; i < n; i++) {
192 c[i] += rhs;
193 }
194 return *this;
195 }
196
197 // -=
198
199 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
200 DiagonalMatrix &operator-=(const DiagonalMatrix<n, S> &rhs) & {
201 for (int i = 0; i < n; i++) {
202 c[i] -= rhs.e(i);
203 }
204 return *this;
205 }
206
207
208 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
209 DiagonalMatrix &operator-=(const S &rhs) & {
210 for (int i = 0; i < n; i++) {
211 c[i] -= rhs;
212 }
213 return *this;
214 }
215
216 // *=
217
218 template <typename S,
219 std::enable_if_t<hila::is_assignable<T &, hila::type_mul<T, S>>::value, int> = 0>
220 DiagonalMatrix &operator*=(const DiagonalMatrix<n, S> &rhs) & {
221 for (int i = 0; i < n; i++) {
222 c[i] *= rhs.e(i);
223 }
224 return *this;
225 }
226
227
228 template <typename S,
229 std::enable_if_t<hila::is_assignable<T &, hila::type_mul<T, S>>::value, int> = 0>
230 DiagonalMatrix &operator*=(const S &rhs) & {
231 for (int i = 0; i < n; i++) {
232 c[i] *= rhs;
233 }
234 return *this;
235 }
236
237 // /=
238
239 template <typename S,
240 std::enable_if_t<hila::is_assignable<T &, hila::type_div<T, S>>::value, int> = 0>
241 DiagonalMatrix &operator/=(const DiagonalMatrix<n, S> &rhs) & {
242 for (int i = 0; i < n; i++) {
243 c[i] /= rhs.e(i);
244 }
245 return *this;
246 }
247
248
249 template <typename S,
250 std::enable_if_t<hila::is_assignable<T &, hila::type_div<T, S>>::value, int> = 0>
251 DiagonalMatrix &operator/=(const S &rhs) & {
252 for (int i = 0; i < n; i++) {
253 c[i] /= rhs;
254 }
255 return *this;
256 }
257
258
259 /**
260 * @brief fill with constant value - same as assignment
261 */
262 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
263 DiagonalMatrix &fill(const S rhs) out_only {
264 for (int i = 0; i < n; i++)
265 c[i] = rhs;
266 return *this;
267 }
268
269 /**
270 * @brief transpose - leaves diagonal matrix as is
271 */
273 return *this;
274 }
275
276 /**
277 * @brief dagger - conjugate elements
278 */
281 for (int i = 0; i < n; i++)
282 ret.e(i) = ::conj(c[i]);
283 return ret;
284 }
285
286 DiagonalMatrix adjoint() const {
287 return this->dagger();
288 }
289
290 /**
291 * @brief conj - conjugate elements
292 */
294 return this->dagger();
295 }
296
297
298 /**
299 * @brief absolute value of all elements
300 *
301 * Returns a real DiagonalMatrix even if the original is complex
302 */
303 auto abs() const {
305 for (int i = 0; i < n; i++) {
306 res.c[i] = ::abs(c[i]);
307 }
308 return res;
309 }
310
311 /**
312 * @brief real and imaginary parts of diagonal matrix
313 *
314 * Returns a real DiagonalMatrix even if the original is complex
315 */
316 auto real() const {
318 for (int i = 0; i < n; i++) {
319 res.c[i] = ::real(c[i]);
320 }
321 return res;
322 }
323
324 auto imag() const {
326 for (int i = 0; i < n; i++) {
327 res.c[i] = ::imag(c[i]);
328 }
329 return res;
330 }
331
332 /**
333 * @brief Find max or min value - only for arithmetic types
334 */
335
336 template <typename S = T, std::enable_if_t<hila::is_arithmetic<S>::value, int> = 0>
337 T max() const {
338 T res = c[0];
339 for (int i = 1; i < n; i++) {
340 if (res < c[i])
341 res = c[i];
342 }
343 return res;
344 }
345
346 template <typename S = T, std::enable_if_t<hila::is_arithmetic<S>::value, int> = 0>
347 T min() const {
348 T res = c[0];
349 for (int i = 1; i < n; i++) {
350 if (res > c[i])
351 res = c[i];
352 }
353 return res;
354 }
355
356 T trace() const {
357 T res = c[0];
358 for (int i = 1; i < n; i++)
359 res += c[i];
360 return res;
361 }
362
363 T det() const {
364 T res = c[0];
365 for (int i = 1; i < n; i++)
366 res *= c[i];
367 return res;
368 }
369
370 auto squarenorm() const {
371 hila::arithmetic_type<T> res(0);
372 for (int i = 0; i < n; i++)
373 res += ::squarenorm(c[i]);
374 return res;
375 }
376
377 hila::arithmetic_type<T> norm() const {
378 return sqrt(squarenorm());
379 }
380
381 /**
382 * @brief Fills Matrix with random elements
383 * @details Works only for non-integer valued elements
384 */
385 DiagonalMatrix &random() out_only {
386
387 static_assert(hila::is_floating_point<hila::arithmetic_type<T>>::value,
388 "DiagonalMatrix random() requires non-integral type elements");
389 for (int i = 0; i < n; i++) {
390 hila::random(c[i]);
391 }
392 return *this;
393 }
394
395 DiagonalMatrix &gaussian_random(double width = 1.0) out_only {
396
397 static_assert(hila::is_floating_point<hila::arithmetic_type<T>>::value,
398 "DiagonaMatrix gaussian_random() requires non-integral type elements");
399 for (int i = 0; i < n; i++) {
400 hila::gaussian_random(c[i], width);
401 }
402 return *this;
403 }
404
405
406 /**
407 * @brief convert to string for printing
408 */
409 std::string str(int prec = 8, char separator = ' ') const {
410 return this->asArray().str(prec, separator);
411 }
412
413 /**
414 * @brief convert to generic matrix
415 */
417 Matrix<n, n, T> res;
418 for (int i = 0; i < n; i++) {
419 for (int j = 0; j < n; j++) {
420 if (i == j)
421 res.e(i, j) = c[i];
422 else
423 res.e(i, j) = 0;
424 }
425 }
426 return res;
427 }
428
429
430 // temp cast to Array, for some arithmetic ops
431
432 Array<n, 1, T> &asArray() const_function {
433 return *(reinterpret_cast<Array<n, 1, T> *>(this));
434 }
435
436 const Array<n, 1, T> &asArray() const {
437 return *(reinterpret_cast<const Array<n, 1, T> *>(this));
438 }
439
440 Vector<n, T> &asVector() const_function {
441 return *(reinterpret_cast<Vector<n, T> *>(this));
442 }
443
444 const Vector<n, T> &asVector() const {
445 return *(reinterpret_cast<const Vector<n, T> *>(this));
446 }
447
448 /// implement sort as casting to matrix
449#pragma hila novector
450 template <int N>
452 hila::sort order = hila::sort::ascending) const {
453 return this->asArray().sort(permutation, order).asDiagonalMatrix();
454 }
455
456#pragma hila novector
457 DiagonalMatrix<n, T> sort(hila::sort order = hila::sort::ascending) const {
458 return this->asArray().sort(order).asDiagonalMatrix();
459 }
460};
461
462///////////////////////////////////////////////////////////////////////////////////////
463
464/**
465 * @brief Boolean operator == to determine if two diagonal matrices are exactly the same
466 */
467template <typename A, typename B, int n, int m>
468inline bool operator==(const DiagonalMatrix<n, A> &lhs, const DiagonalMatrix<m, B> &rhs) {
469 if constexpr (m != n)
470 return false;
471
472 for (int i = 0; i < n; i++) {
473 if (lhs.e(i) != rhs.e(i))
474 return false;
475 }
476 return true;
477}
478
479/**
480 * @brief Boolean operator == to compare with square matrix
481 * Matrices are equal if diagonals are equal and off-diag is zero
482 */
483template <typename A, typename S, typename Mtype, int n, int m1, int m2>
484inline bool operator==(const DiagonalMatrix<n, A> &lhs, const Matrix_t<m1, m2, S, Mtype> &rhs) {
485 if constexpr (m1 != n || m2 != n)
486 return false;
487
488 for (int i = 0; i < n; i++) {
489 for (int j = 0; j < n; j++) {
490 if (i == j) {
491 if (lhs.e(i) != rhs.e(i, i))
492 return false;
493 } else {
494 if (rhs.e(i, j) != 0)
495 return false;
496 }
497 }
498 }
499 return true;
500}
501
502template <typename A, typename S, typename Mtype, int n, int m1, int m2>
503inline bool operator==(const Matrix_t<m1, m2, S, Mtype> &lhs, const DiagonalMatrix<n, A> &rhs) {
504 return rhs == lhs;
505}
506
507// Compiler generates operator!= automatically
508
509template <int n, typename T>
510inline const auto &transpose(const DiagonalMatrix<n, T> &arg) {
511 return arg;
512}
513
514template <int n, typename T>
515inline auto dagger(const DiagonalMatrix<n, T> &arg) {
516 return arg.dagger();
517}
518
519template <int n, typename T>
520inline auto conj(const DiagonalMatrix<n, T> &arg) {
521 return arg.conj();
522}
523
524template <int n, typename T>
525inline auto adjoint(const DiagonalMatrix<n, T> &arg) {
526 return arg.adjoint();
527}
528
529template <int n, typename T>
530inline auto abs(const DiagonalMatrix<n, T> &arg) {
531 return arg.abs();
532}
533
534template <int n, typename T>
535inline auto real(const DiagonalMatrix<n, T> &arg) {
536 return arg.real();
537}
538
539template <int n, typename T>
540inline auto imag(const DiagonalMatrix<n, T> &arg) {
541 return arg.imag();
542}
543
544template <int n, typename T>
545inline auto trace(const DiagonalMatrix<n, T> &arg) {
546 return arg.trace();
547}
548
549template <int n, typename T>
550inline auto squarenorm(const DiagonalMatrix<n, T> &arg) {
551 return arg.squarenorm();
552}
553
554template <int n, typename T>
555inline auto norm(const DiagonalMatrix<n, T> &arg) {
556 return arg.norm();
557}
558
559template <int n, typename T>
560inline auto det(const DiagonalMatrix<n, T> &arg) {
561 return arg.det();
562}
563
564
565namespace hila {
566
567////////////////////////////////////////////////////////////////////////////////
568// DiagonalMatrix + scalar result type:
569// hila::diagonalmatrix_scalar_type<Mt,S>
570// - if result is convertible to Mt, return Mt
571// - if Mt is not complex and S is, return
572// DiagonalMatrix<Complex<type_sum(scalar_type(Mt),scalar_type(S))>>
573// - otherwise return DiagonalMatrix<type_sum>
574
575template <typename Mt, typename S, typename Enable = void>
576struct diagonalmatrix_scalar_op_s {
577 using type = DiagonalMatrix<
578 Mt::rows(), Complex<hila::type_plus<hila::arithmetic_type<Mt>, hila::arithmetic_type<S>>>>;
579};
580
581template <typename Mt, typename S>
582struct diagonalmatrix_scalar_op_s<
583 Mt, S,
584 typename std::enable_if_t<std::is_convertible<hila::type_plus<hila::number_type<Mt>, S>,
585 hila::number_type<Mt>>::value>> {
586 // using type = Mt;
587 using type = typename std::conditional<
588 hila::is_floating_point<hila::arithmetic_type<Mt>>::value, Mt,
589 DiagonalMatrix<Mt::rows(),
590 hila::type_plus<hila::arithmetic_type<Mt>, hila::arithmetic_type<S>>>>::type;
591};
592
593template <typename Mt, typename S>
594using diagonalmatrix_scalar_type = typename diagonalmatrix_scalar_op_s<Mt, S>::type;
595
596} // namespace hila
597
598
599/**
600 * operators: can add scalar, diagonal matrix or square matrix.
601 */
602
603// diagonal + scalar
604template <int n, typename T, typename S,
605 std::enable_if_t<hila::is_complex_or_arithmetic<S>::value, int> = 0,
606 typename Rtype = hila::diagonalmatrix_scalar_type<DiagonalMatrix<n, T>, S>>
607inline Rtype operator+(const DiagonalMatrix<n, T> &a, const S &b) {
608 Rtype res;
609 for (int i = 0; i < n; i++)
610 res.e(i) = a.e(i) + b;
611 return res;
612}
613
614// diagonal + scalar
615template <int n, typename T, typename S,
616 std::enable_if_t<hila::is_complex_or_arithmetic<S>::value, int> = 0,
617 typename Rtype = hila::diagonalmatrix_scalar_type<DiagonalMatrix<n, T>, S>>
618inline Rtype operator+(const S &b, const DiagonalMatrix<n, T> &a) {
619 return a + b;
620}
621
622// diagonal - scalar
623template <int n, typename T, typename S,
624 std::enable_if_t<hila::is_complex_or_arithmetic<S>::value, int> = 0,
625 typename Rtype = hila::diagonalmatrix_scalar_type<DiagonalMatrix<n, T>, S>>
626inline Rtype operator-(const DiagonalMatrix<n, T> &a, const S &b) {
627 Rtype res;
628 for (int i = 0; i < n; i++)
629 res.e(i) = a.e(i) - b;
630 return res;
631}
632
633// scalar - diagonal
634template <int n, typename T, typename S,
635 std::enable_if_t<hila::is_complex_or_arithmetic<S>::value, int> = 0,
636 typename Rtype = hila::diagonalmatrix_scalar_type<DiagonalMatrix<n, T>, S>>
637inline Rtype operator-(const S &b, const DiagonalMatrix<n, T> &a) {
638 Rtype res;
639 for (int i = 0; i < n; i++)
640 res.e(i) = b - a.e(i);
641 return res;
642}
643
644// diagonal * scalar
645template <int n, typename T, typename S,
646 std::enable_if_t<hila::is_complex_or_arithmetic<S>::value, int> = 0,
647 typename Rtype = hila::diagonalmatrix_scalar_type<DiagonalMatrix<n, T>, S>>
648inline Rtype operator*(const DiagonalMatrix<n, T> &a, const S &b) {
649 Rtype res;
650 for (int i = 0; i < n; i++)
651 res.e(i) = a.e(i) * b;
652 return res;
653}
654
655// scalar * diagonal
656template <int n, typename T, typename S,
657 std::enable_if_t<hila::is_complex_or_arithmetic<S>::value, int> = 0,
658 typename Rtype = hila::diagonalmatrix_scalar_type<DiagonalMatrix<n, T>, S>>
659inline Rtype operator*(const S &b, const DiagonalMatrix<n, T> &a) {
660 return a * b;
661}
662
663// diagonal / scalar
664template <int n, typename T, typename S,
665 std::enable_if_t<hila::is_complex_or_arithmetic<S>::value, int> = 0,
666 typename Rtype = hila::diagonalmatrix_scalar_type<DiagonalMatrix<n, T>, S>>
667inline Rtype operator/(const DiagonalMatrix<n, T> &a, const S &b) {
668 Rtype res;
669 for (int i = 0; i < n; i++)
670 res.e(i) = a.e(i) / b;
671 return res;
672}
673
674// scalar / diagonal
675template <int n, typename T, typename S,
676 std::enable_if_t<hila::is_complex_or_arithmetic<S>::value, int> = 0,
677 typename Rtype = hila::diagonalmatrix_scalar_type<DiagonalMatrix<n, T>, S>>
678inline Rtype operator/(const S &b, const DiagonalMatrix<n, T> &a) {
679 Rtype res;
680 for (int i = 0; i < n; i++)
681 res.e(i) = b / a.e(i);
682 return res;
683}
684
685/////
686// diagonal X diagonal
687template <int n, typename A, typename B, typename R = hila::type_plus<A, B>>
688inline auto operator+(const DiagonalMatrix<n, A> &a, const DiagonalMatrix<n, B> &b) {
690 for (int i = 0; i < n; i++)
691 res.e(i) = a.e(i) + b.e(i);
692 return res;
693}
694
695template <int n, typename A, typename B, typename R = hila::type_minus<A, B>>
696inline auto operator-(const DiagonalMatrix<n, A> &a, const DiagonalMatrix<n, B> &b) {
698 for (int i = 0; i < n; i++)
699 res.e(i) = a.e(i) - b.e(i);
700 return res;
701}
702
703template <int n, typename A, typename B, typename R = hila::type_mul<A, B>>
704inline auto operator*(const DiagonalMatrix<n, A> &a, const DiagonalMatrix<n, B> &b) {
706 for (int i = 0; i < n; i++)
707 res.e(i) = a.e(i) * b.e(i);
708 return res;
709}
710
711template <int n, typename A, typename B, typename R = hila::type_div<A, B>>
712inline auto operator/(const DiagonalMatrix<n, A> &a, const DiagonalMatrix<n, B> &b) {
714 for (int i = 0; i < n; i++)
715 res.e(i) = a.e(i) / b.e(i);
716 return res;
717}
718
719//// Finally, diagonal X Matrix ops - gives Matrix
720
721template <int n, typename T, typename Mtype, std::enable_if_t<Mtype::is_matrix(), int> = 0,
722 typename Rtype = hila::mat_x_mat_type<Matrix<n, n, T>, Mtype>>
723inline Rtype operator+(const DiagonalMatrix<n, T> &a, const Mtype &b) {
724
725 constexpr int mr = Mtype::rows();
726 constexpr int mc = Mtype::columns();
727
728 static_assert(mc == n && mr == n, "Matrix sizes do not match");
729
730 Rtype r;
731 r = b;
732 for (int i = 0; i < n; i++)
733 r.e(i, i) += a.e(i);
734 return r;
735}
736
737template <int n, typename T, typename Mtype, std::enable_if_t<Mtype::is_matrix(), int> = 0,
738 typename Rtype = hila::mat_x_mat_type<Matrix<n, n, T>, Mtype>>
739inline Rtype operator+(const Mtype &b, const DiagonalMatrix<n, T> &a) {
740 return a + b;
741}
742
743template <int n, typename T, typename Mtype, std::enable_if_t<Mtype::is_matrix(), int> = 0,
744 typename Rtype = hila::mat_x_mat_type<Matrix<n, n, T>, Mtype>>
745inline Rtype operator-(const DiagonalMatrix<n, T> &a, const Mtype &b) {
746
747 constexpr int mr = Mtype::rows();
748 constexpr int mc = Mtype::columns();
749
750 static_assert(mc == n && mr == n, "Matrix sizes do not match");
751
752 Rtype r = -b;
753 for (int i = 0; i < n; i++)
754 r.e(i, i) += a.e(i);
755 return r;
756}
757
758template <int n, typename T, typename Mtype, std::enable_if_t<Mtype::is_matrix(), int> = 0,
759 typename Rtype = hila::mat_x_mat_type<Matrix<n, n, T>, Mtype>>
760inline Rtype operator-(const Mtype &b, const DiagonalMatrix<n, T> &a) {
761 constexpr int mr = Mtype::rows();
762 constexpr int mc = Mtype::columns();
763
764 static_assert(mc == n && mr == n, "Matrix sizes do not match");
765
766 Rtype r = b;
767 for (int i = 0; i < n; i++)
768 r.e(i, i) -= a.e(i);
769 return r;
770}
771
772
773// multiply by matrix
774
775template <int n, typename T, typename Mtype, std::enable_if_t<Mtype::is_matrix(), int> = 0,
776 typename Rtype = hila::mat_x_mat_type<Matrix<n, n, T>, Mtype>>
777inline Rtype operator*(const DiagonalMatrix<n, T> &a, const Mtype &b) {
778
779 constexpr int mr = Mtype::rows();
780 constexpr int mc = Mtype::columns();
781
782 static_assert(mr == n, "Matrix sizes do not match");
783
784 Rtype r;
785 for (int i = 0; i < n; i++)
786 for (int j = 0; j < mc; j++)
787 r.e(i, j) = a.e(i) * b.e(i, j);
788 return r;
789}
790
791template <int n, typename T, typename Mtype, std::enable_if_t<Mtype::is_matrix(), int> = 0,
792 typename Rtype = hila::mat_x_mat_type<Matrix<n, n, T>, Mtype>>
793inline Rtype operator*(const Mtype &b, const DiagonalMatrix<n, T> &a) {
794
795 constexpr int mr = Mtype::rows();
796 constexpr int mc = Mtype::columns();
797
798 static_assert(mc == n, "Matrix sizes do not match");
799
800 Rtype r;
801 for (int i = 0; i < mr; i++)
802 for (int j = 0; j < n; j++)
803 r.e(i, j) = b.e(i, j) * a.e(j);
804 return r;
805}
806
807// division
808template <int n, typename T, typename Mtype, std::enable_if_t<Mtype::is_matrix(), int> = 0,
809 typename Rtype = hila::mat_x_mat_type<Matrix<n, n, T>, Mtype>>
810inline Rtype operator/(const Mtype &b, const DiagonalMatrix<n, T> &a) {
811
812 constexpr int mr = Mtype::rows();
813 constexpr int mc = Mtype::columns();
814
815 static_assert(mc == n, "Matrix sizes do not match");
816
817 Rtype r;
818 for (int i = 0; i < mr; i++)
819 for (int j = 0; j < n; j++)
820 r.e(i, j) = b.e(i, j) / a.e(j);
821 return r;
822}
823
824////////////////////////////////////////////////////////////////////////////////
825/// Standard arithmetic functions - do element by element
826////////////////////////////////////////////////////////////////////////////////
827
828template <int n, typename T>
830 for (int i = 0; i < n; i++)
831 a.c[i] = sqrt(a.c[i]);
832 return a;
833}
834
835template <int n, typename T>
837 for (int i = 0; i < n; i++)
838 a.c[i] = cbrt(a.c[i]);
839 return a;
840}
841
842template <int n, typename T>
844 for (int i = 0; i < n; i++)
845 a.c[i] = exp(a.c[i]);
846 return a;
847}
848
849template <int n, typename T>
851 for (int i = 0; i < n; i++)
852 a.c[i] = log(a.c[i]);
853 return a;
854}
855
856template <int n, typename T>
858 for (int i = 0; i < n; i++)
859 a.c[i] = sin(a.c[i]);
860 return a;
861}
862
863template <int n, typename T>
865 for (int i = 0; i < n; i++)
866 a.c[i] = cos(a.c[i]);
867 return a;
868}
869
870template <int n, typename T>
872 for (int i = 0; i < n; i++)
873 a.c[i] = tan(a.c[i]);
874 return a;
875}
876
877template <int n, typename T>
879 for (int i = 0; i < n; i++)
880 a.c[i] = asin(a.c[i]);
881 return a;
882}
883
884template <int n, typename T>
886 for (int i = 0; i < n; i++)
887 a.c[i] = acos(a.c[i]);
888 return a;
889}
890
891template <int n, typename T>
893 for (int i = 0; i < n; i++)
894 a.c[i] = atan(a.c[i]);
895 return a;
896}
897
898template <int n, typename T>
900 for (int i = 0; i < n; i++)
901 a.c[i] = sinh(a.c[i]);
902 return a;
903}
904
905template <int n, typename T>
907 for (int i = 0; i < n; i++)
908 a.c[i] = cosh(a.c[i]);
909 return a;
910}
911
912template <int n, typename T>
914 for (int i = 0; i < n; i++)
915 a.c[i] = tanh(a.c[i]);
916 return a;
917}
918
919template <int n, typename T>
921 for (int i = 0; i < n; i++)
922 a.c[i] = asinh(a.c[i]);
923 return a;
924}
925
926template <int n, typename T>
928 for (int i = 0; i < n; i++)
929 a.c[i] = acosh(a.c[i]);
930 return a;
931}
932
933template <int n, typename T>
935 for (int i = 0; i < n; i++)
936 a.c[i] = atanh(a.c[i]);
937 return a;
938}
939
940
941// return pow of diagonalMatrix as original type if power is scalar or diagonal is complex
942template <
943 int n, typename T, typename S,
944 std::enable_if_t<hila::is_arithmetic<S>::value || hila::contains_complex<T>::value, int> = 0>
946 for (int i = 0; i < n; i++)
947 a.c[i] = pow(a.c[i], p);
948 return a;
949}
950
951// if power is complex but matrix is scalar need to upgrade return type
952template <int n, typename T, typename S,
953 std::enable_if_t<!hila::contains_complex<T>::value, int> = 0>
954inline auto pow(const DiagonalMatrix<n, T> &a, const Complex<S> &p) {
956 for (int i = 0; i < n; i++)
957 res.e(i) = pow(a.e(i), p);
958 return res;
959}
960
961// Cast operators to different number or Complex type
962// cast_to<double>(a);
963// Cast from number->number, number->Complex, Complex->Complex OK,
964// Complex->number not.
965
966template <typename Ntype, typename T, int n,
967 std::enable_if_t<hila::is_arithmetic<T>::value, int> = 0>
970 for (int i = 0; i < n; i++)
971 res.c[i] = mat.c[i];
972 return res;
973}
974
975template <typename Ntype, typename T, int n, std::enable_if_t<hila::is_complex<T>::value, int> = 0>
978 for (int i = 0; i < n; i++)
979 res.c[i] = cast_to<Ntype>(mat.c[i]);
980 return res;
981}
982
983/// Stream operator
984template <int n, typename T>
985std::ostream &operator<<(std::ostream &strm, const DiagonalMatrix<n, T> &A) {
986 return operator<<(strm, A.asArray());
987}
988
989namespace hila {
990
991template <int n, typename T>
992std::string to_string(const DiagonalMatrix<n, T> &A, int prec = 8, char separator = ' ') {
993 return to_string(A.asArray(), prec, separator);
994}
995
996template <int n, typename T>
997std::string prettyprint(const DiagonalMatrix<n, T> &A, int prec = 8) {
998 return prettyprint(A.toMatrix(), prec);
999}
1000
1001} // namespace hila
1002
1003
1004#endif
Array< n, m, T > conj(const Array< n, m, T > &arg)
Return conjugate Array.
Definition array.h:674
auto operator/(const Array< n, m, A > &a, const Array< n, m, B > &b)
Division operator.
Definition array.h:916
std::ostream & operator<<(std::ostream &strm, const Array< n, m, T > &A)
Stream operator.
Definition array.h:977
Array< n, m, T > asin(Array< n, m, T > a)
Inverse Sine.
Definition array.h:1109
auto operator-(const Array< n, m, A > &a, const Array< n, m, B > &b)
Subtraction operator.
Definition array.h:781
Array< n, m, T > exp(Array< n, m, T > a)
Exponential.
Definition array.h:1059
Array< n, m, T > tanh(Array< n, m, T > a)
Hyperbolic tangent.
Definition array.h:1159
Array< n, m, hila::arithmetic_type< T > > imag(const Array< n, m, T > &arg)
Return imaginary part of Array.
Definition array.h:702
hila::arithmetic_type< T > squarenorm(const Array< n, m, T > &rhs)
Return square norm of Array.
Definition array.h:1018
Array< n, m, T > cosh(Array< n, m, T > a)
Hyperbolic Cosine.
Definition array.h:1149
Array< n, m, T > tan(Array< n, m, T > a)
Tangent.
Definition array.h:1099
auto operator+(const Array< n, m, A > &a, const Array< n, m, B > &b)
Addition operator.
Definition array.h:740
Array< n, m, T > cos(Array< n, m, T > a)
Cosine.
Definition array.h:1089
Array< n, m, T > cbrt(Array< n, m, T > a)
Cuberoot.
Definition array.h:1049
Array< n, m, T > acos(Array< n, m, T > a)
Inverse Cosine.
Definition array.h:1119
Array< n, m, T > acosh(Array< n, m, T > a)
Inverse Hyperbolic Cosine.
Definition array.h:1179
Array< n, m, T > sqrt(Array< n, m, T > a)
Square root.
Definition array.h:1039
Array< n, m, T > asinh(Array< n, m, T > a)
Inverse Hyperbolic Sine.
Definition array.h:1169
Array< n, m, T > atan(Array< n, m, T > a)
Inverse Tangent.
Definition array.h:1129
Array< n, m, T > pow(Array< n, m, T > a, int b)
Power.
Definition array.h:1204
Array< n, m, T > operator*(Array< n, m, T > a, const Array< n, m, T > &b)
Multiplication operator.
Definition array.h:861
Array< n, m, Ntype > cast_to(const Array< n, m, T > &mat)
Array casting operation.
Definition array.h:1289
Array< n, m, T > sin(Array< n, m, T > a)
Sine.
Definition array.h:1079
Array< n, m, T > sinh(Array< n, m, T > a)
Hyperbolic Sine.
Definition array.h:1139
Array< n, m, hila::arithmetic_type< T > > real(const Array< n, m, T > &arg)
Return real part of Array.
Definition array.h:688
Array< n, m, T > atanh(Array< n, m, T > a)
Inverse Hyperbolic Tangent.
Definition array.h:1189
Array type
Definition array.h:43
Complex definition.
Definition cmplx.h:56
Define type DiagonalMatrix<n,T>
Matrix< n, n, T > toMatrix() const
convert to generic matrix
bool operator!=(const S &rhs) const
Boolean operator != to check if matrices are exactly different.
Vector< n, T > column(int i) const
Returns column vector i.
DiagonalMatrix & operator=(const DiagonalMatrix< n, S > &rhs) &
T e(const int i) const
Element access - e(i) gives diagonal element i.
DiagonalMatrix & transpose() const
transpose - leaves diagonal matrix as is
DiagonalMatrix conj() const
conj - conjugate elements
DiagonalMatrix & random()
Fills Matrix with random elements.
std::string str(int prec=8, char separator=' ') const
convert to string for printing
RowVector< n, T > row(int i) const
Return row from Diagonal matrix.
DiagonalMatrix dagger() const
dagger - conjugate elements
DiagonalMatrix()=default
Define default constructors to ensure std::is_trivial.
DiagonalMatrix< n, T > operator-() const
Unary - operator.
DiagonalMatrix & fill(const S rhs)
fill with constant value - same as assignment
T max() const
Find max or min value - only for arithmetic types.
static constexpr int rows()
Returns matrix size - all give same result.
const auto & operator+() const
Unary + operator.
DiagonalMatrix< n, T > sort(Vector< N, int > &permutation, hila::sort order=hila::sort::ascending) const
implement sort as casting to matrix
auto real() const
real and imaginary parts of diagonal matrix
auto abs() const
absolute value of all elements
The main matrix type template Matrix_t. This is a base class type for "useful" types which are deriv...
Definition matrix.h:106
static constexpr int rows()
Define constant methods rows(), columns() - may be useful in template code.
Definition matrix.h:226
static constexpr int columns()
Returns column length.
Definition matrix.h:234
T e(const int i, const int j) const
Standard array indexing operation for matrices and vectors.
Definition matrix.h:286
Matrix class which defines matrix operations.
Definition matrix.h:1620
bool operator==(const Complex< A > &a, const Complex< B > &b)
Compare equality of two complex numbers.
Definition cmplx.h:1072
Complex< T > dagger(const Complex< T > &val)
Return dagger of Complex number.
Definition cmplx.h:1223
T abs(const Complex< T > &a)
Return absolute value of Complex number.
Definition cmplx.h:1187
T arg(const Complex< T > &a)
Return argument of Complex number.
Definition cmplx.h:1199
Definition of Matrix types.
Implement hila::swap for gauge fields.
Definition array.h:981
T gaussian_random()
Template function T hila::gaussian_random<T>(),generates gaussian random value of type T,...
Definition random.h:183
logger_class log
Now declare the logger.
double random()
Real valued uniform random number generator.
Definition hila_gpu.cpp:121
decltype(std::declval< A >()+std::declval< B >()) type_plus
Definition type_tools.h:103
std::string to_string(const Array< n, m, T > &A, int prec=8, char separator=' ')
Converts Array object to string.
Definition array.h:995
hila::is_complex_or_arithmetic<T>::value
Definition cmplx.h:665