00001 #ifndef RIVET_MATH_MATRIXN
00002 #define RIVET_MATH_MATRIXN
00003
00004 #include "Rivet/Math/MathHeader.hh"
00005 #include "Rivet/Math/MathUtils.hh"
00006 #include "Rivet/Math/Vectors.hh"
00007
00008 #include "Rivet/Math/eigen/matrix.h"
00009
00010 namespace Rivet {
00011
00012
00013 template <size_t N>
00014 class Matrix;
00015 typedef Matrix<4> Matrix4;
00016
00017 template <size_t N>
00018 Matrix<N> multiply(const Matrix<N>& a, const Matrix<N>& b);
00019 template <size_t N>
00020 Matrix<N> divide(const Matrix<N>&, const double);
00021 template <size_t N>
00022 Matrix<N> operator*(const Matrix<N>& a, const Matrix<N>& b);
00023
00024
00025
00026
00027
00028
00029 template <size_t N>
00030 class Matrix {
00031 template <size_t M>
00032 friend Matrix<M> add(const Matrix<M>&, const Matrix<M>&);
00033 template <size_t M>
00034 friend Matrix<M> multiply(const double, const Matrix<M>&);
00035 template <size_t M>
00036 friend Matrix<M> multiply(const Matrix<M>&, const Matrix<M>&);
00037 template <size_t M>
00038 friend Vector<M> multiply(const Matrix<M>&, const Vector<M>&);
00039 template <size_t M>
00040 friend Matrix<M> divide(const Matrix<M>&, const double);
00041
00042 public:
00043 static Matrix<N> mkZero() {
00044 Matrix<N> rtn;
00045 return rtn;
00046 }
00047
00048 static Matrix<N> mkDiag(Vector<N> diag) {
00049 Matrix<N> rtn;
00050 for (size_t i = 0; i < N; ++i) {
00051 rtn.set(i, i, diag[i]);
00052 }
00053 return rtn;
00054 }
00055
00056 static Matrix<N> mkIdentity() {
00057 Matrix<N> rtn;
00058 for (size_t i = 0; i < N; ++i) {
00059 rtn.set(i, i, 1);
00060 }
00061 return rtn;
00062 }
00063
00064
00065 public:
00066
00067 Matrix() {
00068 _matrix.loadZero();
00069 }
00070
00071 Matrix(const Matrix<N>& other) {
00072 _matrix = other._matrix;
00073 }
00074
00075 Matrix& set(const size_t i, const size_t j, const double value) {
00076 if (i < N && j < N) {
00077 _matrix(i, j) = value;
00078 } else {
00079 throw std::runtime_error("Attempted set access outside matrix bounds.");
00080 }
00081 return *this;
00082 }
00083
00084 double get(const size_t i, const size_t j) const {
00085 if (i < N && j < N) {
00086 return _matrix(i, j);
00087 } else {
00088 throw std::runtime_error("Attempted get access outside matrix bounds.");
00089 }
00090 }
00091
00092 Vector<N> getRow(const size_t row) const {
00093 Vector<N> rtn;
00094 for (size_t i = 0; i < N; ++i) {
00095 rtn.set(i, _matrix(row,i));
00096 }
00097 return rtn;
00098 }
00099
00100 Matrix<N>& setRow(const size_t row, const Vector<N>& r) {
00101 for (size_t i = 0; i < N; ++i) {
00102 _matrix(row,i) = r.get(i);
00103 }
00104 return *this;
00105 }
00106
00107 Vector<N> getColumn(const size_t col) const {
00108 const Eigen::Vector<double,N> eVec = _matrix.column(col);
00109 Vector<N> rtn;
00110 for (size_t i = 0; i < N; ++i) {
00111 rtn.set(i, _matrix(i,col));
00112 }
00113 return rtn;
00114 }
00115
00116 Matrix<N>& setColumn(const size_t col, const Vector<N>& c) {
00117 for (size_t i = 0; i < N; ++i) {
00118 _matrix(i,col) = c.get(i);
00119 }
00120 return *this;
00121 }
00122
00123 Matrix<N> transpose() const {
00124 Matrix<N> tmp = *this;
00125 tmp._matrix.replaceWithAdjoint();
00126 return tmp;
00127 }
00128
00129
00130
00131
00132
00133
00134
00135 Matrix<N> inverse() const {
00136 Matrix<N> tmp;
00137 tmp._matrix = _matrix.inverse();
00138 return tmp;
00139 }
00140
00141
00142 double det() const {
00143 return _matrix.determinant();
00144 }
00145
00146
00147 double trace() const {
00148 double tr = 0.0;
00149 for (size_t i = 0; i < N; ++i) {
00150 tr += _matrix(i,i);
00151 }
00152 return tr;
00153
00154 }
00155
00156
00157 Matrix<N> operator-() const {
00158 Matrix<N> rtn;
00159 rtn._matrix = -_matrix;
00160 return rtn;
00161 }
00162
00163
00164 size_t size() const {
00165 return N;
00166 }
00167
00168
00169 bool isZero(double tolerance=1E-5) const {
00170 for (size_t i=0; i < N; ++i) {
00171 for (size_t j=0; j < N; ++j) {
00172 if (! Rivet::isZero(_matrix(i,j), tolerance) ) return false;
00173 }
00174 }
00175 return true;
00176 }
00177
00178
00179 bool isEqual(Matrix<N> other) const {
00180 for (size_t i=0; i < N; ++i) {
00181 for (size_t j=i; j < N; ++j) {
00182 if (! Rivet::isZero(_matrix(i,j) - other._matrix(i,j)) ) return false;
00183 }
00184 }
00185 return true;
00186 }
00187
00188
00189 bool isSymm() const {
00190 return isEqual(this->transpose());
00191 }
00192
00193
00194 bool isDiag() const {
00195 for (size_t i=0; i < N; ++i) {
00196 for (size_t j=0; j < N; ++j) {
00197 if (i == j) continue;
00198 if (! Rivet::isZero(_matrix(i,j)) ) return false;
00199 }
00200 }
00201 return true;
00202 }
00203
00204 bool operator==(const Matrix<N>& a) const {
00205 return _matrix == a._matrix;
00206 }
00207
00208 bool operator!=(const Matrix<N>& a) const {
00209 return _matrix != a._matrix;
00210 }
00211
00212 bool operator<(const Matrix<N>& a) const {
00213 return _matrix < a._matrix;
00214 }
00215
00216 bool operator<=(const Matrix<N>& a) const {
00217 return _matrix <= a._matrix;
00218 }
00219
00220 bool operator>(const Matrix<N>& a) const {
00221 return _matrix > a._matrix;
00222 }
00223
00224 bool operator>=(const Matrix<N>& a) const {
00225 return _matrix >= a._matrix;
00226 }
00227
00228 Matrix<N>& operator*=(const Matrix<N>& m) {
00229 _matrix = _matrix * m._matrix;
00230 return *this;
00231 }
00232
00233 Matrix<N>& operator*=(const double a) {
00234 _matrix *= a;
00235 return *this;
00236 }
00237
00238 Matrix<N>& operator/=(const double a) {
00239 _matrix /= a;
00240 return *this;
00241 }
00242
00243 Matrix<N>& operator+=(const Matrix<N>& m) {
00244 _matrix += m._matrix;
00245 return *this;
00246 }
00247
00248 Matrix<N>& operator-=(const Matrix<N>& m) {
00249 _matrix -= m._matrix;
00250 return *this;
00251 }
00252
00253 protected:
00254 typedef Eigen::Matrix<double,N> EMatrix;
00255 EMatrix _matrix;
00256 };
00257
00258
00259
00260
00261
00262 template <size_t N>
00263 inline Matrix<N> add(const Matrix<N>& a, const Matrix<N>& b) {
00264 Matrix<N> result;
00265 result._matrix = a._matrix + b._matrix;
00266 return result;
00267 }
00268
00269 template <size_t N>
00270 inline Matrix<N> subtract(const Matrix<N>& a, const Matrix<N>& b) {
00271 return add(a, -b);
00272 }
00273
00274 template <size_t N>
00275 inline Matrix<N> operator+(const Matrix<N> a, const Matrix<N>& b) {
00276 return add(a, b);
00277 }
00278
00279 template <size_t N>
00280 inline Matrix<N> operator-(const Matrix<N> a, const Matrix<N>& b) {
00281 return subtract(a, b);
00282 }
00283
00284 template <size_t N>
00285 inline Matrix<N> multiply(const double a, const Matrix<N>& m) {
00286 Matrix<N> rtn;
00287 rtn._matrix = a * m._matrix;
00288 return rtn;
00289 }
00290
00291 template <size_t N>
00292 inline Matrix<N> multiply(const Matrix<N>& m, const double a) {
00293 return multiply(a, m);
00294 }
00295
00296 template <size_t N>
00297 inline Matrix<N> divide(const Matrix<N>& m, const double a) {
00298 return multiply(1/a, m);
00299 }
00300
00301 template <size_t N>
00302 inline Matrix<N> operator*(const double a, const Matrix<N>& m) {
00303 return multiply(a, m);
00304 }
00305
00306 template <size_t N>
00307 inline Matrix<N> operator*(const Matrix<N>& m, const double a) {
00308 return multiply(a, m);
00309 }
00310
00311 template <size_t N>
00312 inline Matrix<N> multiply(const Matrix<N>& a, const Matrix<N>& b) {
00313 Matrix<N> tmp;
00314 tmp._matrix = a._matrix * b._matrix;
00315 return tmp;
00316 }
00317
00318 template <size_t N>
00319 inline Matrix<N> operator*(const Matrix<N>& a, const Matrix<N>& b) {
00320 return multiply(a, b);
00321 }
00322
00323
00324 template <size_t N>
00325 inline Vector<N> multiply(const Matrix<N>& a, const Vector<N>& b) {
00326 Vector<N> tmp;
00327 tmp._vec = a._matrix * b._vec;
00328 return tmp;
00329 }
00330
00331 template <size_t N>
00332 inline Vector<N> operator*(const Matrix<N>& a, const Vector<N>& b) {
00333 return multiply(a, b);
00334 }
00335
00336 template <size_t N>
00337 inline Matrix<N> transpose(const Matrix<N>& m) {
00338
00339
00340
00341
00342
00343
00344
00345 return m.transpose();
00346 }
00347
00348 template <size_t N>
00349 inline Matrix<N> inverse(const Matrix<N>& m) {
00350 return m.inverse();
00351 }
00352
00353 template <size_t N>
00354 inline double det(const Matrix<N>& m) {
00355 return m.determinant();
00356 }
00357
00358 template <size_t N>
00359 inline double trace(const Matrix<N>& m) {
00360 return m.trace();
00361 }
00362
00363
00364
00365
00366
00367
00368 template <size_t N>
00369 inline string toString(const Matrix<N>& m) {
00370 ostringstream ss;
00371 ss << "[ ";
00372 for (size_t i = 0; i < m.size(); ++i) {
00373 ss << "( ";
00374 for (size_t j = 0; j < m.size(); ++j) {
00375 const double e = m.get(i, j);
00376 ss << (Rivet::isZero(e) ? 0.0 : e) << " ";
00377 }
00378 ss << ") ";
00379 }
00380 ss << "]";
00381 return ss.str();
00382 }
00383
00384
00385
00386 template <size_t N>
00387 inline ostream& operator<<(std::ostream& out, const Matrix<N>& m) {
00388 out << toString(m);
00389 return out;
00390 }
00391
00392
00393
00394
00395
00396
00397 template <size_t N>
00398 inline bool fuzzyEquals(const Matrix<N>& ma, const Matrix<N>& mb, double tolerance=1E-5) {
00399 for (size_t i = 0; i < N; ++i) {
00400 for (size_t j = 0; j < N; ++j) {
00401 const double a = ma.get(i, j);
00402 const double b = mb.get(i, j);
00403 if (!Rivet::fuzzyEquals(a, b, tolerance)) return false;
00404 }
00405 }
00406 return true;
00407 }
00408
00409
00410
00411 template <size_t N>
00412 inline bool isZero(const Matrix<N>& m, double tolerance=1E-5) {
00413 return m.isZero(tolerance);
00414 }
00415
00416
00417 }
00418
00419 #endif