Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2014-2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
// Copyright (C) 2015 Navdeep Jaitly <ndjaitly@google.com>
// Copyright (C) 2014 Eric Martin <eric@ericmart.in>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H
#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H
#if defined(EIGEN_USE_GPU) && defined(__CUDACC__)
namespace Eigen {
template<typename Scalar, typename Index, typename LhsMapper,
typename RhsMapper, typename OutputMapper, bool needs_edge_check>
__device__ EIGEN_STRONG_INLINE void
EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
const OutputMapper output, Scalar* lhs_shmem, Scalar* rhs_shmem,
const Index m_size, const Index n_size, const Index k_size) {
const Index m_block_idx = blockIdx.x;
const Index n_block_idx = blockIdx.y;
const Index base_m = 64 * m_block_idx;
const Index base_n = 64 * n_block_idx;
// declare and initialize 64 registers for output 8x8 block
// prefetch registers
Scalar lhs_pf0;
Scalar lhs_pf1;
Scalar lhs_pf2;
Scalar lhs_pf3;
Scalar lhs_pf4;
Scalar lhs_pf5;
Scalar lhs_pf6;
Scalar lhs_pf7;
Scalar rhs_pf0;
Scalar rhs_pf1;
Scalar rhs_pf2;
Scalar rhs_pf3;
Scalar rhs_pf4;
Scalar rhs_pf5;
Scalar rhs_pf6;
Scalar rhs_pf7;
// shared memory is formatted
// (contract idx in block, nocontract idx in block, block idx)
// where block idx is column major. This transposition limits the number of
// bank conflicts when reading the LHS. The core idea is that since the contracting
// index is shared by both sides, then the contracting index should be in threadIdx.x.
// On the LHS, we pad each row inside of each block with an extra element. This makes
// each block 8 rows of 9 elements, which is 72 elements. This gives no bank conflicts
// on writes and very few 2-way conflicts on reads. There is an 8x8 grid of these blocks.
// On the RHS we just add 8 padding elements to the end of each block. This gives no bank
// conflicts on writes and also none on reads.
// storage indices
const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z;
const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x;
const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0;
const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1;
const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2;
const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3;
const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4;
const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5;
const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6;
const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7;
const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0;
const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1;
const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2;
const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3;
const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4;
const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5;
const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6;
const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7;
// in the loading code, the following variables are important:
// threadIdx.x: the vertical position in an 8x8 block
// threadIdx.y: the vertical index of the 8x8 block in the grid
// threadIdx.z: the horizontal position in an 8x8 block
// k: the horizontal index of the 8x8 block in the grid
//
// The k parameter is implicit (it was the loop counter for a loop that went
// from 0 to <8, but now that loop is unrolled in the below code.
const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y;
const Index lhs_vert = base_m + load_idx_vert;
#define prefetchIntoRegisters(base_k) \
{ \
lhs_pf0 = conv(0); \
lhs_pf1 = conv(0); \
lhs_pf2 = conv(0); \
lhs_pf3 = conv(0); \
lhs_pf4 = conv(0); \
lhs_pf5 = conv(0); \
lhs_pf6 = conv(0); \
lhs_pf7 = conv(0); \
\
rhs_pf0 = conv(0); \
rhs_pf1 = conv(0); \
rhs_pf2 = conv(0); \
rhs_pf3 = conv(0); \
rhs_pf4 = conv(0); \
rhs_pf5 = conv(0); \
rhs_pf6 = conv(0); \
rhs_pf7 = conv(0); \
\
if (!needs_edge_check || lhs_vert < m_size) { \
const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8; \
const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8; \
const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8; \
const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8; \
const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8; \
const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8; \
const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8; \
const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8; \
\
if (!needs_edge_check || lhs_horiz_7 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
lhs_pf7 = lhs(lhs_vert, lhs_horiz_7); \
} else if (lhs_horiz_6 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
} else if (lhs_horiz_5 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
} else if (lhs_horiz_4 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
} else if (lhs_horiz_3 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
} else if (lhs_horiz_2 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
} else if (lhs_horiz_1 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
} else if (lhs_horiz_0 < k_size) { \
lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
} \
} \
\
const Index rhs_vert = base_k + load_idx_vert; \
if (!needs_edge_check || rhs_vert < k_size) { \
const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8; \
const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8; \
const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8; \
const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8; \
const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8; \
const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8; \
const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8; \
const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8; \
\
if (rhs_horiz_7 < n_size) { \
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
rhs_pf7 = rhs(rhs_vert, rhs_horiz_7); \
} else if (rhs_horiz_6 < n_size) { \
rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
Loading full blame...