[Top][All Lists]
[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]
[Toon-members] TooN internal/operators.hh test/mmult_test.cc
From: |
Tom Drummond |
Subject: |
[Toon-members] TooN internal/operators.hh test/mmult_test.cc |
Date: |
Mon, 09 Mar 2009 11:55:26 +0000 |
CVSROOT: /cvsroot/toon
Module name: TooN
Changes by: Tom Drummond <twd20> 09/03/09 11:55:26
Modified files:
internal : operators.hh
test : mmult_test.cc
Log message:
Added Matrix * Vector
Vector * Matrix (distinct to handle non commuting data types)
and cleaned up some of the template usage for Matrix*Matrix
CVSWeb URLs:
http://cvs.savannah.gnu.org/viewcvs/TooN/internal/operators.hh?cvsroot=toon&r1=1.10&r2=1.11
http://cvs.savannah.gnu.org/viewcvs/TooN/test/mmult_test.cc?cvsroot=toon&r1=1.1&r2=1.2
Patches:
Index: internal/operators.hh
===================================================================
RCS file: /cvsroot/toon/TooN/internal/operators.hh,v
retrieving revision 1.10
retrieving revision 1.11
diff -u -b -r1.10 -r1.11
--- internal/operators.hh 6 Mar 2009 12:42:35 -0000 1.10
+++ internal/operators.hh 9 Mar 2009 11:55:25 -0000 1.11
@@ -74,9 +74,9 @@
};
//FIXME what about BLAS?
- template<typename Precision> struct MatrixMultiply
+ struct MatrixMultiply
{
- template<int R, int C, typename B, int R1, int C1, typename P1,
typename B1, int R2, int C2, typename P2, typename B2>
+ template<int R, int C, typename Precision, typename B, int R1,
int C1, typename P1, typename B1, int R2, int C2, typename P2, typename B2>
static void eval(Matrix<R, C, Precision, B>& res, const
Matrix<R1, C1, P1, B1>& m1, const Matrix<R2, C2, P2, B2>& m2)
{
for(int i=0; i < res.num_rows(); ++i)
@@ -85,6 +85,30 @@
}
};
+ struct MatrixVectorMultiply
+ {
+ template<int Sout, typename Pout, typename Bout, int R, int C,
int Size, typename P1, typename P2, typename B1, typename B2>
+ static void eval(Vector<Sout, Pout, Bout>& res, const Matrix<R,
C, P1, B1>& m, const Vector<Size, P2, B2>& v)
+ {
+ for(int i=0; i < res.size(); ++i){
+ res[i] = m[i] * v;
+ }
+ }
+ };
+
+ // this is distinct to cater for non communing precision types
+ struct VectorMatrixMultiply
+ {
+ template<int Sout, typename Pout, typename Bout, int R, int C,
int Size, typename P1, typename P2, typename B1, typename B2>
+ static void eval(Vector<Sout, Pout, Bout>& res, const
Vector<Size, P2, B2>& v, const Matrix<R, C, P1, B1>& m)
+ {
+ for(int i=0; i < res.size(); ++i){
+ res[i] = v * m[i];
+ }
+ }
+ };
+
+
//Mini operators for passing to Pairwise, etc
struct Add{ template<class A, class B, class C> static A op(const
B& b, const C& c){return b+c;} };
struct Subtract{ template<class A, class B, class C> static A op(const
B& b, const C& c){return b-c;} };
@@ -164,13 +188,31 @@
// Matrix multiplication Matrix * Matrix
template<int R1, int C1, int R2, int C2, typename P1, typename P2, typename
B1, typename B2>
-Matrix<Internal::Sizer<R1,R1>::size, Internal::Sizer<C2,C2>::size, typename
Internal::MultiplyType<P1, P2>::type> operator*(const Matrix<R1, C1, P1, B1>&
m1, const Matrix<R2, C2, P2, B2>& m2)
+Matrix<R1, C2, typename Internal::MultiplyType<P1, P2>::type> operator*(const
Matrix<R1, C1, P1, B1>& m1, const Matrix<R2, C2, P2, B2>& m2)
{
typedef typename Internal::MultiplyType<P1, P2>::type restype;
SizeMismatch<R1, C2>:: test(m1.num_rows(),m2.num_cols());
SizeMismatch<C1, R2>:: test(m1.num_cols(),m2.num_rows());
- return Matrix<Internal::Sizer<R1,R1>::size,
Internal::Sizer<C2,C2>::size,restype>(m1, m2,
Operator<Internal::MatrixMultiply<restype> >(), m1.num_rows(), m2.num_cols());
+ return Matrix<Internal::Sizer<R1,R1>::size,
Internal::Sizer<C2,C2>::size,restype>(m1, m2,
Operator<Internal::MatrixMultiply>(), m1.num_rows(), m2.num_cols());
+}
+
+// Matrix Vector multiplication Matrix * Vector
+
+template<int R, int C, int Size, typename P1, typename P2, typename B1,
typename B2>
+Vector<R, typename Internal::MultiplyType<P1,P2>::type> operator*(const
Matrix<R, C, P1, B1>& m, const Vector<Size, P2, B2>& v)
+{
+ SizeMismatch<C,Size>::test(m.num_cols(), v.size());
+ return Vector<R, typename Internal::MultiplyType<P1,P2>::type> (m, v,
Operator<Internal::MatrixVectorMultiply>(), m.num_rows() );
+}
+
+// Vector Matrix multiplication Vector * Matrix
+
+template<int Size, int R, int C, typename P1, typename P2, typename B1,
typename B2>
+Vector<C, typename Internal::MultiplyType<P1,P2>::type> operator*(const
Vector<Size, P1, B1>& v, const Matrix<R, C, P2, B2>& m)
+{
+ SizeMismatch<R,Size>::test(m.num_rows(), v.size());
+ return Vector<C, typename Internal::MultiplyType<P1,P2>::type> (v, m,
Operator<Internal::VectorMatrixMultiply>(), m.num_cols() );
}
Index: test/mmult_test.cc
===================================================================
RCS file: /cvsroot/toon/TooN/test/mmult_test.cc,v
retrieving revision 1.1
retrieving revision 1.2
diff -u -b -r1.1 -r1.2
--- test/mmult_test.cc 27 Feb 2009 09:45:47 -0000 1.1
+++ test/mmult_test.cc 9 Mar 2009 11:55:26 -0000 1.2
@@ -24,11 +24,20 @@
m4[1] = makeVector(8, 9);
m4[2] = makeVector(10, 11);
+ Vector<V(a,3)> v(3);
+ v = makeVector(6,8,10);
+
cout << m3<<endl;
cout << m4<<endl;
cout << m3*m4;
cout << "\n should be: \n 28 31\n 100 112\n";
+
+ cout << endl << v << endl;
+ cout << endl << m3*v << endl;
+
+ cout << "\n should be: \n 28 100\n" << endl;
+
}
int main()
[Prev in Thread] |
Current Thread |
[Next in Thread] |
- [Toon-members] TooN internal/operators.hh test/mmult_test.cc,
Tom Drummond <=