00001 #ifndef RIVET_MATH_MATRIXN
00002 #define RIVET_MATH_MATRIXN
00003
00004 #include "Rivet/Math/MathHeader.hh"
00005 #include "Rivet/Math/MathUtils.hh"
00006 #include "Rivet/Math/Vectors.hh"
00007
00008 #include "Rivet/Math/eigen/matrix.h"
00009
00010 namespace Rivet {
00011
00012
00013 template <size_t N>
00014 class Matrix;
00015 typedef Matrix<4> Matrix4;
00016
00017 template <size_t N>
00018 Matrix<N> multiply(const Matrix<N>& a, const Matrix<N>& b);
00019 template <size_t N>
00020 Matrix<N> divide(const Matrix<N>&, const double);
00021 template <size_t N>
00022 Matrix<N> operator*(const Matrix<N>& a, const Matrix<N>& b);
00023
00024
00025
00026
00027
00028 template <size_t N>
00029 class Matrix {
00030 template <size_t M>
00031 friend Matrix<M> add(const Matrix<M>&, const Matrix<M>&);
00032 template <size_t M>
00033 friend Matrix<M> multiply(const double, const Matrix<M>&);
00034 template <size_t M>
00035 friend Matrix<M> multiply(const Matrix<M>&, const Matrix<M>&);
00036 template <size_t M>
00037 friend Vector<M> multiply(const Matrix<M>&, const Vector<M>&);
00038 template <size_t M>
00039 friend Matrix<M> divide(const Matrix<M>&, const double);
00040
00041 public:
00042 static Matrix<N> mkZero() {
00043 Matrix<N> rtn;
00044 return rtn;
00045 }
00046
00047 static Matrix<N> mkDiag(Vector<N> diag) {
00048 Matrix<N> rtn;
00049 for (size_t i = 0; i < N; ++i) {
00050 rtn.set(i, i, diag[i]);
00051 }
00052 return rtn;
00053 }
00054
00055 static Matrix<N> mkIdentity() {
00056 Matrix<N> rtn;
00057 for (size_t i = 0; i < N; ++i) {
00058 rtn.set(i, i, 1);
00059 }
00060 return rtn;
00061 }
00062
00063
00064 public:
00065
00066 Matrix() {
00067 _matrix.loadZero();
00068 }
00069
00070 Matrix(const Matrix<N>& other) {
00071 _matrix = other._matrix;
00072 }
00073
00074 Matrix& set(const size_t i, const size_t j, const double value) {
00075 if (i < N && j < N) {
00076 _matrix(i, j) = value;
00077 } else {
00078 throw std::runtime_error("Attempted set access outside matrix bounds.");
00079 }
00080 return *this;
00081 }
00082
00083 double get(const size_t i, const size_t j) const {
00084 if (i < N && j < N) {
00085 return _matrix(i, j);
00086 } else {
00087 throw std::runtime_error("Attempted get access outside matrix bounds.");
00088 }
00089 }
00090
00091 Vector<N> getRow(const size_t row) const {
00092 Vector<N> rtn;
00093 for (size_t i = 0; i < N; ++i) {
00094 rtn.set(i, _matrix(row,i));
00095 }
00096 return rtn;
00097 }
00098
00099 Matrix<N>& setRow(const size_t row, const Vector<N>& r) {
00100 for (size_t i = 0; i < N; ++i) {
00101 _matrix(row,i) = r.get(i);
00102 }
00103 return *this;
00104 }
00105
00106 Vector<N> getColumn(const size_t col) const {
00107 const Eigen::Vector<double,N> eVec = _matrix.column(col);
00108 Vector<N> rtn;
00109 for (size_t i = 0; i < N; ++i) {
00110 rtn.set(i, _matrix(i,col));
00111 }
00112 return rtn;
00113 }
00114
00115 Matrix<N>& setColumn(const size_t col, const Vector<N>& c) {
00116 for (size_t i = 0; i < N; ++i) {
00117 _matrix(i,col) = c.get(i);
00118 }
00119 return *this;
00120 }
00121
00122 Matrix<N> transpose() const {
00123 Matrix<N> tmp = *this;
00124 tmp._matrix.replaceWithAdjoint();
00125 return tmp;
00126 }
00127
00128
00129
00130
00131
00132
00133
00134 Matrix<N> inverse() const {
00135 Matrix<N> tmp;
00136 tmp._matrix = _matrix.inverse();
00137 return tmp;
00138 }
00139
00140
00141 double det() const {
00142 return _matrix.determinant();
00143 }
00144
00145
00146 double trace() const {
00147 double tr = 0.0;
00148 for (size_t i = 0; i < N; ++i) {
00149 tr += _matrix(i,i);
00150 }
00151 return tr;
00152
00153 }
00154
00155
00156 Matrix<N> operator-() const {
00157 Matrix<N> rtn;
00158 rtn._matrix = -_matrix;
00159 return rtn;
00160 }
00161
00162
00163 size_t size() const {
00164 return N;
00165 }
00166
00167
00168 bool isZero(double tolerance=1E-5) const {
00169 for (size_t i=0; i < N; ++i) {
00170 for (size_t j=0; j < N; ++j) {
00171 if (! Rivet::isZero(_matrix(i,j), tolerance) ) return false;
00172 }
00173 }
00174 return true;
00175 }
00176
00177
00178 bool isEqual(Matrix<N> other) const {
00179 for (size_t i=0; i < N; ++i) {
00180 for (size_t j=i; j < N; ++j) {
00181 if (! Rivet::isZero(_matrix(i,j) - other._matrix(i,j)) ) return false;
00182 }
00183 }
00184 return true;
00185 }
00186
00187
00188 bool isSymm() const {
00189 return isEqual(this->transpose());
00190 }
00191
00192
00193 bool isDiag() const {
00194 for (size_t i=0; i < N; ++i) {
00195 for (size_t j=0; j < N; ++j) {
00196 if (i == j) continue;
00197 if (! Rivet::isZero(_matrix(i,j)) ) return false;
00198 }
00199 }
00200 return true;
00201 }
00202
00203 bool operator==(const Matrix<N>& a) const {
00204 return _matrix == a._matrix;
00205 }
00206
00207 bool operator!=(const Matrix<N>& a) const {
00208 return _matrix != a._matrix;
00209 }
00210
00211 bool operator<(const Matrix<N>& a) const {
00212 return _matrix < a._matrix;
00213 }
00214
00215 bool operator<=(const Matrix<N>& a) const {
00216 return _matrix <= a._matrix;
00217 }
00218
00219 bool operator>(const Matrix<N>& a) const {
00220 return _matrix > a._matrix;
00221 }
00222
00223 bool operator>=(const Matrix<N>& a) const {
00224 return _matrix >= a._matrix;
00225 }
00226
00227 Matrix<N>& operator*=(const Matrix<N>& m) {
00228 _matrix = _matrix * m._matrix;
00229 return *this;
00230 }
00231
00232 Matrix<N>& operator*=(const double a) {
00233 _matrix *= a;
00234 return *this;
00235 }
00236
00237 Matrix<N>& operator/=(const double a) {
00238 _matrix /= a;
00239 return *this;
00240 }
00241
00242 Matrix<N>& operator+=(const Matrix<N>& m) {
00243 _matrix += m._matrix;
00244 return *this;
00245 }
00246
00247 Matrix<N>& operator-=(const Matrix<N>& m) {
00248 _matrix -= m._matrix;
00249 return *this;
00250 }
00251
00252 protected:
00253 typedef Eigen::Matrix<double,N> EMatrix;
00254 EMatrix _matrix;
00255 };
00256
00257
00258
00259
00260
00261 template <size_t N>
00262 inline Matrix<N> add(const Matrix<N>& a, const Matrix<N>& b) {
00263 Matrix<N> result;
00264 result._matrix = a._matrix + b._matrix;
00265 return result;
00266 }
00267
00268 template <size_t N>
00269 inline Matrix<N> subtract(const Matrix<N>& a, const Matrix<N>& b) {
00270 return add(a, -b);
00271 }
00272
00273 template <size_t N>
00274 inline Matrix<N> operator+(const Matrix<N> a, const Matrix<N>& b) {
00275 return add(a, b);
00276 }
00277
00278 template <size_t N>
00279 inline Matrix<N> operator-(const Matrix<N> a, const Matrix<N>& b) {
00280 return subtract(a, b);
00281 }
00282
00283 template <size_t N>
00284 inline Matrix<N> multiply(const double a, const Matrix<N>& m) {
00285 Matrix<N> rtn;
00286 rtn._matrix = a * m._matrix;
00287 return rtn;
00288 }
00289
00290 template <size_t N>
00291 inline Matrix<N> multiply(const Matrix<N>& m, const double a) {
00292 return multiply(a, m);
00293 }
00294
00295 template <size_t N>
00296 inline Matrix<N> divide(const Matrix<N>& m, const double a) {
00297 return multiply(1/a, m);
00298 }
00299
00300 template <size_t N>
00301 inline Matrix<N> operator*(const double a, const Matrix<N>& m) {
00302 return multiply(a, m);
00303 }
00304
00305 template <size_t N>
00306 inline Matrix<N> operator*(const Matrix<N>& m, const double a) {
00307 return multiply(a, m);
00308 }
00309
00310 template <size_t N>
00311 inline Matrix<N> multiply(const Matrix<N>& a, const Matrix<N>& b) {
00312 Matrix<N> tmp;
00313 tmp._matrix = a._matrix * b._matrix;
00314 return tmp;
00315 }
00316
00317 template <size_t N>
00318 inline Matrix<N> operator*(const Matrix<N>& a, const Matrix<N>& b) {
00319 return multiply(a, b);
00320 }
00321
00322
00323 template <size_t N>
00324 inline Vector<N> multiply(const Matrix<N>& a, const Vector<N>& b) {
00325 Vector<N> tmp;
00326 tmp._vec = a._matrix * b._vec;
00327 return tmp;
00328 }
00329
00330 template <size_t N>
00331 inline Vector<N> operator*(const Matrix<N>& a, const Vector<N>& b) {
00332 return multiply(a, b);
00333 }
00334
00335 template <size_t N>
00336 inline Matrix<N> transpose(const Matrix<N>& m) {
00337
00338
00339
00340
00341
00342
00343
00344 return m.transpose();
00345 }
00346
00347 template <size_t N>
00348 inline Matrix<N> inverse(const Matrix<N>& m) {
00349 return m.inverse();
00350 }
00351
00352 template <size_t N>
00353 inline double det(const Matrix<N>& m) {
00354 return m.determinant();
00355 }
00356
00357 template <size_t N>
00358 inline double trace(const Matrix<N>& m) {
00359 return m.trace();
00360 }
00361
00362
00363
00364
00365
00366
00367 template <size_t N>
00368 inline string toString(const Matrix<N>& m) {
00369 ostringstream ss;
00370 ss << "[ ";
00371 for (size_t i = 0; i < m.size(); ++i) {
00372 ss << "( ";
00373 for (size_t j = 0; j < m.size(); ++j) {
00374 const double e = m.get(i, j);
00375 ss << (Rivet::isZero(e) ? 0.0 : e) << " ";
00376 }
00377 ss << ") ";
00378 }
00379 ss << "]";
00380 return ss.str();
00381 }
00382
00383
00384
00385 template <size_t N>
00386 inline ostream& operator<<(std::ostream& out, const Matrix<N>& m) {
00387 out << toString(m);
00388 return out;
00389 }
00390
00391
00392
00393
00394
00395
00396 template <size_t N>
00397 inline bool fuzzyEquals(const Matrix<N>& ma, const Matrix<N>& mb, double tolerance=1E-5) {
00398 for (size_t i = 0; i < N; ++i) {
00399 for (size_t j = 0; j < N; ++j) {
00400 const double a = ma.get(i, j);
00401 const double b = mb.get(i, j);
00402 if (!Rivet::fuzzyEquals(a, b, tolerance)) return false;
00403 }
00404 }
00405 return true;
00406 }
00407
00408
00409
00410 template <size_t N>
00411 inline bool isZero(const Matrix<N>& m, double tolerance=1E-5) {
00412 return m.isZero(tolerance);
00413 }
00414
00415
00416 }
00417
00418 #endif