Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Shenguo Wang
Flash Attention
Commits
656daef4
Commit
656daef4
authored
1 year ago
by
Tri Dao
Browse files
Options
Download
Email Patches
Plain Diff
Use Cute's local_tile to get gQ, gK, gV
parent
9eb3d099
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
csrc/flash_attn/src/flash_fwd_kernel.h
+51
-54
csrc/flash_attn/src/flash_fwd_kernel.h
with
51 additions
and
54 deletions
+51
-54
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
656daef4
...
...
@@ -68,14 +68,16 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
// Otherwise we might read OOB elements from gK and gV.
if
((
Is_causal
||
Is_local
||
!
Is_even_MN
)
&&
n_block_max
<=
n_block_min
)
{
const
index_t
row_offset_o
=
binfo
.
q_offset
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
Tensor
gO
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
row_offset_o
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
o_row_stride
,
_1
{}));
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
Tensor
mO
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
binfo
.
q_offset
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)),
make_shape
(
binfo
.
actual_seqlen_q
,
params
.
h
,
params
.
d
),
make_stride
(
params
.
o_row_stride
,
params
.
o_head_stride
,
_1
{}));
Tensor
gO
=
local_tile
(
mO
(
_
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_coord
(
m_block
,
0
));
// (kBlockM, kHeadDim)
Tensor
mLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)),
make_shape
(
params
.
b
,
params
.
h
,
params
.
seqlen_q
),
make_stride
(
params
.
h
*
params
.
seqlen_q
,
params
.
seqlen_q
,
_1
{}));
Tensor
gLSE
=
local_tile
(
mLSE
(
bidb
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>>
{},
make_coord
(
m_block
));
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
...
...
@@ -108,25 +110,27 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).
const
index_t
row_offset_q
=
binfo
.
q_offset
(
params
.
q_batch_stride
,
params
.
q_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
q_row_stride
+
bidh
*
params
.
q_head_stride
;
// We move K and V to the last block.
const
index_t
row_offset_k
=
binfo
.
k_offset
(
params
.
k_batch_stride
,
params
.
k_row_stride
,
bidb
)
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
const
index_t
row_offset_v
=
binfo
.
k_offset
(
params
.
v_batch_stride
,
params
.
v_row_stride
,
bidb
)
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
v_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
;
const
index_t
row_offset_p
=
((
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q_rounded
+
m_block
*
kBlockM
)
*
params
.
seqlen_k_rounded
+
(
n_block_max
-
1
)
*
kBlockN
;
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
q_row_stride
,
_1
{}));
Tensor
gK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
k_row_stride
,
_1
{}));
Tensor
gV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
v_ptr
)
+
row_offset_v
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
v_row_stride
,
_1
{}));
Tensor
mQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
binfo
.
q_offset
(
params
.
q_batch_stride
,
params
.
q_row_stride
,
bidb
)),
make_shape
(
binfo
.
actual_seqlen_q
,
params
.
h
,
params
.
d
),
make_stride
(
params
.
q_row_stride
,
params
.
q_head_stride
,
_1
{}));
Tensor
gQ
=
local_tile
(
mQ
(
_
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_coord
(
m_block
,
0
));
// (kBlockM, kHeadDim)
Tensor
mK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
binfo
.
k_offset
(
params
.
k_batch_stride
,
params
.
k_row_stride
,
bidb
)),
make_shape
(
binfo
.
actual_seqlen_k
,
params
.
h_k
,
params
.
d
),
make_stride
(
params
.
k_row_stride
,
params
.
k_head_stride
,
_1
{}));
Tensor
gK
=
local_tile
(
mK
(
_
,
bidh
/
params
.
h_h_k_ratio
,
_
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_coord
(
_
,
0
));
// (kBlockN, kHeadDim, nblocksN)
Tensor
mV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
v_ptr
)
+
binfo
.
k_offset
(
params
.
v_batch_stride
,
params
.
v_row_stride
,
bidb
)),
make_shape
(
binfo
.
actual_seqlen_k
,
params
.
h_k
,
params
.
d
),
make_stride
(
params
.
v_row_stride
,
params
.
v_head_stride
,
_1
{}));
Tensor
gV
=
local_tile
(
mV
(
_
,
bidh
/
params
.
h_h_k_ratio
,
_
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_coord
(
_
,
0
));
// (kBlockN, kHeadDim, nblocksN)
Tensor
gP
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
p_ptr
)
+
row_offset_p
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{},
make_stride
(
params
.
seqlen_k_rounded
,
_1
{}));
...
...
@@ -145,9 +149,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
tQgQ
=
gmem_thr_copy_QKV
.
partition_S
(
gQ
);
Tensor
tQsQ
=
gmem_thr_copy_QKV
.
partition_D
(
sQ
);
Tensor
tKgK
=
gmem_thr_copy_QKV
.
partition_S
(
gK
);
// (KCPY, KCPY_N, KCPY_K)
Tensor
tKgK
=
gmem_thr_copy_QKV
.
partition_S
(
gK
);
// (KCPY, KCPY_N, KCPY_K
, nblocksN
)
Tensor
tKsK
=
gmem_thr_copy_QKV
.
partition_D
(
sK
);
Tensor
tVgV
=
gmem_thr_copy_QKV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVgV
=
gmem_thr_copy_QKV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K
, nblocksN
)
Tensor
tVsV
=
gmem_thr_copy_QKV
.
partition_D
(
sV
);
typename
Kernel_traits
::
TiledMma
tiled_mma
;
...
...
@@ -240,7 +244,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
int
n_block
=
n_block_max
-
1
;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
(
_
,
_
,
_
,
n_block
)
,
tKsK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
cute
::
cp_async_fence
();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
...
...
@@ -281,12 +285,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Advance gV
if
(
masking_step
>
0
)
{
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tVgV
(
_
,
_
,
_
,
n_block
),
tVsV
,
tKVcKV
,
tKVpKV
);
}
else
{
// Clear the smem tiles to account for predicated off loads
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
gmem_tiled_copy_QKV
,
tVgV
(
_
,
_
,
_
,
n_block
)
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
cute
::
cp_async_fence
();
...
...
@@ -304,9 +307,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
n_block
>
n_block_min
)
{
// Advance gK
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
(
_
,
_
,
_
,
n_block
-
1
),
tKsK
,
tKVcKV
,
tKVpKV
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute
::
cp_async_fence
();
...
...
@@ -354,9 +355,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
clear
(
acc_s
);
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
// Advance gV
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tVgV
(
_
,
_
,
_
,
n_block
),
tVsV
,
tKVcKV
,
tKVpKV
);
cute
::
cp_async_fence
();
flash
::
gemm
<
/*A_in_regs=*/
Kernel_traits
::
Is_Q_in_regs
>
(
...
...
@@ -367,9 +366,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
n_block
>
n_block_min
)
{
// Advance gK
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
(
_
,
_
,
_
,
n_block
-
1
),
tKsK
,
tKVcKV
,
tKVpKV
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute
::
cp_async_fence
();
...
...
@@ -421,14 +418,16 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
cute
::
copy
(
smem_tiled_copy_O
,
taccOrO
,
taccOsO
);
const
index_t
row_offset_o
=
binfo
.
q_offset
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
Tensor
gO
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
row_offset_o
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
o_row_stride
,
_1
{}));
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
Tensor
mO
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
binfo
.
q_offset
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)),
make_shape
(
binfo
.
actual_seqlen_q
,
params
.
h
,
params
.
d
),
make_stride
(
params
.
o_row_stride
,
params
.
o_head_stride
,
_1
{}));
Tensor
gO
=
local_tile
(
mO
(
_
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_coord
(
m_block
,
0
));
// (kBlockM, kHeadDim)
Tensor
mLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)),
make_shape
(
params
.
b
,
params
.
h
,
params
.
seqlen_q
),
make_stride
(
params
.
h
*
params
.
seqlen_q
,
params
.
seqlen_q
,
_1
{}));
Tensor
gLSE
=
local_tile
(
mLSE
(
bidb
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>>
{},
make_coord
(
m_block
));
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
...
...
@@ -555,8 +554,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).
const
index_t
row_offset_q
=
binfo
.
q_offset
(
params
.
q_batch_stride
,
params
.
q_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
q_row_stride
+
bidh
*
params
.
q_head_stride
;
// We move K and V to the last block.
const
int
bidb_cache
=
params
.
cache_batch_idx
==
nullptr
?
bidb
:
params
.
cache_batch_idx
[
bidb
];
const
int
*
block_table
=
params
.
block_table
==
nullptr
?
nullptr
:
params
.
block_table
+
bidb
*
params
.
block_table_batch_stride
;
...
...
@@ -571,9 +568,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
v_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
:
block_table
[
block_table_idx
]
*
params
.
v_batch_stride
+
block_table_offset
*
params
.
v_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
;
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
q_row_stride
,
_1
{}));
Tensor
mQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
binfo
.
q_offset
(
params
.
q_batch_stride
,
params
.
q_row_stride
,
bidb
)),
make_shape
(
binfo
.
actual_seqlen_q
,
params
.
h
,
params
.
d
),
make_stride
(
params
.
q_row_stride
,
params
.
q_head_stride
,
_1
{}));
Tensor
gQ
=
local_tile
(
mQ
(
_
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_coord
(
m_block
,
0
));
// (kBlockM, kHeadDim)
Tensor
gK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
k_row_stride
,
_1
{}));
...
...
@@ -1033,8 +1032,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_Oaccum
,
tOrOaccum
,
tOgOaccum
,
tOcO
,
tOpO
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
// __syncthreads();
// if (cute::thread0()) { print(tOgOaccum); }
}
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment
Menu
Projects
Groups
Snippets
Help