00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00024
00025
00026 #ifndef _LA_TRAITS_INCL
00027 #define _LA_TRAITS_INCL
00028
00029
00030
00031
00032 #include <stdio.h>
00033 #include <string.h>
00034 #include <iostream>
00035 #include <limits>
00036 #include <complex>
00037
00038
00039
00040 #define complex std::complex
00041
00042 #include "laerror.h"
00043
00044 #include "cuda_la.h"
00045
00046 #ifdef NONCBLAS
00047 #include "noncblas.h"
00048 #else
00049 extern "C" {
00050 #include "cblas.h"
00051 }
00052 #endif
00053
00054 #ifdef NONCLAPACK
00055 #include "noncblas.h"
00056 #else
00057 extern "C" {
00058 #include "clapack.h"
00059 }
00060 #endif
00061
00062 namespace LA {
00063
00064 extern bool _LA_count_check;
00065
00066
00067 template<typename C> class NRVec;
00068 template<typename C> class NRMat;
00069 template<typename C> class NRMat_from1;
00070 template<typename C> class NRSMat;
00071 template<typename C> class NRSMat_from1;
00072 template<typename C> class SparseMat;
00073 template<typename C> class SparseSMat;
00074
00075
00076 typedef class {} Dummy_type;
00077 typedef class {} Dummy_type2;
00078
00079
00080
00081 template<typename C>
00082 struct LA_traits_complex
00083 {
00084 typedef Dummy_type Component_type;
00085 typedef Dummy_type NRVec_Noncomplex_type;
00086 typedef Dummy_type NRMat_Noncomplex_type;
00087 typedef Dummy_type2 NRSMat_Noncomplex_type;
00088 };
00089
00090 #define SPECIALIZE_COMPLEX(T) \
00091 template<> \
00092 struct LA_traits_complex<complex<T> > \
00093 { \
00094 typedef T Component_type; \
00095 typedef NRVec<T> NRVec_Noncomplex_type; \
00096 typedef NRMat<T> NRMat_Noncomplex_type; \
00097 typedef NRSMat<T> NRSMat_Noncomplex_type; \
00098 };
00099
00100
00101 SPECIALIZE_COMPLEX(double)
00102 SPECIALIZE_COMPLEX(complex<double>)
00103 SPECIALIZE_COMPLEX(float)
00104 SPECIALIZE_COMPLEX(complex<float>)
00105 SPECIALIZE_COMPLEX(char)
00106 SPECIALIZE_COMPLEX(unsigned char)
00107 SPECIALIZE_COMPLEX(short)
00108 SPECIALIZE_COMPLEX(unsigned short)
00109 SPECIALIZE_COMPLEX(int)
00110 SPECIALIZE_COMPLEX(unsigned int)
00111 SPECIALIZE_COMPLEX(long)
00112 SPECIALIZE_COMPLEX(unsigned long)
00113 SPECIALIZE_COMPLEX(long long)
00114 SPECIALIZE_COMPLEX(unsigned long long)
00115
00116
00117
00118 template<typename T, typename I, int type> struct LA_sort_traits;
00119
00120 template<typename T, typename I>
00121 struct LA_sort_traits<T,I,0>
00122 {
00123 static inline bool compare(T object, I i, I j) {return object.bigger(i,j);};
00124 };
00125
00126 template<typename T, typename I>
00127 struct LA_sort_traits<T,I,1>
00128 {
00129 static inline bool compare(T object, I i, I j) {return object.smaller(i,j);};
00130 };
00131
00132
00133
00134 template<typename C>
00135 struct LA_traits_io
00136 {
00137 typedef C IOtype;
00138 };
00139
00140 template<>
00141 struct LA_traits_io<char>
00142 {
00143 typedef int IOtype;
00144 };
00145
00146 template<>
00147 struct LA_traits_io<unsigned char>
00148 {
00149 typedef unsigned int IOtype;
00150 };
00151
00152
00153
00154
00155
00156 class scalar_false {};
00157 class scalar_true {};
00158
00159
00160 template<typename C>
00161 class isscalar { public: typedef scalar_false scalar_type;};
00162
00163
00164 #define SCALAR(X) \
00165 template<>\
00166 class isscalar<X> {public: typedef scalar_true scalar_type;};\
00167 template<>\
00168 class isscalar<complex<X> > {public: typedef scalar_true scalar_type;};\
00169 template<>\
00170 class isscalar<complex<complex<X> > > {public: typedef scalar_true scalar_type;};\
00171
00172
00173
00174 SCALAR(char)
00175 SCALAR(short)
00176 SCALAR(int)
00177 SCALAR(long)
00178 SCALAR(long long)
00179 SCALAR(unsigned char)
00180 SCALAR(unsigned short)
00181 SCALAR(unsigned int)
00182 SCALAR(unsigned long)
00183 SCALAR(unsigned long long)
00184 SCALAR(float)
00185 SCALAR(double)
00186 SCALAR(void *)
00187
00188 #undef SCALAR
00189
00190
00191
00192 template<typename C, typename Scalar> struct LA_traits_aux
00193 {
00194 typedef Dummy_type normtype;
00195 };
00196
00197
00198
00203
00204
00205
00206 template<typename C>
00207 struct LA_traits_aux<complex<C>, scalar_true> {
00208 typedef complex<C> elementtype;
00209 typedef complex<C> producttype;
00210 typedef C normtype;
00211 typedef C realtype;
00212 typedef complex<C> complextype;
00213 static inline C sqrabs(const complex<C> x) { return x.real()*x.real()+x.imag()*x.imag();}
00214 static inline bool gencmp(const complex<C> *x, const complex<C> *y, int n) {return memcmp(x,y,n*sizeof(complex<C>));}
00215 static bool bigger(const complex<C> &x, const complex<C> &y) {laerror("complex comparison undefined"); return false;}
00216 static bool smaller(const complex<C> &x, const complex<C> &y) {laerror("complex comparison undefined"); return false;}
00217 static inline normtype norm (const complex<C> &x) {return std::abs(x);}
00218 static inline void axpy (complex<C> &s, const complex<C> &x, const complex<C> &c) {s+=x*c;}
00219 static inline void get(int fd, complex<C> &x, bool dimensions=0, bool transp=0) {if(sizeof(complex<C>)!=read(fd,&x,sizeof(complex<C>))) laerror("read error");}
00220 static inline void put(int fd, const complex<C> &x, bool dimensions=0, bool transp=0) {if(sizeof(complex<C>)!=write(fd,&x,sizeof(complex<C>))) laerror("write error");}
00221 static void multiget(unsigned int n,int fd, complex<C> *x, bool dimensions=0){if((ssize_t)(n*sizeof(complex<C>))!=read(fd,x,n*sizeof(complex<C>))) laerror("read error");}
00222 static void multiput(unsigned int n, int fd, const complex<C> *x, bool dimensions=0) {if((ssize_t)(n*sizeof(complex<C>))!=write(fd,x,n*sizeof(complex<C>))) laerror("write error");}
00223 static void copy(complex<C> *dest, complex<C> *src, unsigned int n) {memcpy(dest,src,n*sizeof(complex<C>));}
00224 static void clear(complex<C> *dest, unsigned int n) {memset(dest,0,n*sizeof(complex<C>));}
00225 static void copyonwrite(complex<C> &x) {};
00226 static void clearme(complex<C> &x) {x=0;};
00227 static inline complex<C> conjugate(const complex<C> &x) {return complex<C>(x.real(),-x.imag());};
00228 static inline C realpart(const complex<C> &x) {return x.real();}
00229 static inline C imagpart(const complex<C> &x) {return x.imag();}
00230 };
00231
00232
00233 template<typename C>
00234 struct LA_traits_aux<C, scalar_true> {
00235 typedef C elementtype;
00236 typedef C producttype;
00237 typedef C normtype;
00238 typedef C realtype;
00239 typedef complex<C> complextype;
00240 static inline C sqrabs(const C x) { return x*x;}
00241 static inline bool gencmp(const C *x, const C *y, int n) {return memcmp(x,y,n*sizeof(C));}
00242 static inline bool bigger(const C &x, const C &y) {return x>y;}
00243 static inline bool smaller(const C &x, const C &y) {return x<y;}
00244 static inline normtype norm (const C &x) {return std::abs(x);}
00245 static inline void axpy (C &s, const C &x, const C &c) {s+=x*c;}
00246 static inline void put(int fd, const C &x, bool dimensions=0, bool transp=0) {if(sizeof(C)!=write(fd,&x,sizeof(C))) laerror("write error");}
00247 static inline void get(int fd, C &x, bool dimensions=0, bool transp=0) {if(sizeof(C)!=read(fd,&x,sizeof(C))) laerror("read error");}
00248 static void multiput(unsigned int n,int fd, const C *x, bool dimensions=0){if((ssize_t)(n*sizeof(C))!=write(fd,x,n*sizeof(C))) laerror("write error");}
00249 static void multiget(unsigned int n, int fd, C *x, bool dimensions=0) {if((ssize_t)(n*sizeof(C))!=read(fd,x,n*sizeof(C))) laerror("read error");}
00250 static void copy(C *dest, C *src, unsigned int n) {memcpy(dest,src,n*sizeof(C));}
00251 static void clear(C *dest, unsigned int n) {memset(dest,0,n*sizeof(C));}
00252 static void copyonwrite(C &x) {};
00253 static void clearme(complex<C> &x) {x=0;};
00254 static inline C conjugate(const C &x) {return x;};
00255 static inline C realpart(const C &x) {return x;}
00256 static inline C imagpart(const C &x) {return 0;}
00257 };
00258
00259
00260
00261
00262 template<typename C>
00263 struct LA_traits;
00264
00265 #define generate_traits(X) \
00266 template<typename C> \
00267 struct LA_traits_aux<X<C>, scalar_false> { \
00268 typedef C elementtype; \
00269 typedef X<C> producttype; \
00270 typedef typename LA_traits<C>::normtype normtype; \
00271 typedef X<typename LA_traits<C>::realtype> realtype; \
00272 typedef X<typename LA_traits<C>::complextype> complextype; \
00273 static bool gencmp(const C *x, const C *y, int n) {for(int i=0; i<n; ++i) if(x[i]!=y[i]) return true; return false;} \
00274 static inline bool bigger(const C &x, const C &y) {return x>y;} \
00275 static inline bool smaller(const C &x, const C &y) {return x<y;} \
00276 static inline normtype norm (const X<C> &x) {return x.norm();} \
00277 static inline void axpy (X<C>&s, const X<C> &x, const C c) {s.axpy(c,x);} \
00278 static void put(int fd, const X<C> &x, bool dimensions=1, bool transp=0) {x.put(fd,dimensions,transp);} \
00279 static void get(int fd, X<C> &x, bool dimensions=1, bool transp=0) {x.get(fd,dimensions,transp);} \
00280 static void multiput(unsigned int n,int fd, const X<C> *x, bool dimensions=1) {for(unsigned int i=0; i<n; ++i) x[i].put(fd,dimensions);} \
00281 static void multiget(unsigned int n,int fd, X<C> *x, bool dimensions=1) {for(unsigned int i=0; i<n; ++i) x[i].get(fd,dimensions);} \
00282 static void copy(C *dest, C *src, unsigned int n) {for(unsigned int i=0; i<n; ++i) dest[i]=src[i];} \
00283 static void clear(C *dest, unsigned int n) {for(unsigned int i=0; i<n; ++i) dest[i].clear();}\
00284 static void copyonwrite(X<C> &x) {x.copyonwrite();}\
00285 static void clearme(X<C> &x) {x.clear();}\
00286 };
00287
00288
00289
00290 generate_traits(NRMat)
00291 generate_traits(NRMat_from1)
00292 generate_traits(NRVec)
00293 generate_traits(SparseMat)
00294 generate_traits(SparseSMat)
00295
00296 #undef generate_traits
00297
00298
00299 #define generate_traits_smat(X) \
00300 template<typename C> \
00301 struct LA_traits_aux<X<C>, scalar_false> { \
00302 typedef C elementtype; \
00303 typedef NRMat<C> producttype; \
00304 typedef typename LA_traits<C>::normtype normtype; \
00305 typedef X<typename LA_traits<C>::realtype> realtype; \
00306 typedef X<typename LA_traits<C>::complextype> complextype; \
00307 static bool gencmp(const C *x, const C *y, int n) {for(int i=0; i<n; ++i) if(x[i]!=y[i]) return true; return false;} \
00308 static inline bool bigger(const C &x, const C &y) {return x>y;} \
00309 static inline bool smaller(const C &x, const C &y) {return x<y;} \
00310 static inline normtype norm (const X<C> &x) {return x.norm();} \
00311 static inline void axpy (X<C>&s, const X<C> &x, const C c) {s.axpy(c,x);} \
00312 static void put(int fd, const X<C> &x, bool dimensions=1, bool transp=0) {x.put(fd,dimensions);} \
00313 static void get(int fd, X<C> &x, bool dimensions=1, bool transp=0) {x.get(fd,dimensions);} \
00314 static void multiput(unsigned int n,int fd, const X<C> *x, bool dimensions=1) {for(unsigned int i=0; i<n; ++i) x[i].put(fd,dimensions);} \
00315 static void multiget(unsigned int n,int fd, X<C> *x, bool dimensions=1) {for(unsigned int i=0; i<n; ++i) x[i].get(fd,dimensions);} \
00316 static void copy(C *dest, C *src, unsigned int n) {for(unsigned int i=0; i<n; ++i) dest[i]=src[i];} \
00317 static void clear(C *dest, unsigned int n) {for(unsigned int i=0; i<n; ++i) dest[i].clear();} \
00318 static void copyonwrite(X<C> &x) {x.copyonwrite();} \
00319 static void clearme(X<C> &x) {x.clear();} \
00320 };
00321
00322 generate_traits_smat(NRSMat)
00323 generate_traits_smat(NRSMat_from1)
00324
00325
00326
00327 template<typename C>
00328 struct LA_traits : LA_traits_aux<C, typename isscalar<C>::scalar_type> {};
00329
00330 }
00331
00332 #endif