Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Shenguo Wang
Flash Attention
Commits
9e5e8bc9
Commit
9e5e8bc9
authored
1 year ago
by
Tri Dao
Browse files
Options
Download
Email Patches
Plain Diff
Change causal mask to be aligned to bottom-right instead of top-left
parent
e07aa036
main
v2.5.9
v2.5.9.post1
v2.5.8
v2.5.7
v2.5.6
v2.5.5
v2.5.4
v2.5.3
v2.5.2
v2.5.1
v2.5.1.post1
v2.5.0
v2.4.3
v2.4.3.post1
v2.4.2
v2.4.1
v2.4.0
v2.4.0.post1
v2.3.6
v2.3.5
v2.3.4
v2.3.3
v2.3.2
v2.3.1
v2.3.1.post1
v2.3.0
v2.2.5
v2.2.4
v2.2.4.post1
v2.2.3
v2.2.3.post2
v2.2.3.post1
v2.2.2
v2.2.1
v2.2.0
v2.1.2
v2.1.2.post3
v2.1.2.post2
v2.1.2.post1
v2.1.1
v2.1.0
No related merge requests found
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
README.md
+26
-0
README.md
benchmarks/benchmark_causal.py
+10
-14
benchmarks/benchmark_causal.py
csrc/flash_attn/src/block_info.h
+2
-2
csrc/flash_attn/src/block_info.h
csrc/flash_attn/src/flash_bwd_kernel.h
+19
-52
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+69
-22
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+6
-8
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/softmax.h
+22
-22
csrc/flash_attn/src/softmax.h
flash_attn/__init__.py
+1
-1
flash_attn/__init__.py
flash_attn/flash_attn_interface.py
+48
-0
flash_attn/flash_attn_interface.py
tests/test_flash_attn.py
+369
-82
tests/test_flash_attn.py
training/Dockerfile
+1
-1
training/Dockerfile
with
573 additions
and
204 deletions
+573
-204
README.md
View file @
9e5e8bc9
...
...
@@ -136,6 +136,32 @@ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
```
python
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
)
```
## Changes in v2.1 (compared to v2.0)
If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the
bottom right corner of the attention matrix, instead of the top-left corner.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
v2.0:
1 0 0 0 0
1 1 0 0 0
v2.1:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
v2.0:
1 0
1 1
1 1
1 1
1 1
v2.1:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
## Performance
...
...
This diff is collapsed.
Click to expand it.
benchmarks/benchmark_causal.py
View file @
9e5e8bc9
...
...
@@ -15,12 +15,7 @@ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
# from triton.ops.flash_attention import attention as attention_triton
try
:
from
fav2
import
flash_attn_qkvpacked_func
as
fav2_qkvpacked_func
from
fav2
import
flash_attn_kvpacked_func
as
fav2_kvpacked_func
except
ImportError
:
fav2_qkvpacked_func
=
None
fav2_kvpacked_func
=
None
from
flash_attn
import
flash_attn_qkvpacked_func
,
flash_attn_kvpacked_func
try
:
from
flash_attn.fused_softmax
import
scaled_upper_triang_masked_softmax
...
...
@@ -80,8 +75,8 @@ def attention_megatron(qkv):
torch
.
manual_seed
(
0
)
repeats
=
30
batch_size
=
2
seqlen
=
8192
batch_size
=
8
seqlen
=
2048
nheads
=
12
headdim
=
128
# nheads = 24
...
...
@@ -90,8 +85,8 @@ headdim = 128
# seqlen = 512
# nheads = 8
# headdim = 128
dropout_p
=
0.
1
causal
=
Fals
e
dropout_p
=
0.
0
causal
=
Tru
e
dtype
=
torch
.
float16
device
=
'cuda'
...
...
@@ -100,20 +95,20 @@ qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=d
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen
,
step
=
seqlen
,
dtype
=
torch
.
int32
,
device
=
qkv
.
device
)
#
qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
qkv_unpad
=
rearrange
(
qkv
,
'b s ... -> (b s) ...'
).
detach
().
requires_grad_
(
True
)
# benchmark_all(flash_attn_varlen_qkvpacked_func, qkv_unpad,
# cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention')
# pytorch_profiler(flash_attn_varlen_qkvpacked_func, qkv_unpad,
# cu_seqlens, seqlen, dropout_p, causal=causal, backward=True)
# if fav2_qkvpacked_func is not None:
# benchmark_all(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
# pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)
benchmark_forward
(
flash_attn_qkvpacked_func
,
qkv
,
dropout_p
,
causal
=
causal
,
repeats
=
repeats
,
desc
=
'Fav2'
)
pytorch_profiler
(
flash_attn_qkvpacked_func
,
qkv
,
dropout_p
,
causal
=
causal
,
backward
=
False
)
# for dropout_p in [0.1, 0.0]:
# for causal in [False, True]:
# print(f"### {dropout_p = }, {causal = } ###")
# pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)
# nheads_k = 2
# q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
# kv = torch.randn(batch_size, seqlen, 2, nheads_k, headdim, device=device, dtype=dtype,
...
...
@@ -151,6 +146,7 @@ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch
flops
=
4
*
batch_size
*
seqlen
**
2
*
nheads
*
headdim
ideal_a100_time
=
flops
/
312
/
1e9
print
(
f
"Ideal A100 fwd time:
{
ideal_a100_time
:.
3
f
}
ms, bwd time:
{
ideal_a100_time
*
2.5
:.
3
f
}
ms"
)
exit
(
0
)
def
time_fwd_bwd
(
func
,
*
args
,
**
kwargs
):
...
...
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/block_info.h
View file @
9e5e8bc9
...
...
@@ -32,8 +32,8 @@ struct BlockInfo {
const
int
sum_s_q
;
const
int
sum_s_k
;
const
u
int
32_t
actual_seqlen_q
;
const
u
int
32_t
actual_seqlen_k
;
const
int
actual_seqlen_q
;
const
int
actual_seqlen_k
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
9e5e8bc9
...
...
@@ -659,46 +659,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
tdQgdQaccum
.
data
()
=
tdQgdQaccum
.
data
()
+
kBlockM
*
params
.
d_rounded
;
int
m_block
=
m_block_max
-
1
;
int
m_block_min
=
!
Is_causal
?
0
:
(
n_block
*
kBlockN
-
int
(
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
))
/
kBlockM
;
m_block_min
=
m_block_min
<
0
?
0
:
m_block_min
;
// We might need to exit early and write 0 to dK and dV.
// Otherwise we get wrong result for the case where we don't enter the for loop.
// And we might read OOB elements from gQ and gdO.
// TODO: what if we're not parallelizing, do we need to compute dot_do_o?
if
(
Is_causal
&&
m_block
<
m_block_min
)
{
const
index_t
row_offset_dk
=
binfo
.
k_offset
(
params
.
dk_batch_stride
,
params
.
dk_row_stride
,
bidb
)
+
n_block
*
kBlockN
*
params
.
dk_row_stride
+
bidh
*
params
.
dk_head_stride
;
const
index_t
row_offset_dv
=
binfo
.
k_offset
(
params
.
dv_batch_stride
,
params
.
dv_row_stride
,
bidb
)
+
n_block
*
kBlockN
*
params
.
dv_row_stride
+
bidh
*
params
.
dv_head_stride
;
Tensor
gdK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
dk_ptr
)
+
row_offset_dk
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
dk_row_stride
,
_1
{}));
Tensor
gdV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
dv_ptr
)
+
row_offset_dv
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
dv_row_stride
,
_1
{}));
typename
Kernel_traits
::
GmemTiledCopydKV
gmem_tiled_copy_dKV
;
auto
gmem_thr_copy_dKV
=
gmem_tiled_copy_dKV
.
get_thread_slice
(
tidx
);
Tensor
tdKgdK
=
gmem_thr_copy_dKV
.
partition_D
(
gdK
);
Tensor
tdVgdV
=
gmem_thr_copy_dKV
.
partition_D
(
gdV
);
Tensor
tdKrdK
=
make_tensor
<
Element
>
(
shape
(
tdKgdK
));
Tensor
tdVrdV
=
make_tensor
<
Element
>
(
shape
(
tdVgdV
));
clear
(
tdKrdK
);
clear
(
tdVrdV
);
Tensor
cdKV
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
gdK
),
size
<
1
>
(
gdK
)));
// (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor
tdKVcdKV
=
gmem_thr_copy_dKV
.
partition_D
(
cdKV
);
Tensor
tdKVpdKV
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tdKgdK
)));
#pragma unroll
for
(
int
k
=
0
;
k
<
size
(
tdKVpdKV
);
++
k
)
{
tdKVpdKV
(
k
)
=
get
<
1
>
(
tdKVcdKV
(
0
,
0
,
k
))
<
params
.
d
;
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_dKV
,
tdKrdK
,
tdKgdK
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_dKV
,
tdVrdV
,
tdVgdV
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
return
;
}
int
m_block_min
=
!
Is_causal
?
0
:
std
::
max
(
0
,
(
n_block
*
kBlockN
+
binfo
.
actual_seqlen_q
-
binfo
.
actual_seqlen_k
)
/
kBlockM
);
// We're guaranteed that m_block_min <= m_block:
// We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case,
// n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q.
// So m_block_min <= (actual_seqlen_q - 1) / kBlockM.
// Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM.
// So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM.
// We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop.
if
(
Double_buffer
&&
m_block
%
2
==
1
)
{
// Double buffer for sQ
tQsQ
.
data
()
=
tQsQ
.
data
()
+
size
(
sQ
);
...
...
@@ -743,7 +711,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
Tensor
lse
=
make_tensor
<
ElementAccum
>
(
Shape
<
Int
<
decltype
(
size
(
taccScS_row
))
::
value
>>
{});
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
// Using uint32_t row makes it 10us slower on d=128, not sure why.
const
int
row
=
get
<
0
>
(
taccScS_row
(
mi
));
lse
(
mi
)
=
Is_even_MN
||
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
?
gLSE
(
row
)
:
0
;
}
...
...
@@ -824,11 +791,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short
// (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.
// But we still want to mask out elements beyond actual_seqlen_k.
if
(
m_block
*
kBlockM
<
(
n_block
+
1
)
*
kBlockN
if
(
m_block
*
kBlockM
<
(
n_block
+
1
)
*
kBlockN
+
binfo
.
actual_seqlen_q
-
binfo
.
actual_seqlen_k
||
(
!
Is_even_MN
&&
(
n_block
+
1
)
*
kBlockN
>=
binfo
.
actual_seqlen_k
))
{
flash
::
apply_mask_causal
(
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
binfo
.
actual_seqlen_
q
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
))
,
binfo
.
actual_seqlen_
k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
))
,
binfo
.
actual_seqlen_q
,
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
AtomLayoutMS
*
16
);
}
...
...
@@ -837,11 +804,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// Compute the exponential value.
flash
::
scale_apply_exp2
<
/*scale_max=*/
false
>
(
scores
,
lse
,
params
.
scale_softmax_log2
);
if
(
Is_dropout
)
{
u
int
32_t
warp_id
=
tidx
/
32
;
u
int
32_t
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
warp_id
%
AtomLayoutMS
;
int
warp_id
=
tidx
/
32
;
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
warp_id
%
AtomLayoutMS
;
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
static_assert
(
MMA_N_SdP
%
2
==
0
);
u
int
32_t
block_col_idx
=
n_block
*
(
kBlockN
/
32
)
+
(
warp_id
/
AtomLayoutMS
)
*
(
MMA_N_SdP
/
2
);
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
)
+
(
warp_id
/
AtomLayoutMS
)
*
(
MMA_N_SdP
/
2
);
Tensor
scores_dropped
=
make_tensor
(
scores
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMmaSdP
>
(
scores
.
layout
()));
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
scores_dropped
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
...
...
@@ -1341,7 +1308,6 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
Tensor
lse
=
make_tensor
<
ElementAccum
>
(
Shape
<
Int
<
decltype
(
size
(
taccScS_row
))
::
value
>>
{});
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
// Using uint32_t row makes it 10us slower on d=128, not sure why.
const
int
row
=
get
<
0
>
(
taccScS_row
(
mi
));
lse
(
mi
)
=
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
?
gLSE
(
row
)
:
0
;
}
...
...
@@ -1379,18 +1345,19 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
// the corresponding values of K would be 0, so the result would still be correct.
if
(
Is_causal
&&
m_block
*
kBlockM
<
(
n_block
+
1
)
*
kBlockN
)
{
flash
::
apply_mask_causal
(
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
binfo
.
actual_seqlen_q
,
AtomLayoutMS
*
16
);
}
// Compute the exponential value.
flash
::
scale_apply_exp2
<
/*scale_max=*/
false
>
(
scores
,
lse
,
params
.
scale_softmax_log2
);
if
(
Is_dropout
)
{
u
int
32_t
warp_id
=
tidx
/
32
;
u
int
32_t
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
warp_id
%
AtomLayoutMS
;
int
warp_id
=
tidx
/
32
;
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
warp_id
%
AtomLayoutMS
;
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
static_assert
(
MMA_N_SdP
%
2
==
0
);
u
int
32_t
block_col_idx
=
n_block
*
(
kBlockN
/
32
)
+
(
warp_id
/
AtomLayoutMS
)
*
(
MMA_N_SdP
/
2
);
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
)
+
(
warp_id
/
AtomLayoutMS
)
*
(
MMA_N_SdP
/
2
);
Tensor
scores_dropped
=
make_tensor
(
scores
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMmaSdP
>
(
scores
.
layout
()));
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
scores_dropped
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
...
...
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
9e5e8bc9
...
...
@@ -118,7 +118,7 @@ inline __device__ void write_softmax_to_gmem(
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_N
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_
M
N
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
...
...
@@ -130,8 +130,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
// The global block index.
const
int
block_id
=
blockIdx
.
x
+
blockIdx
.
y
*
gridDim
.
x
+
gridDim
.
x
*
gridDim
.
y
*
blockIdx
.
z
;
constexpr
int
kBlockM
=
Kernel_traits
::
kBlockM
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
...
...
@@ -139,16 +137,60 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
constexpr
int
MMA_M
=
kBlockM
/
decltype
(
size
<
0
>
(
typename
Kernel_traits
::
TiledMma
::
TiledShape_MNK
{}))
::
value
;
const
BlockInfo
<
/*Varlen=*/
!
Is_even_N
>
binfo
(
params
,
bidb
);
const
BlockInfo
<
/*Varlen=*/
!
Is_even_
M
N
>
binfo
(
params
,
bidb
);
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
||
binfo
.
actual_seqlen_k
==
0
)
return
;
int
n_block_max
=
cute
::
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN
);
if
(
Is_causal
)
{
n_block_max
=
std
::
min
(
n_block_max
,
cute
::
ceil_div
(
(
m_block
+
1
)
*
kBlockM
+
int
(
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
)
,
kBlockN
));
n_block_max
=
std
::
min
(
n_block_max
,
cute
::
ceil_div
(
(
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
,
kBlockN
));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
// }
// We exit early and write 0 to gO and gLSE.
// Otherwise we might read OOB elements from gK and gV.
if
(
n_block_max
<=
0
)
{
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might
// exit early and no one saves the rng state.
if
(
Is_dropout
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
&&
tidx
==
0
)
{
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
params
.
rng_state
[
0
]
=
std
::
get
<
0
>
(
seeds
);
params
.
rng_state
[
1
]
=
std
::
get
<
1
>
(
seeds
);
}
const
index_t
row_offset_o
=
binfo
.
q_offset
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
Tensor
gO
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
row_offset_o
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
o_row_stride
,
_1
{}));
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
Tensor
tOgO
=
gmem_thr_copy_O
.
partition_D
(
gO
);
Tensor
tOrO
=
make_tensor
<
Element
>
(
shape
(
tOgO
));
clear
(
tOrO
);
// Construct identity layout for sO
Tensor
cO
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
gO
),
size
<
1
>
(
gO
)));
// (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor
tOcO
=
gmem_thr_copy_O
.
partition_D
(
cO
);
Tensor
tOpO
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tOgO
)));
if
(
!
Is_even_K
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
size
(
tOpO
);
++
k
)
{
tOpO
(
k
)
=
get
<
1
>
(
tOcO
(
0
,
0
,
k
))
<
params
.
d
;
}
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_O
,
tOrO
,
tOgO
,
tOcO
,
tOpO
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
tOgO
);
++
m
)
{
const
int
row
=
get
<
0
>
(
tOcO
(
0
,
m
,
0
));
if
(
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
&&
get
<
1
>
(
tOcO
(
0
,
m
,
0
))
==
0
)
{
gLSE
(
row
)
=
INFINITY
;
}
}
return
;
}
}
// We iterate over the blocks in reverse order. This is because the last block is the only one
...
...
@@ -275,8 +317,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
tQrQ
=
make_fragment_like
(
tQgQ
);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash
::
copy
<
/*
Is_even_MN
=*/
false
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
if
(
Kernel_traits
::
Is_Q_in_regs
)
{
cute
::
cp_async_fence
();
}
// // Copy rmem to smem
...
...
@@ -298,8 +340,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
int
n_block
=
n_block_max
-
1
;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash
::
copy
<
Is_even_N
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
flash
::
copy
<
Is_even_
M
N
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
cute
::
cp_async_fence
();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
// __syncthreads();
...
...
@@ -317,7 +359,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
unsigned
long
long
offset
=
std
::
get
<
1
>
(
seeds
)
+
(
bidb
*
params
.
h
+
bidh
)
*
32
+
tidx
%
32
;
// Save seed and offset for backward.
if
(
block_id
==
0
&&
tidx
==
0
)
{
if
(
Is_dropout
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
&&
tidx
==
0
)
{
params
.
rng_state
[
0
]
=
seed
;
params
.
rng_state
[
1
]
=
std
::
get
<
1
>
(
seeds
);
}
...
...
@@ -330,7 +372,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.
constexpr
int
n_masking_steps
=
Is_causal
?
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
:
1
;
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr
int
n_masking_steps
=
!
Is_causal
?
1
:
(
Is_even_MN
?
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
:
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
+
1
);
#pragma unroll
for
(
int
masking_step
=
0
;
masking_step
<
n_masking_steps
;
++
masking_step
,
--
n_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_M, MMA_N)
...
...
@@ -344,7 +390,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
}
else
{
// Clear the smem tiles to account for predicated off loads
flash
::
copy
<
Is_even_N
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
Is_even_
M
N
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
...
...
@@ -363,7 +409,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN.
if
(
!
Is_causal
)
{
if
(
!
Is_even_N
)
{
flash
::
apply_mask
(
scores
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
if
(
!
Is_even_
M
N
)
{
flash
::
apply_mask
(
scores
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
}
else
{
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
...
...
@@ -376,9 +422,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Idk why it's get<1> and not get<0> of the stride.
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
// I can't get the stride from idx_row
flash
::
apply_mask_causal
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_k
,
flash
::
apply_mask_causal
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
);
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
...
...
@@ -405,8 +452,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
u
int
32_t
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
u
int
32_t
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
if
(
Return_softmax
)
{
Tensor
tOrP_copy
=
make_fragment_like
(
tOrP
);
cute
::
copy
(
tOrP
,
tOrP_copy
);
...
...
@@ -468,8 +515,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
u
int
32_t
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
u
int
32_t
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
if
(
Return_softmax
)
{
Tensor
tOrP_copy
=
make_fragment_like
(
tOrP
);
cute
::
copy
(
tOrP
,
tOrP_copy
);
...
...
@@ -563,14 +610,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
for
(
int
k
=
0
;
k
<
size
(
tOpO
);
++
k
)
{
tOpO
(
k
)
=
get
<
1
>
(
tOcO
(
0
,
0
,
k
))
<
params
.
d
;
}
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
/*
Is_even_MN
=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_O
,
tOrO
,
tOgO
,
tOcO
,
tOpO
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_N
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_
M
N
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
compute_attn
(
const
Params
&
params
)
{
const
int
m_block
=
blockIdx
.
x
;
// The block index for the batch.
...
...
@@ -586,7 +633,7 @@ inline __device__ void compute_attn(const Params ¶ms) {
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
flash
::
compute_attn_1rowblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_N
,
Is_even_K
,
Return_softmax
>
(
params
,
bidb
,
bidh
,
m_block
);
flash
::
compute_attn_1rowblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_
M
N
,
Is_even_K
,
Return_softmax
>
(
params
,
bidb
,
bidh
,
m_block
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
9e5e8bc9
...
...
@@ -10,9 +10,9 @@
#include "flash.h"
#include "flash_fwd_kernel.h"
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_N
,
bool
Is_even_K
,
bool
Return_softmax
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_
M
N
,
bool
Is_even_K
,
bool
Return_softmax
>
__global__
void
flash_fwd_kernel
(
Flash_fwd_params
params
)
{
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_N
,
Is_even_K
,
Return_softmax
>
(
params
);
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_
M
N
,
Is_even_K
,
Return_softmax
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
>
...
...
@@ -26,17 +26,15 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
const
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
dim3
grid
(
num_m_block
,
params
.
b
,
params
.
h
);
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
// for cu_seqlens_q as well.
const
bool
is_even_N
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
;
const
bool
is_even_MN
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
const
bool
return_softmax
=
params
.
p_ptr
!=
nullptr
;
BOOL_SWITCH
(
is_even_N
,
IsEvenNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_
M
N
,
IsEven
M
NConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
// Will only return softmax if dropout, to reduce compilation time.
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
IsEvenNConst
,
IsEvenKConst
,
ReturnSoftmaxConst
&&
Is_dropout
>
;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>;
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
IsEven
M
NConst
,
IsEvenKConst
,
ReturnSoftmaxConst
&&
Is_dropout
>
;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEven
M
NConst, true, ReturnSoftmaxConst && Is_dropout>;
if
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
...
...
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/softmax.h
View file @
9e5e8bc9
...
...
@@ -117,18 +117,18 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
}
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
u
int
32_t
max_seqlen_k
,
const
u
int
32_t
col_idx_offset_
=
0
)
{
inline
__device__
void
apply_mask
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
max_seqlen_k
,
const
int
col_idx_offset_
=
0
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
u
int
32_t
lane_id
=
threadIdx
.
x
%
32
;
const
u
int
32_t
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
u
int
32_t
col_idx_base
=
col_idx_offset
+
nj
*
8
;
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
u
int
32_t
col_idx
=
col_idx_base
+
j
;
const
int
col_idx
=
col_idx_base
+
j
;
if
(
col_idx
>=
max_seqlen_k
)
{
// Without the "make_coord" we get wrong results
#pragma unroll
...
...
@@ -141,28 +141,28 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t
}
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_causal
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
u
int
32_t
col_idx_offset_
,
const
u
int
32_t
max_seqlen_
q
,
const
u
int
32_t
max_seqlen_k
,
const
u
int
32_t
row_idx_offset_
,
const
u
int
32_t
warp_row_stride
)
{
inline
__device__
void
apply_mask_causal
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_
k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
u
int
32_t
lane_id
=
threadIdx
.
x
%
32
;
// const
u
int
32_t
row_idx_offset = row_idx_offset_ + lane_id / 4;
const
u
int
32_t
row_idx_offset
=
row_idx_offset_
;
const
u
int
32_t
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
const
int
lane_id
=
threadIdx
.
x
%
32
;
// const int row_idx_offset = row_idx_offset_ + lane_id / 4;
const
int
row_idx_offset
=
row_idx_offset_
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
,
1
>
(
tensor
);
++
mi
)
{
const
u
int
32_t
row_idx_base
=
row_idx_offset
+
mi
*
warp_row_stride
;
const
int
row_idx_base
=
row_idx_offset
+
mi
*
warp_row_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
0
,
0
>
(
tensor
);
++
i
)
{
const
u
int
32_t
row_idx
=
row_idx_base
+
i
*
8
;
const
u
int
32_t
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
row_idx
+
1
+
max_seqlen_k
-
max_seqlen_q
);
const
int
row_idx
=
row_idx_base
+
i
*
8
;
const
int
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
row_idx
+
1
+
max_seqlen_k
-
max_seqlen_q
);
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
u
int
32_t
col_idx_base
=
col_idx_offset
+
nj
*
8
;
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
u
int
32_t
col_idx
=
col_idx_base
+
j
;
const
int
col_idx
=
col_idx_base
+
j
;
if
(
col_idx
>=
col_idx_limit
)
{
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
=
-
INFINITY
;
}
...
...
@@ -180,7 +180,7 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const u
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
inline
__device__
void
apply_mask_causal_w_idx
(
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
idx_rowcol
,
const
u
int
32_t
col_idx_offset_
,
const
u
int
32_t
max_seqlen_k
,
const
u
int
32_t
row_idx_offset_
)
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
...
...
@@ -189,7 +189,7 @@ inline __device__ void apply_mask_causal_w_idx(
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tensor
)
==
size
<
1
>
(
idx_rowcol
));
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
const
u
int
32_t
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
1
+
row_idx_offset_
+
get
<
0
>
(
idx_rowcol
(
mi
,
0
)));
const
int
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
1
+
row_idx_offset_
+
get
<
0
>
(
idx_rowcol
(
mi
,
0
)));
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
,
1
>
(
tensor
);
++
ni
)
{
if
(
col_idx_offset_
+
get
<
1
>
(
idx_rowcol
(
0
,
ni
))
>=
col_idx_limit
)
{
...
...
@@ -207,8 +207,8 @@ inline __device__ void apply_mask_causal_w_idx(
template
<
bool
encode_dropout_in_sign_bit
=
false
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_dropout
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
uint8_t
p_dropout_in_uint8_t
,
unsigned
long
long
seed
,
unsigned
long
long
offset
,
u
int
32_t
block_row_start
,
u
int
32_t
block_col_start
,
u
int
32_t
block_row_stride
)
{
int
block_row_start
,
int
block_col_start
,
int
block_row_stride
)
{
// tensor has shape (8, MMA_M, MMA_N / 2)
using
T
=
typename
Engine
::
value_type
;
auto
encode_dropout
=
[](
bool
keep
,
T
val
)
{
...
...
This diff is collapsed.
Click to expand it.
flash_attn/__init__.py
View file @
9e5e8bc9
__version__
=
"2.
0.9
"
__version__
=
"2.
1.0
"
from
flash_attn.flash_attn_interface
import
(
flash_attn_func
,
...
...
This diff is collapsed.
Click to expand it.
flash_attn/flash_attn_interface.py
View file @
9e5e8bc9
...
...
@@ -528,6 +528,18 @@ def flash_attn_kvpacked_func(
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
kv: (batch_size, seqlen, 2, nheads_k, headdim)
...
...
@@ -559,6 +571,18 @@ def flash_attn_func(
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
...
...
@@ -645,6 +669,18 @@ def flash_attn_varlen_kvpacked_func(
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
...
...
@@ -703,6 +739,18 @@ def flash_attn_varlen_func(
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
...
...
This diff is collapsed.
Click to expand it.
tests/test_flash_attn.py
View file @
9e5e8bc9
This diff is collapsed.
Click to expand it.
training/Dockerfile
View file @
9e5e8bc9
...
...
@@ -89,7 +89,7 @@ RUN pip install flash-attn==2.0.9
# Install CUDA extensions for cross-entropy, fused dense, layer norm
RUN
git clone https://github.com/HazyResearch/flash-attention
\
&&
cd
flash-attention
&&
git checkout v2.
0.9
\
&&
cd
flash-attention
&&
git checkout v2.
1.0
\
&&
cd
csrc/fused_softmax
&&
pip
install
.
&&
cd
../../
\
&&
cd
csrc/rotary
&&
pip
install
.
&&
cd
../../
\
&&
cd
csrc/xentropy
&&
pip
install
.
&&
cd
../../
\
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment
Menu
Projects
Groups
Snippets
Help