Skip to content
ProductEvaluators.h 47.8 KiB
Newer Older
Luker's avatar
Luker committed
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2006-2008 Benoit Jacob <jacob.benoit.1@gmail.com>
// Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2011 Jitse Niesen <jitse@maths.leeds.ac.uk>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.


#ifndef EIGEN_PRODUCTEVALUATORS_H
#define EIGEN_PRODUCTEVALUATORS_H

namespace Eigen {
  
namespace internal {

/** \internal
  * Evaluator of a product expression.
  * Since products require special treatments to handle all possible cases,
  * we simply deffer the evaluation logic to a product_evaluator class
  * which offers more partial specialization possibilities.
  * 
  * \sa class product_evaluator
  */
template<typename Lhs, typename Rhs, int Options>
struct evaluator<Product<Lhs, Rhs, Options> > 
 : public product_evaluator<Product<Lhs, Rhs, Options> >
{
  typedef Product<Lhs, Rhs, Options> XprType;
  typedef product_evaluator<XprType> Base;
  
  EIGEN_DEVICE_FUNC explicit evaluator(const XprType& xpr) : Base(xpr) {}
};
 
// Catch "scalar * ( A * B )" and transform it to "(A*scalar) * B"
// TODO we should apply that rule only if that's really helpful
template<typename Lhs, typename Rhs, typename Scalar1, typename Scalar2, typename Plain1>
struct evaluator_assume_aliasing<CwiseBinaryOp<internal::scalar_product_op<Scalar1,Scalar2>,
                                               const CwiseNullaryOp<internal::scalar_constant_op<Scalar1>, Plain1>,
                                               const Product<Lhs, Rhs, DefaultProduct> > >
{
  static const bool value = true;
};
template<typename Lhs, typename Rhs, typename Scalar1, typename Scalar2, typename Plain1>
struct evaluator<CwiseBinaryOp<internal::scalar_product_op<Scalar1,Scalar2>,
                               const CwiseNullaryOp<internal::scalar_constant_op<Scalar1>, Plain1>,
                               const Product<Lhs, Rhs, DefaultProduct> > >
 : public evaluator<Product<EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar1,Lhs,product), Rhs, DefaultProduct> >
{
  typedef CwiseBinaryOp<internal::scalar_product_op<Scalar1,Scalar2>,
                               const CwiseNullaryOp<internal::scalar_constant_op<Scalar1>, Plain1>,
                               const Product<Lhs, Rhs, DefaultProduct> > XprType;
  typedef evaluator<Product<EIGEN_SCALAR_BINARYOP_EXPR_RETURN_TYPE(Scalar1,Lhs,product), Rhs, DefaultProduct> > Base;

  EIGEN_DEVICE_FUNC explicit evaluator(const XprType& xpr)
    : Base(xpr.lhs().functor().m_other * xpr.rhs().lhs() * xpr.rhs().rhs())
  {}
};


template<typename Lhs, typename Rhs, int DiagIndex>
struct evaluator<Diagonal<const Product<Lhs, Rhs, DefaultProduct>, DiagIndex> > 
 : public evaluator<Diagonal<const Product<Lhs, Rhs, LazyProduct>, DiagIndex> >
{
  typedef Diagonal<const Product<Lhs, Rhs, DefaultProduct>, DiagIndex> XprType;
  typedef evaluator<Diagonal<const Product<Lhs, Rhs, LazyProduct>, DiagIndex> > Base;
  
  EIGEN_DEVICE_FUNC explicit evaluator(const XprType& xpr)
    : Base(Diagonal<const Product<Lhs, Rhs, LazyProduct>, DiagIndex>(
        Product<Lhs, Rhs, LazyProduct>(xpr.nestedExpression().lhs(), xpr.nestedExpression().rhs()),
        xpr.index() ))
  {}
};


// Helper class to perform a matrix product with the destination at hand.
// Depending on the sizes of the factors, there are different evaluation strategies
// as controlled by internal::product_type.
template< typename Lhs, typename Rhs,
          typename LhsShape = typename evaluator_traits<Lhs>::Shape,
          typename RhsShape = typename evaluator_traits<Rhs>::Shape,
          int ProductType = internal::product_type<Lhs,Rhs>::value>
struct generic_product_impl;

template<typename Lhs, typename Rhs>
struct evaluator_assume_aliasing<Product<Lhs, Rhs, DefaultProduct> > {
  static const bool value = true;
};

// This is the default evaluator implementation for products:
// It creates a temporary and call generic_product_impl
template<typename Lhs, typename Rhs, int Options, int ProductTag, typename LhsShape, typename RhsShape>
struct product_evaluator<Product<Lhs, Rhs, Options>, ProductTag, LhsShape, RhsShape>
  : public evaluator<typename Product<Lhs, Rhs, Options>::PlainObject>
{
  typedef Product<Lhs, Rhs, Options> XprType;
  typedef typename XprType::PlainObject PlainObject;
  typedef evaluator<PlainObject> Base;
  enum {
    Flags = Base::Flags | EvalBeforeNestingBit
  };

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
  explicit product_evaluator(const XprType& xpr)
    : m_result(xpr.rows(), xpr.cols())
  {
    ::new (static_cast<Base*>(this)) Base(m_result);
    
// FIXME shall we handle nested_eval here?,
// if so, then we must take care at removing the call to nested_eval in the specializations (e.g., in permutation_matrix_product, transposition_matrix_product, etc.)
//     typedef typename internal::nested_eval<Lhs,Rhs::ColsAtCompileTime>::type LhsNested;
//     typedef typename internal::nested_eval<Rhs,Lhs::RowsAtCompileTime>::type RhsNested;
//     typedef typename internal::remove_all<LhsNested>::type LhsNestedCleaned;
//     typedef typename internal::remove_all<RhsNested>::type RhsNestedCleaned;
//     
//     const LhsNested lhs(xpr.lhs());
//     const RhsNested rhs(xpr.rhs());
//   
//     generic_product_impl<LhsNestedCleaned, RhsNestedCleaned>::evalTo(m_result, lhs, rhs);

    generic_product_impl<Lhs, Rhs, LhsShape, RhsShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
  }
  
protected:  
  PlainObject m_result;
};

// The following three shortcuts are enabled only if the scalar types match excatly.
// TODO: we could enable them for different scalar types when the product is not vectorized.

// Dense = Product
template< typename DstXprType, typename Lhs, typename Rhs, int Options, typename Scalar>
struct Assignment<DstXprType, Product<Lhs,Rhs,Options>, internal::assign_op<Scalar,Scalar>, Dense2Dense,
  typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct)>::type>
{
  typedef Product<Lhs,Rhs,Options> SrcXprType;
  static EIGEN_STRONG_INLINE
  void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar,Scalar> &)
  {
    Index dstRows = src.rows();
    Index dstCols = src.cols();
    if((dst.rows()!=dstRows) || (dst.cols()!=dstCols))
      dst.resize(dstRows, dstCols);
    // FIXME shall we handle nested_eval here?
    generic_product_impl<Lhs, Rhs>::evalTo(dst, src.lhs(), src.rhs());
  }
};

// Dense += Product
template< typename DstXprType, typename Lhs, typename Rhs, int Options, typename Scalar>
struct Assignment<DstXprType, Product<Lhs,Rhs,Options>, internal::add_assign_op<Scalar,Scalar>, Dense2Dense,
  typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct)>::type>
{
  typedef Product<Lhs,Rhs,Options> SrcXprType;
  static EIGEN_STRONG_INLINE
  void run(DstXprType &dst, const SrcXprType &src, const internal::add_assign_op<Scalar,Scalar> &)
  {
    eigen_assert(dst.rows() == src.rows() && dst.cols() == src.cols());
    // FIXME shall we handle nested_eval here?
    generic_product_impl<Lhs, Rhs>::addTo(dst, src.lhs(), src.rhs());
  }
};

// Dense -= Product
template< typename DstXprType, typename Lhs, typename Rhs, int Options, typename Scalar>
struct Assignment<DstXprType, Product<Lhs,Rhs,Options>, internal::sub_assign_op<Scalar,Scalar>, Dense2Dense,
  typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct)>::type>
{
  typedef Product<Lhs,Rhs,Options> SrcXprType;
  static EIGEN_STRONG_INLINE
  void run(DstXprType &dst, const SrcXprType &src, const internal::sub_assign_op<Scalar,Scalar> &)
  {
    eigen_assert(dst.rows() == src.rows() && dst.cols() == src.cols());
    // FIXME shall we handle nested_eval here?
    generic_product_impl<Lhs, Rhs>::subTo(dst, src.lhs(), src.rhs());
  }
};


// Dense ?= scalar * Product
// TODO we should apply that rule if that's really helpful
// for instance, this is not good for inner products
template< typename DstXprType, typename Lhs, typename Rhs, typename AssignFunc, typename Scalar, typename ScalarBis, typename Plain>
struct Assignment<DstXprType, CwiseBinaryOp<internal::scalar_product_op<ScalarBis,Scalar>, const CwiseNullaryOp<internal::scalar_constant_op<ScalarBis>,Plain>,
                                           const Product<Lhs,Rhs,DefaultProduct> >, AssignFunc, Dense2Dense>
{
  typedef CwiseBinaryOp<internal::scalar_product_op<ScalarBis,Scalar>,
                        const CwiseNullaryOp<internal::scalar_constant_op<ScalarBis>,Plain>,
                        const Product<Lhs,Rhs,DefaultProduct> > SrcXprType;
  static EIGEN_STRONG_INLINE
  void run(DstXprType &dst, const SrcXprType &src, const AssignFunc& func)
  {
    call_assignment_no_alias(dst, (src.lhs().functor().m_other * src.rhs().lhs())*src.rhs().rhs(), func);
  }
};

//----------------------------------------
Loading full blame...