Skip to content
...@@ -25,13 +25,13 @@ ...@@ -25,13 +25,13 @@
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
******************************************************************************** ********************************************************************************
* Content : Eigen bindings to Intel(R) MKL * Content : Eigen bindings to BLAS F77
* Selfadjoint matrix-vector product functionality based on ?SYMV/HEMV. * Selfadjoint matrix-vector product functionality based on ?SYMV/HEMV.
******************************************************************************** ********************************************************************************
*/ */
#ifndef EIGEN_SELFADJOINT_MATRIX_VECTOR_MKL_H #ifndef EIGEN_SELFADJOINT_MATRIX_VECTOR_BLAS_H
#define EIGEN_SELFADJOINT_MATRIX_VECTOR_MKL_H #define EIGEN_SELFADJOINT_MATRIX_VECTOR_BLAS_H
namespace Eigen { namespace Eigen {
...@@ -47,31 +47,31 @@ template<typename Scalar, typename Index, int StorageOrder, int UpLo, bool Conju ...@@ -47,31 +47,31 @@ template<typename Scalar, typename Index, int StorageOrder, int UpLo, bool Conju
struct selfadjoint_matrix_vector_product_symv : struct selfadjoint_matrix_vector_product_symv :
selfadjoint_matrix_vector_product<Scalar,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs,BuiltIn> {}; selfadjoint_matrix_vector_product<Scalar,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs,BuiltIn> {};
#define EIGEN_MKL_SYMV_SPECIALIZE(Scalar) \ #define EIGEN_BLAS_SYMV_SPECIALIZE(Scalar) \
template<typename Index, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs> \ template<typename Index, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs> \
struct selfadjoint_matrix_vector_product<Scalar,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs,Specialized> { \ struct selfadjoint_matrix_vector_product<Scalar,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs,Specialized> { \
static void run( \ static void run( \
Index size, const Scalar* lhs, Index lhsStride, \ Index size, const Scalar* lhs, Index lhsStride, \
const Scalar* _rhs, Index rhsIncr, Scalar* res, Scalar alpha) { \ const Scalar* _rhs, Scalar* res, Scalar alpha) { \
enum {\ enum {\
IsColMajor = StorageOrder==ColMajor \ IsColMajor = StorageOrder==ColMajor \
}; \ }; \
if (IsColMajor == ConjugateLhs) {\ if (IsColMajor == ConjugateLhs) {\
selfadjoint_matrix_vector_product<Scalar,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs,BuiltIn>::run( \ selfadjoint_matrix_vector_product<Scalar,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs,BuiltIn>::run( \
size, lhs, lhsStride, _rhs, rhsIncr, res, alpha); \ size, lhs, lhsStride, _rhs, res, alpha); \
} else {\ } else {\
selfadjoint_matrix_vector_product_symv<Scalar,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs>::run( \ selfadjoint_matrix_vector_product_symv<Scalar,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs>::run( \
size, lhs, lhsStride, _rhs, rhsIncr, res, alpha); \ size, lhs, lhsStride, _rhs, res, alpha); \
}\ }\
} \ } \
}; \ }; \
EIGEN_MKL_SYMV_SPECIALIZE(double) EIGEN_BLAS_SYMV_SPECIALIZE(double)
EIGEN_MKL_SYMV_SPECIALIZE(float) EIGEN_BLAS_SYMV_SPECIALIZE(float)
EIGEN_MKL_SYMV_SPECIALIZE(dcomplex) EIGEN_BLAS_SYMV_SPECIALIZE(dcomplex)
EIGEN_MKL_SYMV_SPECIALIZE(scomplex) EIGEN_BLAS_SYMV_SPECIALIZE(scomplex)
#define EIGEN_MKL_SYMV_SPECIALIZATION(EIGTYPE,MKLTYPE,MKLFUNC) \ #define EIGEN_BLAS_SYMV_SPECIALIZATION(EIGTYPE,BLASTYPE,BLASFUNC) \
template<typename Index, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs> \ template<typename Index, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs> \
struct selfadjoint_matrix_vector_product_symv<EIGTYPE,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs> \ struct selfadjoint_matrix_vector_product_symv<EIGTYPE,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs> \
{ \ { \
...@@ -79,36 +79,33 @@ typedef Matrix<EIGTYPE,Dynamic,1,ColMajor> SYMVVector;\ ...@@ -79,36 +79,33 @@ typedef Matrix<EIGTYPE,Dynamic,1,ColMajor> SYMVVector;\
\ \
static void run( \ static void run( \
Index size, const EIGTYPE* lhs, Index lhsStride, \ Index size, const EIGTYPE* lhs, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsIncr, EIGTYPE* res, EIGTYPE alpha) \ const EIGTYPE* _rhs, EIGTYPE* res, EIGTYPE alpha) \
{ \ { \
enum {\ enum {\
IsRowMajor = StorageOrder==RowMajor ? 1 : 0, \ IsRowMajor = StorageOrder==RowMajor ? 1 : 0, \
IsLower = UpLo == Lower ? 1 : 0 \ IsLower = UpLo == Lower ? 1 : 0 \
}; \ }; \
MKL_INT n=size, lda=lhsStride, incx=rhsIncr, incy=1; \ BlasIndex n=convert_index<BlasIndex>(size), lda=convert_index<BlasIndex>(lhsStride), incx=1, incy=1; \
MKLTYPE alpha_, beta_; \ EIGTYPE beta(1); \
const EIGTYPE *x_ptr, myone(1); \ const EIGTYPE *x_ptr; \
char uplo=(IsRowMajor) ? (IsLower ? 'U' : 'L') : (IsLower ? 'L' : 'U'); \ char uplo=(IsRowMajor) ? (IsLower ? 'U' : 'L') : (IsLower ? 'L' : 'U'); \
assign_scalar_eig2mkl(alpha_, alpha); \
assign_scalar_eig2mkl(beta_, myone); \
SYMVVector x_tmp; \ SYMVVector x_tmp; \
if (ConjugateRhs) { \ if (ConjugateRhs) { \
Map<const SYMVVector, 0, InnerStride<> > map_x(_rhs,size,1,InnerStride<>(incx)); \ Map<const SYMVVector, 0 > map_x(_rhs,size,1); \
x_tmp=map_x.conjugate(); \ x_tmp=map_x.conjugate(); \
x_ptr=x_tmp.data(); \ x_ptr=x_tmp.data(); \
incx=1; \
} else x_ptr=_rhs; \ } else x_ptr=_rhs; \
MKLFUNC(&uplo, &n, &alpha_, (const MKLTYPE*)lhs, &lda, (const MKLTYPE*)x_ptr, &incx, &beta_, (MKLTYPE*)res, &incy); \ BLASFUNC(&uplo, &n, &numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, (const BLASTYPE*)x_ptr, &incx, &numext::real_ref(beta), (BLASTYPE*)res, &incy); \
}\ }\
}; };
EIGEN_MKL_SYMV_SPECIALIZATION(double, double, dsymv) EIGEN_BLAS_SYMV_SPECIALIZATION(double, double, dsymv_)
EIGEN_MKL_SYMV_SPECIALIZATION(float, float, ssymv) EIGEN_BLAS_SYMV_SPECIALIZATION(float, float, ssymv_)
EIGEN_MKL_SYMV_SPECIALIZATION(dcomplex, MKL_Complex16, zhemv) EIGEN_BLAS_SYMV_SPECIALIZATION(dcomplex, double, zhemv_)
EIGEN_MKL_SYMV_SPECIALIZATION(scomplex, MKL_Complex8, chemv) EIGEN_BLAS_SYMV_SPECIALIZATION(scomplex, float, chemv_)
} // end namespace internal } // end namespace internal
} // end namespace Eigen } // end namespace Eigen
#endif // EIGEN_SELFADJOINT_MATRIX_VECTOR_MKL_H #endif // EIGEN_SELFADJOINT_MATRIX_VECTOR_BLAS_H
...@@ -53,7 +53,6 @@ struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,true> ...@@ -53,7 +53,6 @@ struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,true>
static void run(MatrixType& mat, const OtherType& other, const typename MatrixType::Scalar& alpha) static void run(MatrixType& mat, const OtherType& other, const typename MatrixType::Scalar& alpha)
{ {
typedef typename MatrixType::Scalar Scalar; typedef typename MatrixType::Scalar Scalar;
typedef typename MatrixType::Index Index;
typedef internal::blas_traits<OtherType> OtherBlasTraits; typedef internal::blas_traits<OtherType> OtherBlasTraits;
typedef typename OtherBlasTraits::DirectLinearAccessType ActualOtherType; typedef typename OtherBlasTraits::DirectLinearAccessType ActualOtherType;
typedef typename internal::remove_all<ActualOtherType>::type _ActualOtherType; typedef typename internal::remove_all<ActualOtherType>::type _ActualOtherType;
...@@ -86,7 +85,6 @@ struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,false> ...@@ -86,7 +85,6 @@ struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,false>
static void run(MatrixType& mat, const OtherType& other, const typename MatrixType::Scalar& alpha) static void run(MatrixType& mat, const OtherType& other, const typename MatrixType::Scalar& alpha)
{ {
typedef typename MatrixType::Scalar Scalar; typedef typename MatrixType::Scalar Scalar;
typedef typename MatrixType::Index Index;
typedef internal::blas_traits<OtherType> OtherBlasTraits; typedef internal::blas_traits<OtherType> OtherBlasTraits;
typedef typename OtherBlasTraits::DirectLinearAccessType ActualOtherType; typedef typename OtherBlasTraits::DirectLinearAccessType ActualOtherType;
typedef typename internal::remove_all<ActualOtherType>::type _ActualOtherType; typedef typename internal::remove_all<ActualOtherType>::type _ActualOtherType;
...@@ -94,15 +92,27 @@ struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,false> ...@@ -94,15 +92,27 @@ struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,false>
Scalar actualAlpha = alpha * OtherBlasTraits::extractScalarFactor(other.derived()); Scalar actualAlpha = alpha * OtherBlasTraits::extractScalarFactor(other.derived());
enum { IsRowMajor = (internal::traits<MatrixType>::Flags&RowMajorBit) ? 1 : 0 }; enum {
IsRowMajor = (internal::traits<MatrixType>::Flags&RowMajorBit) ? 1 : 0,
OtherIsRowMajor = _ActualOtherType::Flags&RowMajorBit ? 1 : 0
};
Index size = mat.cols();
Index depth = actualOther.cols();
typedef internal::gemm_blocking_space<IsRowMajor ? RowMajor : ColMajor,Scalar,Scalar,
MatrixType::MaxColsAtCompileTime, MatrixType::MaxColsAtCompileTime, _ActualOtherType::MaxColsAtCompileTime> BlockingType;
BlockingType blocking(size, size, depth, 1, false);
internal::general_matrix_matrix_triangular_product<Index, internal::general_matrix_matrix_triangular_product<Index,
Scalar, _ActualOtherType::Flags&RowMajorBit ? RowMajor : ColMajor, OtherBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex, Scalar, OtherIsRowMajor ? RowMajor : ColMajor, OtherBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex,
Scalar, _ActualOtherType::Flags&RowMajorBit ? ColMajor : RowMajor, (!OtherBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex, Scalar, OtherIsRowMajor ? ColMajor : RowMajor, (!OtherBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex,
MatrixType::Flags&RowMajorBit ? RowMajor : ColMajor, UpLo> IsRowMajor ? RowMajor : ColMajor, UpLo>
::run(mat.cols(), actualOther.cols(), ::run(size, depth,
&actualOther.coeffRef(0,0), actualOther.outerStride(), &actualOther.coeffRef(0,0), actualOther.outerStride(), &actualOther.coeffRef(0,0), actualOther.outerStride(), &actualOther.coeffRef(0,0), actualOther.outerStride(),
mat.data(), mat.outerStride(), actualAlpha); mat.data(), mat.outerStride(), actualAlpha, blocking);
} }
}; };
......
...@@ -79,11 +79,11 @@ SelfAdjointView<MatrixType,UpLo>& SelfAdjointView<MatrixType,UpLo> ...@@ -79,11 +79,11 @@ SelfAdjointView<MatrixType,UpLo>& SelfAdjointView<MatrixType,UpLo>
if (IsRowMajor) if (IsRowMajor)
actualAlpha = numext::conj(actualAlpha); actualAlpha = numext::conj(actualAlpha);
internal::selfadjoint_rank2_update_selector<Scalar, Index, typedef typename internal::remove_all<typename internal::conj_expr_if<IsRowMajor ^ UBlasTraits::NeedToConjugate,_ActualUType>::type>::type UType;
typename internal::remove_all<typename internal::conj_expr_if<IsRowMajor ^ UBlasTraits::NeedToConjugate,_ActualUType>::type>::type, typedef typename internal::remove_all<typename internal::conj_expr_if<IsRowMajor ^ VBlasTraits::NeedToConjugate,_ActualVType>::type>::type VType;
typename internal::remove_all<typename internal::conj_expr_if<IsRowMajor ^ VBlasTraits::NeedToConjugate,_ActualVType>::type>::type, internal::selfadjoint_rank2_update_selector<Scalar, Index, UType, VType,
(IsRowMajor ? int(UpLo==Upper ? Lower : Upper) : UpLo)> (IsRowMajor ? int(UpLo==Upper ? Lower : Upper) : UpLo)>
::run(_expression().const_cast_derived().data(),_expression().outerStride(),actualU,actualV,actualAlpha); ::run(_expression().const_cast_derived().data(),_expression().outerStride(),UType(actualU),VType(actualV),actualAlpha);
return *this; return *this;
} }
......