00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 
00014 
00015 
00016 
00017 
00018 
00019 
00020 
00021 
00022 
00024 
00025 
00026 #ifndef _LA_TRAITS_INCL
00027 #define _LA_TRAITS_INCL
00028 
00029 
00030 
00031 
00032 #include <stdio.h>
00033 #include <string.h>
00034 #include <iostream>
00035 #include <limits>
00036 #include <complex>
00037 
00038 
00039 
00040 #define complex std::complex
00041 
00042 #include "laerror.h"
00043 
00044 #include "cuda_la.h"
00045 
00046 #ifdef NONCBLAS
00047 #include "noncblas.h"
00048 #else
00049 extern "C" {
00050 #include "cblas.h"
00051 }
00052 #endif
00053 
00054 #ifdef NONCLAPACK
00055 #include "noncblas.h"
00056 #else
00057 extern "C" {
00058 #include "clapack.h"
00059 }
00060 #endif
00061 
00062 namespace LA {
00063 
00064 extern bool _LA_count_check;
00065 
00066 
00067 template<typename C> class NRVec;
00068 template<typename C> class NRMat;
00069 template<typename C> class NRMat_from1;
00070 template<typename C> class NRSMat;
00071 template<typename C> class NRSMat_from1;
00072 template<typename C> class SparseMat;
00073 template<typename C> class SparseSMat;
00074 
00075 
00076 typedef class {} Dummy_type;
00077 typedef class {} Dummy_type2;
00078 
00079 
00080 
00081 template<typename C>
00082 struct LA_traits_complex
00083         {
00084         typedef Dummy_type Component_type;
00085         typedef Dummy_type NRVec_Noncomplex_type;
00086         typedef Dummy_type NRMat_Noncomplex_type;
00087         typedef Dummy_type2 NRSMat_Noncomplex_type;
00088         };
00089 
00090 #define SPECIALIZE_COMPLEX(T) \
00091 template<> \
00092 struct LA_traits_complex<complex<T> > \
00093         { \
00094         typedef T Component_type; \
00095         typedef NRVec<T> NRVec_Noncomplex_type; \
00096         typedef NRMat<T> NRMat_Noncomplex_type; \
00097         typedef NRSMat<T> NRSMat_Noncomplex_type; \
00098         };
00099 
00100 
00101 SPECIALIZE_COMPLEX(double)
00102 SPECIALIZE_COMPLEX(complex<double>)
00103 SPECIALIZE_COMPLEX(float)
00104 SPECIALIZE_COMPLEX(complex<float>)
00105 SPECIALIZE_COMPLEX(char)
00106 SPECIALIZE_COMPLEX(unsigned char)
00107 SPECIALIZE_COMPLEX(short)
00108 SPECIALIZE_COMPLEX(unsigned short)
00109 SPECIALIZE_COMPLEX(int)
00110 SPECIALIZE_COMPLEX(unsigned int)
00111 SPECIALIZE_COMPLEX(long)
00112 SPECIALIZE_COMPLEX(unsigned long)
00113 SPECIALIZE_COMPLEX(long long)
00114 SPECIALIZE_COMPLEX(unsigned long long)
00115 
00116 
00117 
00118 template<typename T, typename I, int type> struct LA_sort_traits;
00119 
00120 template<typename T, typename I>
00121 struct LA_sort_traits<T,I,0>
00122         {
00123         static inline bool compare(T object, I i, I j) {return object.bigger(i,j);};
00124         };
00125 
00126 template<typename T, typename I>
00127 struct LA_sort_traits<T,I,1>
00128         {
00129         static inline bool compare(T object, I i, I j) {return object.smaller(i,j);};
00130         };
00131 
00132 
00133 
00134 template<typename C>
00135 struct LA_traits_io
00136         {
00137         typedef C IOtype;
00138         };
00139 
00140 template<>
00141 struct LA_traits_io<char>
00142         {
00143         typedef int IOtype;
00144         };
00145 
00146 template<>
00147 struct LA_traits_io<unsigned char>
00148         {
00149         typedef unsigned int IOtype;
00150         };
00151 
00152 
00153 
00154 
00155 
00156 class scalar_false {};
00157 class scalar_true {};
00158 
00159 
00160 template<typename C>
00161 class isscalar { public: typedef scalar_false scalar_type;};
00162 
00163 
00164 #define SCALAR(X) \
00165 template<>\
00166 class isscalar<X> {public: typedef scalar_true scalar_type;};\
00167 template<>\
00168 class isscalar<complex<X> > {public: typedef scalar_true scalar_type;};\
00169 template<>\
00170 class isscalar<complex<complex<X> > > {public: typedef scalar_true scalar_type;};\
00171 
00172 
00173 
00174 SCALAR(char)
00175 SCALAR(short)
00176 SCALAR(int)
00177 SCALAR(long)
00178 SCALAR(long long)
00179 SCALAR(unsigned char)
00180 SCALAR(unsigned short)
00181 SCALAR(unsigned int)
00182 SCALAR(unsigned long)
00183 SCALAR(unsigned long long)
00184 SCALAR(float)
00185 SCALAR(double)
00186 SCALAR(void *)
00187 
00188 #undef SCALAR
00189 
00190 
00191 
00192 template<typename C, typename Scalar> struct LA_traits_aux
00193         {
00194         typedef Dummy_type normtype;
00195         };
00196 
00197 
00198 
00203 
00204 
00205 
00206 template<typename C>
00207 struct LA_traits_aux<complex<C>, scalar_true> {
00208 typedef complex<C> elementtype;
00209 typedef complex<C> producttype;
00210 typedef C normtype;
00211 typedef C realtype;
00212 typedef complex<C> complextype;
00213 static inline C sqrabs(const complex<C> x) { return x.real()*x.real()+x.imag()*x.imag();}
00214 static inline bool gencmp(const complex<C> *x, const complex<C> *y, int n) {return memcmp(x,y,n*sizeof(complex<C>));}
00215 static bool bigger(const  complex<C> &x, const complex<C> &y) {laerror("complex comparison undefined"); return false;}
00216 static bool smaller(const  complex<C> &x, const complex<C> &y) {laerror("complex comparison undefined"); return false;}
00217 static inline normtype norm (const  complex<C> &x) {return std::abs(x);}
00218 static inline void axpy (complex<C> &s, const complex<C> &x, const complex<C> &c) {s+=x*c;}
00219 static inline void get(int fd, complex<C> &x, bool dimensions=0, bool transp=0) {if(sizeof(complex<C>)!=read(fd,&x,sizeof(complex<C>))) laerror("read error");}
00220 static inline void put(int fd, const complex<C> &x, bool dimensions=0, bool transp=0) {if(sizeof(complex<C>)!=write(fd,&x,sizeof(complex<C>))) laerror("write error");}
00221 static void multiget(unsigned int n,int fd, complex<C> *x, bool dimensions=0){if((ssize_t)(n*sizeof(complex<C>))!=read(fd,x,n*sizeof(complex<C>))) laerror("read error");}
00222 static void multiput(unsigned int n, int fd, const complex<C> *x, bool dimensions=0) {if((ssize_t)(n*sizeof(complex<C>))!=write(fd,x,n*sizeof(complex<C>))) laerror("write error");}
00223 static void copy(complex<C> *dest, complex<C> *src, unsigned int n) {memcpy(dest,src,n*sizeof(complex<C>));}
00224 static void clear(complex<C> *dest, unsigned int n) {memset(dest,0,n*sizeof(complex<C>));}
00225 static void copyonwrite(complex<C> &x) {};
00226 static void clearme(complex<C> &x) {x=0;};
00227 static inline complex<C> conjugate(const complex<C> &x) {return complex<C>(x.real(),-x.imag());};
00228 static inline C realpart(const complex<C> &x) {return x.real();}
00229 static inline C imagpart(const complex<C> &x) {return x.imag();}
00230 };
00231 
00232 
00233 template<typename C>
00234 struct LA_traits_aux<C, scalar_true> {
00235 typedef C elementtype;
00236 typedef C producttype;
00237 typedef C normtype;
00238 typedef C realtype;
00239 typedef complex<C> complextype;
00240 static inline C sqrabs(const C x) { return x*x;}
00241 static inline bool gencmp(const C *x, const C *y, int n) {return memcmp(x,y,n*sizeof(C));}
00242 static inline bool bigger(const  C &x, const C &y) {return x>y;}
00243 static inline bool smaller(const  C &x, const C &y) {return x<y;}
00244 static inline normtype norm (const  C &x) {return std::abs(x);}
00245 static inline void axpy (C &s, const C &x, const C &c) {s+=x*c;}
00246 static inline void put(int fd, const C &x, bool dimensions=0, bool transp=0) {if(sizeof(C)!=write(fd,&x,sizeof(C))) laerror("write error");}
00247 static inline void get(int fd, C &x, bool dimensions=0, bool transp=0) {if(sizeof(C)!=read(fd,&x,sizeof(C))) laerror("read error");}
00248 static void multiput(unsigned int n,int fd, const C *x, bool dimensions=0){if((ssize_t)(n*sizeof(C))!=write(fd,x,n*sizeof(C))) laerror("write error");}
00249 static void multiget(unsigned int n, int fd, C *x, bool dimensions=0) {if((ssize_t)(n*sizeof(C))!=read(fd,x,n*sizeof(C))) laerror("read error");}
00250 static void copy(C *dest, C *src, unsigned int n) {memcpy(dest,src,n*sizeof(C));}
00251 static void clear(C *dest, unsigned int n) {memset(dest,0,n*sizeof(C));}
00252 static void copyonwrite(C &x) {};
00253 static void clearme(complex<C> &x) {x=0;};
00254 static inline C conjugate(const C &x) {return x;};
00255 static inline C realpart(const C &x) {return x;}
00256 static inline C imagpart(const C &x) {return 0;}
00257 };
00258 
00259 
00260 
00261 
00262 template<typename C>
00263 struct LA_traits; 
00264 
00265 #define generate_traits(X) \
00266 template<typename C> \
00267 struct LA_traits_aux<X<C>, scalar_false> { \
00268 typedef C elementtype; \
00269 typedef X<C> producttype; \
00270 typedef typename LA_traits<C>::normtype normtype; \
00271 typedef X<typename LA_traits<C>::realtype> realtype; \
00272 typedef X<typename LA_traits<C>::complextype> complextype; \
00273 static bool gencmp(const C *x, const C *y, int n) {for(int i=0; i<n; ++i) if(x[i]!=y[i]) return true; return false;} \
00274 static inline bool bigger(const  C &x, const C &y) {return x>y;} \
00275 static inline bool smaller(const  C &x, const C &y) {return x<y;} \
00276 static inline normtype norm (const X<C> &x) {return x.norm();} \
00277 static inline void axpy (X<C>&s, const X<C> &x, const C c) {s.axpy(c,x);} \
00278 static void put(int fd, const X<C> &x, bool dimensions=1, bool transp=0) {x.put(fd,dimensions,transp);} \
00279 static void get(int fd, X<C> &x, bool dimensions=1, bool transp=0) {x.get(fd,dimensions,transp);} \
00280 static void multiput(unsigned int n,int fd, const X<C> *x, bool dimensions=1) {for(unsigned int i=0; i<n; ++i) x[i].put(fd,dimensions);} \
00281 static void multiget(unsigned int n,int fd, X<C> *x, bool dimensions=1) {for(unsigned int i=0; i<n; ++i) x[i].get(fd,dimensions);} \
00282 static void copy(C *dest, C *src, unsigned int n) {for(unsigned int i=0; i<n; ++i) dest[i]=src[i];} \
00283 static void clear(C *dest, unsigned int n) {for(unsigned int i=0; i<n; ++i) dest[i].clear();}\
00284 static void copyonwrite(X<C> &x) {x.copyonwrite();}\
00285 static void clearme(X<C> &x) {x.clear();}\
00286 };
00287 
00288 
00289 
00290 generate_traits(NRMat)
00291 generate_traits(NRMat_from1)
00292 generate_traits(NRVec)
00293 generate_traits(SparseMat)
00294 generate_traits(SparseSMat) 
00295 
00296 #undef generate_traits
00297 
00298 
00299 #define generate_traits_smat(X) \
00300 template<typename C>  \
00301 struct LA_traits_aux<X<C>, scalar_false> {  \
00302 typedef C elementtype;  \
00303 typedef NRMat<C> producttype;  \
00304 typedef typename LA_traits<C>::normtype normtype;  \
00305 typedef X<typename LA_traits<C>::realtype> realtype; \
00306 typedef X<typename LA_traits<C>::complextype> complextype; \
00307 static bool gencmp(const C *x, const C *y, int n) {for(int i=0; i<n; ++i) if(x[i]!=y[i]) return true; return false;} \
00308 static inline bool bigger(const  C &x, const C &y) {return x>y;} \
00309 static inline bool smaller(const  C &x, const C &y) {return x<y;} \
00310 static inline normtype norm (const X<C> &x) {return x.norm();}  \
00311 static inline void axpy (X<C>&s, const X<C> &x, const C c) {s.axpy(c,x);}  \
00312 static void put(int fd, const X<C> &x, bool dimensions=1, bool transp=0) {x.put(fd,dimensions);}  \
00313 static void get(int fd, X<C> &x, bool dimensions=1, bool transp=0) {x.get(fd,dimensions);}  \
00314 static void multiput(unsigned int n,int fd, const X<C> *x, bool dimensions=1) {for(unsigned int i=0; i<n; ++i) x[i].put(fd,dimensions);}  \
00315 static void multiget(unsigned int n,int fd, X<C> *x, bool dimensions=1) {for(unsigned int i=0; i<n; ++i) x[i].get(fd,dimensions);}  \
00316 static void copy(C *dest, C *src, unsigned int n) {for(unsigned int i=0; i<n; ++i) dest[i]=src[i];}  \
00317 static void clear(C *dest, unsigned int n) {for(unsigned int i=0; i<n; ++i) dest[i].clear();} \
00318 static void copyonwrite(X<C> &x) {x.copyonwrite();} \
00319 static void clearme(X<C> &x) {x.clear();} \
00320 };
00321 
00322 generate_traits_smat(NRSMat)
00323 generate_traits_smat(NRSMat_from1)
00324 
00325 
00326 
00327 template<typename C>
00328 struct LA_traits : LA_traits_aux<C, typename isscalar<C>::scalar_type> {};
00329 
00330 }
00331 
00332 #endif