HILA
Loading...
Searching...
No Matches
diagonal_matrix.h
1#ifndef DIAGONAL_MATRIX_H_
2#define DIAGONAL_MATRIX_H_
3
4#include "matrix.h"
5
6/// Define type DiagonalMatrix<n,T>
7
8
9/**
10 * @brief Class for diagonal matrix
11 *
12 * More optimal storage and algebra than normal square matrix
13 *
14 * @tparam n dimensionality
15 * @tparam T type
16 */
17template <int n, typename T>
19
20 public:
22 "DiagonalMatrix requires Complex or arithmetic type");
23
24 // std incantation for field types
25 using base_type = hila::arithmetic_type<T>;
26 using argument_type = T;
27
28 T c[n];
29
30 /// Define default constructors to ensure std::is_trivial
31 DiagonalMatrix() = default;
32 ~DiagonalMatrix() = default;
33 DiagonalMatrix(const DiagonalMatrix &v) = default;
34
35 // constructor from scalar -- keep it explicit! Not good for auto use
36 template <typename S, std::enable_if_t<(hila::is_assignable<T &, S>::value), int> = 0>
37 explicit inline DiagonalMatrix(const S rhs) {
38 for (int i = 0; i < n; i++)
39 c[i] = rhs;
40 }
41
42
43 // Construct matrix automatically from right-size initializer list
44 // This does not seem to be dangerous, so keep non-explicit
45 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
46 inline DiagonalMatrix(std::initializer_list<S> rhs) {
47 assert(rhs.size() == n && "Matrix/Vector initializer list size must match variable size");
48 int i = 0;
49 for (auto it = rhs.begin(); it != rhs.end(); it++, i++) {
50 c[i] = *it;
51 }
52 }
53
54 /**
55 * @brief Returns matrix size - all give same result
56 */
57 static constexpr int rows() {
58 return n;
59 }
60 static constexpr int columns() {
61 return n;
62 }
63 static constexpr int size() {
64 return n;
65 }
66
67 /**
68 * @brief Element access - e(i) gives diagonal element i
69 *
70 */
71
72 inline T e(const int i) const {
73 return c[i];
74 }
75
76 // Same as above but with const_function, see const_function for details
77 inline T &e(const int i) const_function {
78 return c[i];
79 }
80
81 // For completeness, add e(i,j) - only for const
82 T e(const int i, const int j) const {
83 T ret(0);
84 if (i == j)
85 ret = c[i];
86 return ret;
87 }
88
89
90 /**
91 * @brief Return row from Diagonal matrix
92 *
93 */
94 RowVector<n, T> row(int i) const {
95 RowVector<n, T> res = 0;
96 res.e(i) = c[i];
97 return res;
98 }
99
100 /**
101 * @brief Returns column vector i
102 *
103 */
104 Vector<n, T> column(int i) const {
105 Vector<n, T> v = 0;
106 v.e(i) = c[i];
107 return v;
108 }
109
110
111 /**
112 * @brief Unary - operator
113 *
114 */
117 for (int i = 0; i < n; i++) {
118 res.e(i) = -c[i];
119 }
120 return res;
121 }
122
123 /**
124 * @brief Unary + operator
125 */
126 inline const auto &operator+() const {
127 return *this;
128 }
129
130 /**
131 * @brief Boolean operator == to determine if two matrices are exactly the same
132 */
133 template <typename S>
134 bool operator==(const DiagonalMatrix<n, S> &rhs) const {
135 for (int i = 0; i < n; i++) {
136 if (e(i) != rhs.e(i))
137 return false;
138 }
139 return true;
140 }
141
142 /**
143 * @brief Boolean operator == to compare with square matrix
144 */
145 template <typename S, typename Mtype>
146 bool operator==(const Matrix_t<n, n, S, Mtype> &rhs) const {
147 for (int i = 0; i < n; i++) {
148 for (int j = 0; j < n; j++) {
149 if (i == j) {
150 if (e(i) != rhs.e(i, i))
151 return false;
152 } else {
153 if (rhs.e(i, j) != 0)
154 return false;
155 }
156 }
157 }
158 return true;
159 }
160
161
162 /**
163 * @brief Boolean operator != to check if matrices are exactly different
164 */
165 template <typename S>
166 bool operator!=(const S &rhs) const {
167 return !(*this == rhs);
168 }
169
170
171 /**
172 * Assignment operators: assign from another DiagonalMatrix, scalar or initializer list
173 */
174
175
176#pragma hila loop_function
177 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
178 inline DiagonalMatrix &operator=(const DiagonalMatrix<n, S> &rhs) out_only {
179
180 for (int i = 0; i < n; i++) {
181 c[i] = rhs.e(i);
182 }
183 return *this;
184 }
185
186// Assign from "scalar"
187#pragma hila loop_function
188 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
189 inline DiagonalMatrix &operator=(const S &rhs) out_only {
190
191 for (int i = 0; i < n; i++) {
192 c[i] = rhs;
193 }
194 return *this;
195 }
196
197// Assign from initializer list
198#pragma hila loop_function
199 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
200 inline DiagonalMatrix &operator=(std::initializer_list<S> rhs) out_only {
201 assert(rhs.size() == n && "Initializer list has a wrong size in assignment");
202 int i = 0;
203 for (auto it = rhs.begin(); it != rhs.end(); it++, i++) {
204 c[i] = *it;
205 }
206 return *this;
207 }
208
209
210 // +=
211#pragma hila loop_function
212 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
213 DiagonalMatrix &operator+=(const DiagonalMatrix<n, S> &rhs) {
214 for (int i = 0; i < n; i++) {
215 c[i] += rhs.e(i);
216 }
217 return *this;
218 }
219
220#pragma hila loop_function
221 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
222 DiagonalMatrix &operator+=(const S &rhs) {
223 for (int i = 0; i < n; i++) {
224 c[i] += rhs;
225 }
226 return *this;
227 }
228
229 // -=
230#pragma hila loop_function
231 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
232 DiagonalMatrix &operator-=(const DiagonalMatrix<n, S> &rhs) {
233 for (int i = 0; i < n; i++) {
234 c[i] -= rhs.e(i);
235 }
236 return *this;
237 }
238
239#pragma hila loop_function
240 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
241 DiagonalMatrix &operator-=(const S &rhs) {
242 for (int i = 0; i < n; i++) {
243 c[i] -= rhs;
244 }
245 return *this;
246 }
247
248 // *=
249#pragma hila loop_function
250 template <typename S,
251 std::enable_if_t<hila::is_assignable<T &, hila::type_mul<T, S>>::value, int> = 0>
252 DiagonalMatrix &operator*=(const DiagonalMatrix<n, S> &rhs) {
253 for (int i = 0; i < n; i++) {
254 c[i] *= rhs.e(i);
255 }
256 return *this;
257 }
258
259#pragma hila loop_function
260 template <typename S,
261 std::enable_if_t<hila::is_assignable<T &, hila::type_mul<T, S>>::value, int> = 0>
262 DiagonalMatrix &operator*=(const S &rhs) {
263 for (int i = 0; i < n; i++) {
264 c[i] *= rhs;
265 }
266 return *this;
267 }
268
269 // /=
270#pragma hila loop_function
271 template <typename S,
272 std::enable_if_t<hila::is_assignable<T &, hila::type_div<T, S>>::value, int> = 0>
273 DiagonalMatrix &operator/=(const DiagonalMatrix<n, S> &rhs) {
274 for (int i = 0; i < n; i++) {
275 c[i] /= rhs.e(i);
276 }
277 return *this;
278 }
279
280#pragma hila loop_function
281 template <typename S,
282 std::enable_if_t<hila::is_assignable<T &, hila::type_div<T, S>>::value, int> = 0>
283 DiagonalMatrix &operator/=(const S &rhs) {
284 for (int i = 0; i < n; i++) {
285 c[i] /= rhs;
286 }
287 return *this;
288 }
289
290
291 /**
292 * @brief fill with constant value - same as assignment
293 */
294 template <typename S, std::enable_if_t<hila::is_assignable<T &, S>::value, int> = 0>
295 DiagonalMatrix &fill(const S rhs) out_only {
296 for (int i = 0; i < n; i++)
297 c[i] = rhs;
298 return *this;
299 }
300
301 /**
302 * @brief transpose - leaves diagonal matrix as is
303 */
304 const DiagonalMatrix &transpose() const {
305 return *this;
306 }
307
308 /**
309 * @brief dagger - conjugate elements
310 */
313 for (int i = 0; i < n; i++)
314 ret.e(i) = ::conj(c[i]);
315 return ret;
316 }
317
318 DiagonalMatrix adjoint() const {
319 return this->dagger();
320 }
321
322 /**
323 * @brief conj - conjugate elements
324 */
326 return this->dagger();
327 }
328
329
330 /**
331 * @brief absolute value of all elements
332 *
333 * Returns a real DiagonalMatrix even if the original is complex
334 */
335 auto abs() const {
337 for (int i = 0; i < n; i++) {
338 res.c[i] = ::abs(c[i]);
339 }
340 return res;
341 }
342
343 /**
344 * @brief real and imaginary parts of diagonal matrix
345 *
346 * Returns a real DiagonalMatrix even if the original is complex
347 */
348 auto real() const {
350 for (int i = 0; i < n; i++) {
351 res.c[i] = ::real(c[i]);
352 }
353 return res;
354 }
355
356 auto imag() const {
358 for (int i = 0; i < n; i++) {
359 res.c[i] = ::imag(c[i]);
360 }
361 return res;
362 }
363
364 /**
365 * @brief Find max or min value - only for arithmetic types
366 */
367
368 template <typename S = T, std::enable_if_t<hila::is_arithmetic<S>::value, int> = 0>
369 T max() const {
370 T res = c[0];
371 for (int i = 1; i < n; i++) {
372 if (res < c[i])
373 res = c[i];
374 }
375 return res;
376 }
377
378 template <typename S = T, std::enable_if_t<hila::is_arithmetic<S>::value, int> = 0>
379 T min() const {
380 T res = c[0];
381 for (int i = 1; i < n; i++) {
382 if (res > c[i])
383 res = c[i];
384 }
385 return res;
386 }
387
388 T trace() const {
389 T res = c[0];
390 for (int i = 1; i < n; i++)
391 res += c[i];
392 return res;
393 }
394
395 T det() const {
396 T res = c[0];
397 for (int i = 1; i < n; i++)
398 res *= c[i];
399 return res;
400 }
401
402 auto squarenorm() const {
403 hila::arithmetic_type<T> res(0);
404 for (int i = 0; i < n; i++)
405 res += ::squarenorm(c[i]);
406 return res;
407 }
408
409 hila::arithmetic_type<T> norm() const {
410 return sqrt(squarenorm());
411 }
412
413 /**
414 * @brief Fills Matrix with random elements
415 * @details Works only for non-integer valued elements
416 */
417 DiagonalMatrix &random() out_only {
418
419 static_assert(hila::is_floating_point<hila::arithmetic_type<T>>::value,
420 "DiagonalMatrix random() requires non-integral type elements");
421 for (int i = 0; i < n; i++) {
422 hila::random(c[i]);
423 }
424 return *this;
425 }
426
427 DiagonalMatrix &gaussian_random(double width = 1.0) out_only {
428
429 static_assert(hila::is_floating_point<hila::arithmetic_type<T>>::value,
430 "DiagonaMatrix gaussian_random() requires non-integral type elements");
431 for (int i = 0; i < n; i++) {
432 hila::gaussian_random(c[i], width);
433 }
434 return *this;
435 }
436
437
438 /**
439 * @brief convert to string for printing
440 */
441 std::string str(int prec = 8, char separator = ' ') const {
442 return this->asArray().str(prec, separator);
443 }
444
445 /**
446 * @brief convert to generic matrix
447 */
449 Matrix<n, n, T> res;
450 for (int i = 0; i < n; i++) {
451 for (int j = 0; j < n; j++) {
452 if (i == j)
453 res.e(i, j) = c[i];
454 else
455 res.e(i, j) = 0;
456 }
457 }
458 return res;
459 }
460
461
462 // temp cast to Array, for some arithmetic ops
463
464 Array<n, 1, T> &asArray() const_function {
465 return *(reinterpret_cast<Array<n, 1, T> *>(this));
466 }
467
468 const Array<n, 1, T> &asArray() const {
469 return *(reinterpret_cast<const Array<n, 1, T> *>(this));
470 }
471
472 Vector<n, T> &asVector() const_function {
473 return *(reinterpret_cast<Vector<n,T> *>(this));
474 }
475
476 const Vector<n, T> &asVector() const {
477 return *(reinterpret_cast<const Vector<n,T> *>(this));
478 }
479
480 /// implement sort as casting to matrix
481#pragma hila novector
482 template <int N>
484 hila::sort order = hila::sort::ascending) const {
485 return this->asArray().sort(permutation, order).asDiagonalMatrix();
486 }
487
488#pragma hila novector
489 DiagonalMatrix<n, T> sort(hila::sort order = hila::sort::ascending) const {
490 return this->asArray().sort(order).asDiagonalMatrix();
491 }
492};
493
494///////////////////////////////////////////////////////////////////////////////////////
495
496template <int n, typename T>
497inline const auto &transpose(const DiagonalMatrix<n, T> &arg) {
498 return arg;
499}
500
501template <int n, typename T>
502inline auto dagger(const DiagonalMatrix<n, T> &arg) {
503 return arg.dagger();
504}
505
506template <int n, typename T>
507inline auto conj(const DiagonalMatrix<n, T> &arg) {
508 return arg.conj();
509}
510
511template <int n, typename T>
512inline auto adjoint(const DiagonalMatrix<n, T> &arg) {
513 return arg.adjoint();
514}
515
516template <int n, typename T>
517inline auto abs(const DiagonalMatrix<n, T> &arg) {
518 return arg.abs();
519}
520
521template <int n, typename T>
522inline auto real(const DiagonalMatrix<n, T> &arg) {
523 return arg.real();
524}
525
526template <int n, typename T>
527inline auto imag(const DiagonalMatrix<n, T> &arg) {
528 return arg.imag();
529}
530
531template <int n, typename T>
532inline auto trace(const DiagonalMatrix<n, T> &arg) {
533 return arg.trace();
534}
535
536template <int n, typename T>
537inline auto squarenorm(const DiagonalMatrix<n, T> &arg) {
538 return arg.squarenorm();
539}
540
541template <int n, typename T>
542inline auto norm(const DiagonalMatrix<n, T> &arg) {
543 return arg.norm();
544}
545
546template <int n, typename T>
547inline auto det(const DiagonalMatrix<n, T> &arg) {
548 return arg.det();
549}
550
551
552namespace hila {
553
554////////////////////////////////////////////////////////////////////////////////
555// DiagonalMatrix + scalar result type:
556// hila::diagonalmatrix_scalar_type<Mt,S>
557// - if result is convertible to Mt, return Mt
558// - if Mt is not complex and S is, return
559// DiagonalMatrix<Complex<type_sum(scalar_type(Mt),scalar_type(S))>>
560// - otherwise return DiagonalMatrix<type_sum>
561
562template <typename Mt, typename S, typename Enable = void>
563struct diagonalmatrix_scalar_op_s {
564 using type =
565 DiagonalMatrix<Mt::rows(),
566 Complex<hila::type_plus<hila::arithmetic_type<Mt>, hila::arithmetic_type<S>>>>;
567};
568
569template <typename Mt, typename S>
570struct diagonalmatrix_scalar_op_s<
571 Mt, S,
572 typename std::enable_if_t<std::is_convertible<hila::type_plus<hila::number_type<Mt>, S>,
573 hila::number_type<Mt>>::value>> {
574 // using type = Mt;
575 using type = typename std::conditional<
576 hila::is_floating_point<hila::arithmetic_type<Mt>>::value, Mt,
577 DiagonalMatrix<Mt::rows(),
578 hila::type_plus<hila::arithmetic_type<Mt>, hila::arithmetic_type<S>>>>::type;
579};
580
581template <typename Mt, typename S>
582using diagonalmatrix_scalar_type = typename diagonalmatrix_scalar_op_s<Mt, S>::type;
583
584} // namespace hila
585
586
587/**
588 * operators: can add scalar, diagonal matrix or square matrix.
589 */
590
591// diagonal + scalar
592template <int n, typename T, typename S,
593 std::enable_if_t<hila::is_complex_or_arithmetic<S>::value, int> = 0,
594 typename Rtype = hila::diagonalmatrix_scalar_type<DiagonalMatrix<n, T>, S>>
595inline Rtype operator+(const DiagonalMatrix<n, T> &a, const S &b) {
596 Rtype res;
597 for (int i = 0; i < n; i++)
598 res.e(i) = a.e(i) + b;
599 return res;
600}
601
602// diagonal + scalar
603template <int n, typename T, typename S,
604 std::enable_if_t<hila::is_complex_or_arithmetic<S>::value, int> = 0,
605 typename Rtype = hila::diagonalmatrix_scalar_type<DiagonalMatrix<n, T>, S>>
606inline Rtype operator+(const S &b, const DiagonalMatrix<n, T> &a) {
607 return a + b;
608}
609
610// diagonal - scalar
611template <int n, typename T, typename S,
612 std::enable_if_t<hila::is_complex_or_arithmetic<S>::value, int> = 0,
613 typename Rtype = hila::diagonalmatrix_scalar_type<DiagonalMatrix<n, T>, S>>
614inline Rtype operator-(const DiagonalMatrix<n, T> &a, const S &b) {
615 Rtype res;
616 for (int i = 0; i < n; i++)
617 res.e(i) = a.e(i) - b;
618 return res;
619}
620
621// scalar - diagonal
622template <int n, typename T, typename S,
623 std::enable_if_t<hila::is_complex_or_arithmetic<S>::value, int> = 0,
624 typename Rtype = hila::diagonalmatrix_scalar_type<DiagonalMatrix<n, T>, S>>
625inline Rtype operator-(const S &b, const DiagonalMatrix<n, T> &a) {
626 Rtype res;
627 for (int i = 0; i < n; i++)
628 res.e(i) = b - a.e(i);
629 return res;
630}
631
632// diagonal * scalar
633template <int n, typename T, typename S,
634 std::enable_if_t<hila::is_complex_or_arithmetic<S>::value, int> = 0,
635 typename Rtype = hila::diagonalmatrix_scalar_type<DiagonalMatrix<n, T>, S>>
636inline Rtype operator*(const DiagonalMatrix<n, T> &a, const S &b) {
637 Rtype res;
638 for (int i = 0; i < n; i++)
639 res.e(i) = a.e(i) * b;
640 return res;
641}
642
643// scalar * diagonal
644template <int n, typename T, typename S,
645 std::enable_if_t<hila::is_complex_or_arithmetic<S>::value, int> = 0,
646 typename Rtype = hila::diagonalmatrix_scalar_type<DiagonalMatrix<n, T>, S>>
647inline Rtype operator*(const S &b, const DiagonalMatrix<n, T> &a) {
648 return a * b;
649}
650
651// diagonal / scalar
652template <int n, typename T, typename S,
653 std::enable_if_t<hila::is_complex_or_arithmetic<S>::value, int> = 0,
654 typename Rtype = hila::diagonalmatrix_scalar_type<DiagonalMatrix<n, T>, S>>
655inline Rtype operator/(const DiagonalMatrix<n, T> &a, const S &b) {
656 Rtype res;
657 for (int i = 0; i < n; i++)
658 res.e(i) = a.e(i) / b;
659 return res;
660}
661
662// scalar / diagonal
663template <int n, typename T, typename S,
664 std::enable_if_t<hila::is_complex_or_arithmetic<S>::value, int> = 0,
665 typename Rtype = hila::diagonalmatrix_scalar_type<DiagonalMatrix<n, T>, S>>
666inline Rtype operator/(const S &b, const DiagonalMatrix<n, T> &a) {
667 Rtype res;
668 for (int i = 0; i < n; i++)
669 res.e(i) = b / a.e(i);
670 return res;
671}
672
673/////
674// diagonal X diagonal
675template <int n, typename A, typename B, typename R = hila::type_plus<A, B>>
676inline auto operator+(const DiagonalMatrix<n, A> &a, const DiagonalMatrix<n, B> &b) {
678 for (int i = 0; i < n; i++)
679 res.e(i) = a.e(i) + b.e(i);
680 return res;
681}
682
683template <int n, typename A, typename B, typename R = hila::type_minus<A, B>>
684inline auto operator-(const DiagonalMatrix<n, A> &a, const DiagonalMatrix<n, B> &b) {
686 for (int i = 0; i < n; i++)
687 res.e(i) = a.e(i) - b.e(i);
688 return res;
689}
690
691template <int n, typename A, typename B, typename R = hila::type_mul<A, B>>
692inline auto operator*(const DiagonalMatrix<n, A> &a, const DiagonalMatrix<n, B> &b) {
694 for (int i = 0; i < n; i++)
695 res.e(i) = a.e(i) * b.e(i);
696 return res;
697}
698
699template <int n, typename A, typename B, typename R = hila::type_div<A, B>>
700inline auto operator/(const DiagonalMatrix<n, A> &a, const DiagonalMatrix<n, B> &b) {
702 for (int i = 0; i < n; i++)
703 res.e(i) = a.e(i) / b.e(i);
704 return res;
705}
706
707//// Finally, diagonal X Matrix ops - gives Matrix
708
709template <int n, typename T, typename Mtype, std::enable_if_t<Mtype::is_matrix(), int> = 0,
710 typename Rtype = hila::mat_x_mat_type<Matrix<n, n, T>, Mtype>>
711inline Rtype operator+(const DiagonalMatrix<n, T> &a, const Mtype &b) {
712
713 constexpr int mr = Mtype::rows();
714 constexpr int mc = Mtype::columns();
715
716 static_assert(mc == n && mr == n, "Matrix sizes do not match");
717
718 Rtype r;
719 r = b;
720 for (int i = 0; i < n; i++)
721 r.e(i, i) += a.e(i);
722 return r;
723}
724
725template <int n, typename T, typename Mtype, std::enable_if_t<Mtype::is_matrix(), int> = 0,
726 typename Rtype = hila::mat_x_mat_type<Matrix<n, n, T>, Mtype>>
727inline Rtype operator+(const Mtype &b, const DiagonalMatrix<n, T> &a) {
728 return a + b;
729}
730
731template <int n, typename T, typename Mtype, std::enable_if_t<Mtype::is_matrix(), int> = 0,
732 typename Rtype = hila::mat_x_mat_type<Matrix<n, n, T>, Mtype>>
733inline Rtype operator-(const DiagonalMatrix<n, T> &a, const Mtype &b) {
734
735 constexpr int mr = Mtype::rows();
736 constexpr int mc = Mtype::columns();
737
738 static_assert(mc == n && mr == n, "Matrix sizes do not match");
739
740 Rtype r = -b;
741 for (int i = 0; i < n; i++)
742 r.e(i, i) += a.e(i);
743 return r;
744}
745
746template <int n, typename T, typename Mtype, std::enable_if_t<Mtype::is_matrix(), int> = 0,
747 typename Rtype = hila::mat_x_mat_type<Matrix<n, n, T>, Mtype>>
748inline Rtype operator-(const Mtype &b, const DiagonalMatrix<n, T> &a) {
749 constexpr int mr = Mtype::rows();
750 constexpr int mc = Mtype::columns();
751
752 static_assert(mc == n && mr == n, "Matrix sizes do not match");
753
754 Rtype r = b;
755 for (int i = 0; i < n; i++)
756 r.e(i, i) -= a.e(i);
757 return r;
758}
759
760
761// multiply by matrix
762
763template <int n, typename T, typename Mtype, std::enable_if_t<Mtype::is_matrix(), int> = 0,
764 typename Rtype = hila::mat_x_mat_type<Matrix<n, n, T>, Mtype>>
765inline Rtype operator*(const DiagonalMatrix<n, T> &a, const Mtype &b) {
766
767 constexpr int mr = Mtype::rows();
768 constexpr int mc = Mtype::columns();
769
770 static_assert(mr == n, "Matrix sizes do not match");
771
772 Rtype r;
773 for (int i = 0; i < n; i++)
774 for (int j = 0; j < mc; j++)
775 r.e(i, j) = a.e(i) * b.e(i, j);
776 return r;
777}
778
779template <int n, typename T, typename Mtype, std::enable_if_t<Mtype::is_matrix(), int> = 0,
780 typename Rtype = hila::mat_x_mat_type<Matrix<n, n, T>, Mtype>>
781inline Rtype operator*(const Mtype &b, const DiagonalMatrix<n, T> &a) {
782
783 constexpr int mr = Mtype::rows();
784 constexpr int mc = Mtype::columns();
785
786 static_assert(mc == n, "Matrix sizes do not match");
787
788 Rtype r;
789 for (int i = 0; i < mr; i++)
790 for (int j = 0; j < n; j++)
791 r.e(i, j) = b.e(i, j) * a.e(j);
792 return r;
793}
794
795// division
796template <int n, typename T, typename Mtype, std::enable_if_t<Mtype::is_matrix(), int> = 0,
797 typename Rtype = hila::mat_x_mat_type<Matrix<n, n, T>, Mtype>>
798inline Rtype operator/(const Mtype &b, const DiagonalMatrix<n, T> &a) {
799
800 constexpr int mr = Mtype::rows();
801 constexpr int mc = Mtype::columns();
802
803 static_assert(mc == n, "Matrix sizes do not match");
804
805 Rtype r;
806 for (int i = 0; i < mr; i++)
807 for (int j = 0; j < n; j++)
808 r.e(i, j) = b.e(i, j) / a.e(j);
809 return r;
810}
811
812////////////////////////////////////////////////////////////////////////////////
813/// Standard arithmetic functions - do element by element
814////////////////////////////////////////////////////////////////////////////////
815
816template <int n, typename T>
818 for (int i = 0; i < n; i++)
819 a.c[i] = sqrt(a.c[i]);
820 return a;
821}
822
823template <int n, typename T>
825 for (int i = 0; i < n; i++)
826 a.c[i] = cbrt(a.c[i]);
827 return a;
828}
829
830template <int n, typename T>
832 for (int i = 0; i < n; i++)
833 a.c[i] = exp(a.c[i]);
834 return a;
835}
836
837template <int n, typename T>
839 for (int i = 0; i < n; i++)
840 a.c[i] = log(a.c[i]);
841 return a;
842}
843
844template <int n, typename T>
846 for (int i = 0; i < n; i++)
847 a.c[i] = sin(a.c[i]);
848 return a;
849}
850
851template <int n, typename T>
853 for (int i = 0; i < n; i++)
854 a.c[i] = cos(a.c[i]);
855 return a;
856}
857
858template <int n, typename T>
860 for (int i = 0; i < n; i++)
861 a.c[i] = tan(a.c[i]);
862 return a;
863}
864
865template <int n, typename T>
867 for (int i = 0; i < n; i++)
868 a.c[i] = asin(a.c[i]);
869 return a;
870}
871
872template <int n, typename T>
874 for (int i = 0; i < n; i++)
875 a.c[i] = acos(a.c[i]);
876 return a;
877}
878
879template <int n, typename T>
881 for (int i = 0; i < n; i++)
882 a.c[i] = atan(a.c[i]);
883 return a;
884}
885
886template <int n, typename T>
888 for (int i = 0; i < n; i++)
889 a.c[i] = sinh(a.c[i]);
890 return a;
891}
892
893template <int n, typename T>
895 for (int i = 0; i < n; i++)
896 a.c[i] = cosh(a.c[i]);
897 return a;
898}
899
900template <int n, typename T>
902 for (int i = 0; i < n; i++)
903 a.c[i] = tanh(a.c[i]);
904 return a;
905}
906
907template <int n, typename T>
909 for (int i = 0; i < n; i++)
910 a.c[i] = asinh(a.c[i]);
911 return a;
912}
913
914template <int n, typename T>
916 for (int i = 0; i < n; i++)
917 a.c[i] = acosh(a.c[i]);
918 return a;
919}
920
921template <int n, typename T>
923 for (int i = 0; i < n; i++)
924 a.c[i] = atanh(a.c[i]);
925 return a;
926}
927
928
929// return pow of diagonalMatrix as original type if power is scalar or diagonal is complex
930template <
931 int n, typename T, typename S,
932 std::enable_if_t<hila::is_arithmetic<S>::value || hila::contains_complex<T>::value, int> = 0>
934 for (int i = 0; i < n; i++)
935 a.c[i] = pow(a.c[i], p);
936 return a;
937}
938
939// if power is complex but matrix is scalar need to upgrade return type
940template <int n, typename T, typename S,
941 std::enable_if_t<!hila::contains_complex<T>::value, int> = 0>
942inline auto pow(const DiagonalMatrix<n, T> &a, const Complex<S> &p) {
944 for (int i = 0; i < n; i++)
945 res.e(i) = pow(a.e(i), p);
946 return res;
947}
948
949// Cast operators to different number or Complex type
950// cast_to<double>(a);
951// Cast from number->number, number->Complex, Complex->Complex OK,
952// Complex->number not.
953
954template <typename Ntype, typename T, int n,
955 std::enable_if_t<hila::is_arithmetic<T>::value, int> = 0>
958 for (int i = 0; i < n; i++)
959 res.c[i] = mat.c[i];
960 return res;
961}
962
963template <typename Ntype, typename T, int n, std::enable_if_t<hila::is_complex<T>::value, int> = 0>
966 for (int i = 0; i < n; i++)
967 res.c[i] = cast_to<Ntype>(mat.c[i]);
968 return res;
969}
970
971/// Stream operator
972template <int n, typename T>
973std::ostream &operator<<(std::ostream &strm, const DiagonalMatrix<n, T> &A) {
974 return operator<<(strm, A.asArray());
975}
976
977namespace hila {
978
979template <int n, typename T>
980std::string to_string(const DiagonalMatrix<n, T> &A, int prec = 8, char separator = ' ') {
981 return to_string(A.asArray(), prec, separator);
982}
983
984template <int n, typename T>
985std::string prettyprint(const DiagonalMatrix<n, T> &A, int prec = 8) {
986 return prettyprint(A.toMatrix(), prec);
987}
988
989} // namespace hila
990
991
992#endif
Array< n, m, T > conj(const Array< n, m, T > &arg)
Return conjugate Array.
Definition array.h:648
std::ostream & operator<<(std::ostream &strm, const Array< n, m, T > &A)
Stream operator.
Definition array.h:916
Array< n, m, hila::arithmetic_type< T > > imag(const Array< n, m, T > &arg)
Return imaginary part of Array.
Definition array.h:676
hila::arithmetic_type< T > squarenorm(const Array< n, m, T > &rhs)
Return square norm of Array.
Definition array.h:957
Array< n, m, hila::arithmetic_type< T > > real(const Array< n, m, T > &arg)
Return real part of Array.
Definition array.h:662
Array type
Definition array.h:43
Complex definition.
Definition cmplx.h:50
Define type DiagonalMatrix<n,T>
bool operator==(const Matrix_t< n, n, S, Mtype > &rhs) const
Boolean operator == to compare with square matrix.
Matrix< n, n, T > toMatrix() const
convert to generic matrix
const DiagonalMatrix & transpose() const
transpose - leaves diagonal matrix as is
DiagonalMatrix & operator=(const DiagonalMatrix< n, S > &rhs)
bool operator!=(const S &rhs) const
Boolean operator != to check if matrices are exactly different.
Vector< n, T > column(int i) const
Returns column vector i.
T e(const int i) const
Element access - e(i) gives diagonal element i.
DiagonalMatrix conj() const
conj - conjugate elements
DiagonalMatrix & random()
Fills Matrix with random elements.
std::string str(int prec=8, char separator=' ') const
convert to string for printing
RowVector< n, T > row(int i) const
Return row from Diagonal matrix.
DiagonalMatrix dagger() const
dagger - conjugate elements
DiagonalMatrix()=default
Define default constructors to ensure std::is_trivial.
DiagonalMatrix< n, T > operator-() const
Unary - operator.
DiagonalMatrix & fill(const S rhs)
fill with constant value - same as assignment
T max() const
Find max or min value - only for arithmetic types.
static constexpr int rows()
Returns matrix size - all give same result.
const auto & operator+() const
Unary + operator.
DiagonalMatrix< n, T > sort(Vector< N, int > &permutation, hila::sort order=hila::sort::ascending) const
implement sort as casting to matrix
bool operator==(const DiagonalMatrix< n, S > &rhs) const
Boolean operator == to determine if two matrices are exactly the same.
auto real() const
real and imaginary parts of diagonal matrix
auto abs() const
absolute value of all elements
The main matrix type template Matrix_t. This is a base class type for "useful" types which are deriv...
Definition matrix.h:102
static constexpr int rows()
Define constant methods rows(), columns() - may be useful in template code.
Definition matrix.h:220
static constexpr int columns()
Returns column length.
Definition matrix.h:228
T e(const int i, const int j) const
Standard array indexing operation for matrices.
Definition matrix.h:272
Matrix class which defines matrix operations.
Definition matrix.h:1679
Complex< T > dagger(const Complex< T > &val)
Return dagger of Complex number.
Definition cmplx.h:1358
T abs(const Complex< T > &a)
Return absolute value of Complex number.
Definition cmplx.h:1322
T arg(const Complex< T > &a)
Return argument of Complex number.
Definition cmplx.h:1334
Definition of Matrix types.
Invert diagonal + const. matrix using Sherman-Morrison formula.
Definition array.h:920
T gaussian_random()
Template function T hila::gaussian_random<T>(),generates gaussian random value of type T,...
Definition random.h:183
logger_class log
Now declare the logger.
double random()
Real valued uniform random number generator.
Definition hila_gpu.cpp:120
decltype(std::declval< A >()+std::declval< B >()) type_plus
Definition type_tools.h:103
std::string to_string(const Array< n, m, T > &A, int prec=8, char separator=' ')
Converts Array object to string.
Definition array.h:934
std:swap() for Fields
Definition field.h:1847
hila::is_complex_or_arithmetic<T>::value
Definition cmplx.h:750