getfem-commits
[Top][All Lists]
Advanced

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

[Getfem-commits] [getfem-commits] branch master updated: Fix BLAS interf


From: Konstantinos Poulios
Subject: [Getfem-commits] [getfem-commits] branch master updated: Fix BLAS interface not supported by MKL and avoid nested macros
Date: Mon, 25 Mar 2024 11:55:22 -0400

This is an automated email from the git hooks/post-receive script.

logari81 pushed a commit to branch master
in repository getfem.

The following commit(s) were added to refs/heads/master by this push:
     new d8be9712 Fix BLAS interface not supported by MKL and avoid nested 
macros
d8be9712 is described below

commit d8be97120c1ba18729ac9c00c2a24a6b65775e3a
Author: Konstantinos Poulios <logari81@gmail.com>
AuthorDate: Mon Mar 25 16:55:10 2024 +0100

    Fix BLAS interface not supported by MKL and avoid nested macros
---
 src/gmm/gmm_blas_interface.h | 141 ++++++++++++++++++++-----------------------
 1 file changed, 64 insertions(+), 77 deletions(-)

diff --git a/src/gmm/gmm_blas_interface.h b/src/gmm/gmm_blas_interface.h
index 07861ec5..44f99ae3 100644
--- a/src/gmm/gmm_blas_interface.h
+++ b/src/gmm/gmm_blas_interface.h
@@ -196,115 +196,102 @@ namespace gmm {
   nrm2_interface(dznrm2_, BLAS_Z)
 
   /* ********************************************************************* */
-  /* vect_sp(x, y).                                                        */
+  /* vect_sp(x,y) = vect_hp(x,y) for real vectors                          */
   /* ********************************************************************* */
 
-# define dot_interface(blas_name, base_type)                               \
-  inline base_type vect_sp(const std::vector<base_type> &x,                \
-                           const std::vector<base_type> &y) {              \
-    GMMLAPACK_TRACE("dot_interface");                                      \
+# define dot_interface(funcname, msg, blas_name, base_type)                \
+  inline base_type funcname(const std::vector<base_type> &x,               \
+                            const std::vector<base_type> &y) {             \
+    GMMLAPACK_TRACE(msg);                                                  \
     BLAS_INT inc(1), n(BLAS_INT(vect_size(y)));                            \
     return blas_name(&n, &x[0], &inc, &y[0], &inc);                        \
   }                                                                        \
-  inline base_type vect_sp                                                 \
+  inline base_type funcname                                                \
    (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_,   \
     const std::vector<base_type> &y) {                                     \
-    GMMLAPACK_TRACE("dot_interface");                                      \
+    GMMLAPACK_TRACE(msg);                                                  \
     const std::vector<base_type> &x = *(linalg_origin(x_));                \
     base_type a(x_.r);                                                     \
     BLAS_INT inc(1), n(BLAS_INT(vect_size(y)));                            \
-    return a* blas_name(&n, &x[0], &inc, &y[0], &inc);                     \
+    return a * blas_name(&n, &x[0], &inc, &y[0], &inc);                    \
   }                                                                        \
-  inline base_type vect_sp                                                 \
+  inline base_type funcname                                                \
     (const std::vector<base_type> &x,                                      \
      const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
-    GMMLAPACK_TRACE("dot_interface");                                      \
+    GMMLAPACK_TRACE(msg);                                                  \
     const std::vector<base_type> &y = *(linalg_origin(y_));                \
     base_type b(y_.r);                                                     \
     BLAS_INT inc(1), n(BLAS_INT(vect_size(y)));                            \
-    return b* blas_name(&n, &x[0], &inc, &y[0], &inc);                     \
+    return b * blas_name(&n, &x[0], &inc, &y[0], &inc);                    \
   }                                                                        \
-  inline base_type vect_sp                                                 \
+  inline base_type funcname                                                \
     (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_,  \
      const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
-    GMMLAPACK_TRACE("dot_interface");                                      \
+    GMMLAPACK_TRACE(msg);                                                  \
     const std::vector<base_type> &x = *(linalg_origin(x_));                \
-    base_type a(x_.r);                                                     \
     const std::vector<base_type> &y = *(linalg_origin(y_));                \
-    base_type b(y_.r);                                                     \
+    base_type a(x_.r), b(y_.r);                                            \
     BLAS_INT inc(1), n(BLAS_INT(vect_size(y)));                            \
-    return a* b* blas_name(&n, &x[0], &inc, &y[0], &inc);                  \
+    return a*b * blas_name(&n, &x[0], &inc, &y[0], &inc);                  \
   }
 
-  dot_interface(sdot_,  BLAS_S)
-  dot_interface(ddot_,  BLAS_D)
-  dot_interface(cdotu_, BLAS_C)
-  dot_interface(zdotu_, BLAS_Z)
+  dot_interface(vect_sp, "dot_interface", sdot_,  BLAS_S)
+  dot_interface(vect_sp, "dot_interface", ddot_,  BLAS_D)
+  dot_interface(vect_hp, "dotc_interface", sdot_,  BLAS_S)
+  dot_interface(vect_hp, "dotc_interface", ddot_,  BLAS_D)
 
   /* ********************************************************************* */
-  /* vect_hp(x, y).                                                        */
+  /* vect_sp(x,y) and vect_hp(x,y) for complex vectors                     */
+  /* vect_hp(x, y) = x.conj(y) (different order than in BLAS)              */
+  /* switching x,y before passed to BLAS is important only for vect_hp     */
   /* ********************************************************************* */
 
-# define dotc_interface(param1, trans1, mult1, param2, trans2, mult2,      \
-                        blas_name, base_type)                              \
-  inline base_type vect_hp(param1(base_type), param2(base_type)) {         \
-    GMMLAPACK_TRACE("dotc_interface");                                     \
-    trans1(base_type); trans2(base_type);                                  \
+# define dot_interface_cplx(funcname, msg, blas_name, base_type, bdef)     \
+  inline base_type funcname(const std::vector<base_type> &x,               \
+                            const std::vector<base_type> &y) {             \
+    GMMLAPACK_TRACE(msg);                                                  \
+    base_type res;                                                         \
     BLAS_INT inc(1), n(BLAS_INT(vect_size(y)));                            \
-    return mult1 mult2 blas_name(&n, &x[0], &inc, &y[0], &inc);            \
+    blas_name(&res, &n, &y[0], &inc, &x[0], &inc);                         \
+    return res;                                                            \
+  }                                                                        \
+  inline base_type funcname                                                \
+   (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_,   \
+    const std::vector<base_type> &y) {                                     \
+    GMMLAPACK_TRACE(msg);                                                  \
+    const std::vector<base_type> &x = *(linalg_origin(x_));                \
+    base_type res, a(x_.r);                                                \
+    BLAS_INT inc(1), n(BLAS_INT(vect_size(y)));                            \
+    blas_name(&res, &n, &y[0], &inc, &x[0], &inc);                         \
+    return a*res;                                                          \
+  }                                                                        \
+  inline base_type funcname                                                \
+    (const std::vector<base_type> &x,                                      \
+     const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
+    GMMLAPACK_TRACE(msg);                                                  \
+    const std::vector<base_type> &y = *(linalg_origin(y_));                \
+    base_type res, b(bdef);                                                \
+    BLAS_INT inc(1), n(BLAS_INT(vect_size(y)));                            \
+    blas_name(&res, &n, &y[0], &inc, &x[0], &inc);                         \
+    return b*res;                                                          \
+  }                                                                        \
+  inline base_type funcname                                                \
+    (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_,  \
+     const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
+    GMMLAPACK_TRACE(msg);                                                  \
+    const std::vector<base_type> &x = *(linalg_origin(x_));                \
+    const std::vector<base_type> &y = *(linalg_origin(y_));                \
+    base_type res, a(x_.r), b(bdef);                                       \
+    BLAS_INT inc(1), n(BLAS_INT(vect_size(y)));                            \
+    blas_name(&res, &n, &y[0], &inc, &x[0], &inc);                         \
+    return a*b*res;                                                        \
   }
 
-# define dotc_p1(base_type) const std::vector<base_type> &x
-# define dotc_trans1(base_type)
-# define dotc_p1_s(base_type)                                              \
-    const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_
-# define dotc_trans1_s(base_type)                                          \
-         const std::vector<base_type> &x = *(linalg_origin(x_));           \
-         base_type a(x_.r)
+  dot_interface_cplx(vect_sp, "dot_interface", cdotu_, BLAS_C, y_.r)
+  dot_interface_cplx(vect_sp, "dot_interface", zdotu_, BLAS_Z, y_.r)
+  dot_interface_cplx(vect_hp, "dotc_interface", cdotc_, BLAS_C, 
gmm::conj(y_.r))
+  dot_interface_cplx(vect_hp, "dotc_interface", zdotc_, BLAS_Z, 
gmm::conj(y_.r))
 
-# define dotc_p2(base_type) const std::vector<base_type> &y
-# define dotc_trans2(base_type)
-# define dotc_p2_s(base_type)                                              \
-    const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_
-# define dotc_trans2_s(base_type)                                          \
-         const std::vector<base_type> &y = *(linalg_origin(y_));           \
-         base_type b(gmm::conj(y_.r))
-
-  dotc_interface(dotc_p1, dotc_trans1, (BLAS_S),
-                 dotc_p2, dotc_trans2, (BLAS_S), sdot_,  BLAS_S)
-  dotc_interface(dotc_p1, dotc_trans1, (BLAS_D),
-                 dotc_p2, dotc_trans2, (BLAS_D), ddot_,  BLAS_D)
-  dotc_interface(dotc_p2, dotc_trans2, (BLAS_C),
-                 dotc_p1, dotc_trans1, (BLAS_C), cdotc_, BLAS_C)
-  dotc_interface(dotc_p2, dotc_trans2, (BLAS_Z),
-                 dotc_p1, dotc_trans1, (BLAS_Z), zdotc_, BLAS_Z)
-
-  dotc_interface(dotc_p1_s, dotc_trans1_s, a*,
-                 dotc_p2,   dotc_trans2,   (BLAS_S), sdot_,  BLAS_S)
-  dotc_interface(dotc_p1_s, dotc_trans1_s, a*,
-                 dotc_p2,   dotc_trans2,   (BLAS_D), ddot_,  BLAS_D)
-  dotc_interface(dotc_p2,   dotc_trans2,   (BLAS_C),
-                 dotc_p1_s, dotc_trans1_s, a*,       cdotc_, BLAS_C)
-  dotc_interface(dotc_p2,   dotc_trans2,   (BLAS_Z),
-                 dotc_p1_s, dotc_trans1_s, a*,       zdotc_, BLAS_Z)
-
-  dotc_interface(dotc_p1,   dotc_trans1,   (BLAS_S),
-                 dotc_p2_s, dotc_trans2_s, b*,       sdot_,  BLAS_S)
-  dotc_interface(dotc_p1,   dotc_trans1,   (BLAS_D),
-                 dotc_p2_s, dotc_trans2_s, b*,       ddot_,  BLAS_D)
-  dotc_interface(dotc_p2_s, dotc_trans2_s, b*,
-                 dotc_p1,   dotc_trans1,   (BLAS_C), cdotc_, BLAS_C)
-  dotc_interface(dotc_p2_s, dotc_trans2_s, b*,
-                 dotc_p1,   dotc_trans1,   (BLAS_Z), zdotc_, BLAS_Z)
-
-  dotc_interface(dotc_p1_s, dotc_trans1_s, a*,
-                 dotc_p2_s, dotc_trans2_s, b*, sdot_,  BLAS_S)
-  dotc_interface(dotc_p1_s, dotc_trans1_s, a*,
-                 dotc_p2_s, dotc_trans2_s, b*, ddot_,  BLAS_D)
-  dotc_interface(dotc_p2_s, dotc_trans2_s, b*,
-                 dotc_p1_s, dotc_trans1_s, a*, cdotc_, BLAS_C)
-  dotc_interface(dotc_p2_s, dotc_trans2_s, b*,
-                 dotc_p1_s, dotc_trans1_s, a*, zdotc_, BLAS_Z)
 
   /* ********************************************************************* */
   /* add(x, y).                                                            */



reply via email to

[Prev in Thread] Current Thread [Next in Thread]