Skip to content
TensorContractionCuda.h 60.6 KiB
Newer Older
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2014-2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
// Copyright (C) 2015 Navdeep Jaitly <ndjaitly@google.com>
// Copyright (C) 2014 Eric Martin <eric@ericmart.in>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H
#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H

#if defined(EIGEN_USE_GPU) && defined(__CUDACC__)

namespace Eigen {

template<typename Scalar, typename Index, typename LhsMapper,
         typename RhsMapper, typename OutputMapper, bool needs_edge_check>
__device__ EIGEN_STRONG_INLINE void
EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
                               const OutputMapper output, Scalar* lhs_shmem, Scalar* rhs_shmem,
                       const Index m_size, const Index n_size, const Index k_size) {

  const Index m_block_idx = blockIdx.x;
  const Index n_block_idx = blockIdx.y;

  const Index base_m = 64 * m_block_idx;
  const Index base_n = 64 * n_block_idx;

  // declare and initialize 64 registers for output 8x8 block

  // prefetch registers
  Scalar lhs_pf0;
  Scalar lhs_pf1;
  Scalar lhs_pf2;
  Scalar lhs_pf3;
  Scalar lhs_pf4;
  Scalar lhs_pf5;
  Scalar lhs_pf6;
  Scalar lhs_pf7;

  Scalar rhs_pf0;
  Scalar rhs_pf1;
  Scalar rhs_pf2;
  Scalar rhs_pf3;
  Scalar rhs_pf4;
  Scalar rhs_pf5;
  Scalar rhs_pf6;
  Scalar rhs_pf7;

  // shared memory is formatted
  // (contract idx in block, nocontract idx in block, block idx)
  // where block idx is column major. This transposition limits the number of
  // bank conflicts when reading the LHS. The core idea is that since the contracting
  // index is shared by both sides, then the contracting index should be in threadIdx.x.

  // On the LHS, we pad each row inside of each block with an extra element. This makes
  // each block 8 rows of 9 elements, which is 72 elements. This gives no bank conflicts
  // on writes and very few 2-way conflicts on reads. There is an 8x8 grid of these blocks.

  // On the RHS we just add 8 padding elements to the end of each block. This gives no bank
  // conflicts on writes and also none on reads.

  // storage indices
  const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z;
  const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x;

  const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0;
  const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1;
  const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2;
  const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3;
  const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4;
  const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5;
  const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6;
  const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7;

  const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0;
  const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1;
  const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2;
  const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3;
  const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4;
  const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5;
  const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6;
  const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7;

  // in the loading code, the following variables are important:
  // threadIdx.x: the vertical position in an 8x8 block
  // threadIdx.y: the vertical index of the 8x8 block in the grid
  // threadIdx.z: the horizontal position in an 8x8 block
  // k: the horizontal index of the 8x8 block in the grid
  //
  // The k parameter is implicit (it was the loop counter for a loop that went
  // from 0 to <8, but now that loop is unrolled in the below code.

  const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y;
  const Index lhs_vert = base_m + load_idx_vert;

#define prefetchIntoRegisters(base_k)                           \
  {                                                             \
    lhs_pf0 = conv(0);                                          \
    lhs_pf1 = conv(0);                                          \
    lhs_pf2 = conv(0);                                          \
    lhs_pf3 = conv(0);                                          \
    lhs_pf4 = conv(0);                                          \
    lhs_pf5 = conv(0);                                          \
    lhs_pf6 = conv(0);                                          \
    lhs_pf7 = conv(0);                                          \
                                                                \
    rhs_pf0 = conv(0);                                          \
    rhs_pf1 = conv(0);                                          \
    rhs_pf2 = conv(0);                                          \
    rhs_pf3 = conv(0);                                          \
    rhs_pf4 = conv(0);                                          \
    rhs_pf5 = conv(0);                                          \
    rhs_pf6 = conv(0);                                          \
    rhs_pf7 = conv(0);                                          \
                                                                \
    if (!needs_edge_check || lhs_vert < m_size) {               \
      const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8;   \
      const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8;   \
      const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8;   \
      const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8;   \
      const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8;   \
      const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8;   \
      const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8;   \
      const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8;   \
                                                                \
      if (!needs_edge_check || lhs_horiz_7 < k_size) {          \
        lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
        lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
        lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
        lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
        lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
        lhs_pf5 = lhs(lhs_vert, lhs_horiz_5);                   \
        lhs_pf6 = lhs(lhs_vert, lhs_horiz_6);                   \
        lhs_pf7 = lhs(lhs_vert, lhs_horiz_7);                   \
      } else if (lhs_horiz_6 < k_size) {                        \
        lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
        lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
        lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
        lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
        lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
        lhs_pf5 = lhs(lhs_vert, lhs_horiz_5);                   \
        lhs_pf6 = lhs(lhs_vert, lhs_horiz_6);                   \
      } else if (lhs_horiz_5 < k_size) {                        \
        lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
        lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
        lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
        lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
        lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
        lhs_pf5 = lhs(lhs_vert, lhs_horiz_5);                   \
      } else if (lhs_horiz_4 < k_size) {                        \
        lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
        lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
        lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
        lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
        lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
      } else if (lhs_horiz_3 < k_size) {                        \
        lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
        lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
        lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
        lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
      } else if (lhs_horiz_2 < k_size) {                        \
        lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
        lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
        lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
      } else if (lhs_horiz_1 < k_size) {                        \
        lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
        lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
      } else if (lhs_horiz_0 < k_size) {                        \
        lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
      }                                                         \
    }                                                           \
                                                                \
    const Index rhs_vert = base_k + load_idx_vert;              \
    if (!needs_edge_check || rhs_vert < k_size) {               \
      const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8;   \
      const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8;   \
      const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8;   \
      const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8;   \
      const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8;   \
      const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8;   \
      const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8;   \
      const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8;   \
                                                                \
      if (rhs_horiz_7 < n_size) {                               \
        rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
        rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
        rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
        rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
        rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
        rhs_pf5 = rhs(rhs_vert, rhs_horiz_5);                   \
        rhs_pf6 = rhs(rhs_vert, rhs_horiz_6);                   \
        rhs_pf7 = rhs(rhs_vert, rhs_horiz_7);                   \
      } else if (rhs_horiz_6 < n_size) {                        \
        rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
        rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
        rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
Loading full blame...