Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Shenguo Wang
Flash Attention
Commits
37c6e054
Commit
37c6e054
authored
1 year ago
by
Tri Dao
Browse files
Options
Download
Email Patches
Plain Diff
Implement flash_attn_with_kvcache
parent
4976650f
main
v2.5.9
v2.5.9.post1
v2.5.8
v2.5.7
v2.5.6
v2.5.5
v2.5.4
v2.5.3
v2.5.2
v2.5.1
v2.5.1.post1
v2.5.0
v2.4.3
v2.4.3.post1
v2.4.2
v2.4.1
v2.4.0
v2.4.0.post1
v2.3.6
v2.3.5
v2.3.4
v2.3.3
v2.3.2
v2.3.1
v2.3.1.post1
v2.3.0
v2.2.5
v2.2.4
v2.2.4.post1
v2.2.3
v2.2.3.post2
v2.2.3.post1
v2.2.2
v2.2.1
v2.2.0
No related merge requests found
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
csrc/flash_attn/flash_api.cpp
+192
-3
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/src/block_info.h
+7
-2
csrc/flash_attn/src/block_info.h
csrc/flash_attn/src/flash.h
+16
-0
csrc/flash_attn/src/flash.h
csrc/flash_attn/src/flash_fwd_kernel.h
+150
-61
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+41
-31
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/utils.h
+69
-2
csrc/flash_attn/src/utils.h
flash_attn/__init__.py
+1
-0
flash_attn/__init__.py
flash_attn/flash_attn_interface.py
+72
-0
flash_attn/flash_attn_interface.py
flash_attn/utils/generation.py
+25
-9
flash_attn/utils/generation.py
tests/test_flash_attn.py
+90
-0
tests/test_flash_attn.py
with
663 additions
and
108 deletions
+663
-108
csrc/flash_attn/flash_api.cpp
View file @
37c6e054
...
...
@@ -102,6 +102,7 @@ void set_params_fprop(Flash_fwd_params ¶ms,
TORCH_CHECK
(
p_dropout
<
1.
f
);
params
.
is_causal
=
is_causal
;
params
.
is_seqlens_k_cumulative
=
true
;
}
void
set_params_dgrad
(
Flash_bwd_params
&
params
,
...
...
@@ -175,10 +176,10 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
params
.
dsoftmax_sum
=
dsoftmax_sum_d
;
}
void
run_mha_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
,
bool
force_split_kernel
=
false
)
{
FP16_SWITCH
(
!
params
.
is_bf16
,
[
&
]
{
FWD_HEADDIM_SWITCH
(
params
.
d
,
[
&
]
{
if
(
params
.
num_splits
<=
1
)
{
// If we don't set it num_splits == 0
if
(
params
.
num_splits
<=
1
&&
!
force_split_kernel
)
{
// If we don't set it num_splits == 0
run_mha_fwd_
<
elem_type
,
kHeadDim
>
(
params
,
stream
);
}
else
{
run_mha_fwd_splitkv_dispatch
<
elem_type
,
kHeadDim
>
(
params
,
stream
);
...
...
@@ -350,7 +351,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
const
int
num_m_blocks
=
(
seqlen_q
+
64
-
1
)
/
64
;
params
.
num_splits
=
1
;
if
(
p_dropout
==
0.0
f
)
{
// SplitKV is not implemented for dropout
params
.
num_splits
=
num_splits_heuristic
(
batch_size
*
num_heads
*
num_m_blocks
,
dprops
->
multiProcessorCount
,
num_n_blocks
,
64
);
params
.
num_splits
=
num_splits_heuristic
(
batch_size
*
num_heads
*
num_m_blocks
,
dprops
->
multiProcessorCount
,
num_n_blocks
,
128
);
if
(
params
.
num_splits
>
1
)
{
at
::
Tensor
softmax_lse_accum
=
torch
::
empty
({
params
.
num_splits
,
batch_size
,
num_heads
,
seqlen_q
},
opts
.
dtype
(
at
::
kFloat
));
at
::
Tensor
out_accum
=
torch
::
empty
({
params
.
num_splits
,
batch_size
,
num_heads
,
seqlen_q
,
head_size_rounded
},
opts
.
dtype
(
at
::
kFloat
));
...
...
@@ -990,10 +991,198 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
return
{
dq
,
dk
,
dv
,
softmax_d
};
}
std
::
vector
<
at
::
Tensor
>
mha_fwd_kvcache
(
const
at
::
Tensor
&
q
,
// batch_size x seqlen_q x num_heads x head_size
const
at
::
Tensor
&
kcache
,
// batch_size x seqlen_k x num_heads_k x head_size
const
at
::
Tensor
&
vcache
,
// batch_size x seqlen_k x num_heads_k x head_size
c10
::
optional
<
const
at
::
Tensor
>
&
k_
,
// batch_size x seqlen_q x num_heads_k x head_size
c10
::
optional
<
const
at
::
Tensor
>
&
v_
,
// batch_size x seqlen_q x num_heads_k x head_size
c10
::
optional
<
const
at
::
Tensor
>
&
seqlens_k_
,
// batch_size
c10
::
optional
<
at
::
Tensor
>
&
out_
,
// batch_size x seqlen_q x num_heads x head_size
const
float
softmax_scale
,
const
bool
is_causal
,
int
num_splits
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
;
bool
is_sm90
=
dprops
->
major
==
9
&&
dprops
->
minor
==
0
;
TORCH_CHECK
(
is_sm90
||
is_sm8x
,
"FlashAttention only supports Ampere GPUs or newer."
);
// We will support Turing in the near future
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
auto
q_dtype
=
q
.
dtype
();
TORCH_CHECK
(
q_dtype
==
torch
::
kFloat16
||
q_dtype
==
torch
::
kBFloat16
,
"FlashAttention only support fp16 and bf16 data type"
);
if
(
q_dtype
==
torch
::
kBFloat16
)
{
TORCH_CHECK
(
is_sm90
||
is_sm8x
,
"bfloat16 is only supported on Ampere GPUs or newer"
);
}
TORCH_CHECK
(
kcache
.
dtype
()
==
q_dtype
,
"query and key must have the same dtype"
);
TORCH_CHECK
(
vcache
.
dtype
()
==
q_dtype
,
"query and value must have the same dtype"
);
TORCH_CHECK
(
q
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
TORCH_CHECK
(
kcache
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
TORCH_CHECK
(
vcache
.
is_cuda
(),
"Input tensor must be on CUDA device"
);
TORCH_CHECK
(
q
.
stride
(
-
1
)
==
1
,
"Input tensor must have contiguous last dimension"
);
TORCH_CHECK
(
kcache
.
stride
(
-
1
)
==
1
,
"Input tensor must have contiguous last dimension"
);
TORCH_CHECK
(
vcache
.
stride
(
-
1
)
==
1
,
"Input tensor must have contiguous last dimension"
);
const
auto
sizes
=
q
.
sizes
();
const
int
batch_size
=
sizes
[
0
];
const
int
seqlen_q
=
sizes
[
1
];
const
int
num_heads
=
sizes
[
2
];
const
int
head_size_og
=
sizes
[
3
];
const
int
seqlen_k
=
kcache
.
size
(
1
);
const
int
num_heads_k
=
kcache
.
size
(
2
);
TORCH_CHECK
(
batch_size
>
0
,
"batch size must be postive"
);
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
CHECK_SHAPE
(
q
,
batch_size
,
seqlen_q
,
num_heads
,
head_size_og
);
CHECK_SHAPE
(
kcache
,
batch_size
,
seqlen_k
,
num_heads_k
,
head_size_og
);
CHECK_SHAPE
(
vcache
,
batch_size
,
seqlen_k
,
num_heads_k
,
head_size_og
);
at
::
Tensor
q_padded
,
kcache_padded
,
vcache_padded
;
if
(
head_size_og
%
8
!=
0
)
{
q_padded
=
torch
::
nn
::
functional
::
pad
(
q
,
torch
::
nn
::
functional
::
PadFuncOptions
({
0
,
8
-
head_size_og
%
8
}));
kcache_padded
=
torch
::
nn
::
functional
::
pad
(
kcache
,
torch
::
nn
::
functional
::
PadFuncOptions
({
0
,
8
-
head_size_og
%
8
}));
vcache_padded
=
torch
::
nn
::
functional
::
pad
(
vcache
,
torch
::
nn
::
functional
::
PadFuncOptions
({
0
,
8
-
head_size_og
%
8
}));
}
else
{
q_padded
=
q
;
kcache_padded
=
kcache
;
vcache_padded
=
vcache
;
}
at
::
Tensor
out
;
if
(
out_
.
has_value
())
{
out
=
out_
.
value
();
TORCH_CHECK
(
out
.
dtype
()
==
q_dtype
,
"Output must have the same dtype as inputs"
);
TORCH_CHECK
(
out
.
is_cuda
(),
"Output tensor must be on CUDA device"
);
TORCH_CHECK
(
out
.
stride
(
-
1
)
==
1
,
"Output tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
out
,
batch_size
,
seqlen_q
,
num_heads
,
head_size_og
);
if
(
head_size_og
%
8
!=
0
)
{
out
=
torch
::
empty_like
(
q_padded
);
}
}
else
{
out
=
torch
::
empty_like
(
q_padded
);
}
auto
round_multiple
=
[](
int
x
,
int
m
)
{
return
(
x
+
m
-
1
)
/
m
*
m
;
};
const
int
head_size
=
round_multiple
(
head_size_og
,
8
);
const
int
head_size_rounded
=
round_multiple
(
head_size
,
32
);
const
int
seqlen_q_rounded
=
round_multiple
(
seqlen_q
,
128
);
const
int
seqlen_k_rounded
=
round_multiple
(
seqlen_k
,
128
);
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q
.
get_device
()};
auto
opts
=
q
.
options
();
auto
softmax_lse
=
torch
::
empty
({
batch_size
,
num_heads
,
seqlen_q
},
opts
.
dtype
(
at
::
kFloat
));
Flash_fwd_params
params
;
set_params_fprop
(
params
,
batch_size
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
seqlen_k_rounded
,
num_heads
,
num_heads_k
,
head_size
,
head_size_rounded
,
q_padded
,
kcache_padded
,
vcache_padded
,
out
,
/*cu_seqlens_q_d=*/
nullptr
,
/*cu_seqlens_k_d=*/
nullptr
,
/*p_ptr=*/
nullptr
,
softmax_lse
.
data_ptr
(),
/*p_dropout=*/
0.
f
,
softmax_scale
,
is_causal
);
at
::
Tensor
k
,
v
,
k_padded
,
v_padded
;
if
(
k_
.
has_value
())
{
TORCH_CHECK
(
v_
.
has_value
(),
"If key is supplied, value must also be passed in"
);
TORCH_CHECK
(
seqlens_k_
.
has_value
(),
"If key is supplied, seqlens_k must also be passed in"
);
TORCH_CHECK
(
seqlen_q
<=
seqlen_k
,
"If key is supplied, it must have seqlen <= the seqlen of the KV cache"
);
k
=
k_
.
value
();
v
=
v_
.
value
();
TORCH_CHECK
(
k
.
dtype
()
==
q_dtype
,
"Key must have the same dtype as query"
);
TORCH_CHECK
(
v
.
dtype
()
==
q_dtype
,
"Value must have the same dtype as query"
);
TORCH_CHECK
(
k
.
is_cuda
(),
"Key tensor must be on CUDA device"
);
TORCH_CHECK
(
v
.
is_cuda
(),
"Value tensor must be on CUDA device"
);
TORCH_CHECK
(
k
.
stride
(
-
1
)
==
1
,
"Key tensor must have contiguous last dimension"
);
TORCH_CHECK
(
v
.
stride
(
-
1
)
==
1
,
"Value tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
k
,
batch_size
,
seqlen_q
,
num_heads_k
,
head_size_og
);
CHECK_SHAPE
(
v
,
batch_size
,
seqlen_q
,
num_heads_k
,
head_size_og
);
if
(
head_size_og
%
8
!=
0
)
{
k_padded
=
torch
::
nn
::
functional
::
pad
(
k
,
torch
::
nn
::
functional
::
PadFuncOptions
({
0
,
8
-
head_size_og
%
8
}));
v_padded
=
torch
::
nn
::
functional
::
pad
(
v
,
torch
::
nn
::
functional
::
PadFuncOptions
({
0
,
8
-
head_size_og
%
8
}));
}
else
{
k_padded
=
k
;
v_padded
=
v
;
}
params
.
knew_ptr
=
k_padded
.
data_ptr
();
params
.
vnew_ptr
=
v_padded
.
data_ptr
();
// All stride are in elements, not bytes.
params
.
knew_batch_stride
=
k_padded
.
stride
(
0
);
params
.
vnew_batch_stride
=
v_padded
.
stride
(
0
);
params
.
knew_row_stride
=
k_padded
.
stride
(
-
3
);
params
.
vnew_row_stride
=
v_padded
.
stride
(
-
3
);
params
.
knew_head_stride
=
k_padded
.
stride
(
-
2
);
params
.
vnew_head_stride
=
v_padded
.
stride
(
-
2
);
}
if
(
seqlens_k_
.
has_value
())
{
auto
seqlens_k
=
seqlens_k_
.
value
();
TORCH_CHECK
(
seqlens_k
.
dtype
()
==
torch
::
kInt32
,
"seqlens_k must have dtype int32"
);
TORCH_CHECK
(
seqlens_k
.
is_cuda
(),
"seqlens_k must be on CUDA device"
);
TORCH_CHECK
(
seqlens_k
.
is_contiguous
(),
"seqlens_k must be contiguous"
);
CHECK_SHAPE
(
seqlens_k
,
batch_size
);
params
.
cu_seqlens_k
=
static_cast
<
int
*>
(
seqlens_k
.
data_ptr
());
}
params
.
is_seqlens_k_cumulative
=
!
(
seqlens_k_
.
has_value
());
// This needs to match with run_mha_fwd_splitkv_dispatch
const
int
block_n
=
is_sm90
||
is_sm8x
?
(
head_size
<=
64
?
256
:
(
head_size
<=
160
?
128
:
64
))
:
(
head_size
<=
64
?
256
:
(
head_size
<=
128
?
128
:
64
));
const
int
num_n_blocks
=
(
seqlen_k
+
(
params
.
knew_ptr
==
nullptr
?
0
:
seqlen_q
)
+
block_n
-
1
)
/
block_n
;
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
// In any case we don't expect seqlen_q to be larger than 64 for inference.
const
int
num_m_blocks
=
(
seqlen_q
+
64
-
1
)
/
64
;
params
.
num_splits
=
num_splits
;
if
(
num_splits
<
1
)
{
params
.
num_splits
=
num_splits_heuristic
(
batch_size
*
num_heads
*
num_m_blocks
,
dprops
->
multiProcessorCount
,
num_n_blocks
,
128
);
}
if
(
params
.
num_splits
>
1
)
{
at
::
Tensor
softmax_lse_accum
=
torch
::
empty
({
params
.
num_splits
,
batch_size
,
num_heads
,
seqlen_q
},
opts
.
dtype
(
at
::
kFloat
));
at
::
Tensor
out_accum
=
torch
::
empty
({
params
.
num_splits
,
batch_size
,
num_heads
,
seqlen_q
,
head_size_rounded
},
opts
.
dtype
(
at
::
kFloat
));
params
.
softmax_lseaccum_ptr
=
softmax_lse_accum
.
data_ptr
();
params
.
oaccum_ptr
=
out_accum
.
data_ptr
();
}
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
// Only split kernel supports appending to KV cache
run_mha_fwd
(
params
,
stream
,
/*force_split_kernel=*/
k_
.
has_value
());
if
(
head_size_og
%
8
!=
0
)
{
out
=
out
.
index
({
"..."
,
torch
::
indexing
::
Slice
(
torch
::
indexing
::
None
,
head_size_og
)});
if
(
out_
.
has_value
())
{
out_
.
value
().
copy_
(
out
);
}
if
(
k_
.
has_value
())
{
// It's expensive to copy the KV cache here for the case where head size not divisible by 8,
// but we don't expect to get this case in practice. This is just so that the code works for that case.
kcache
.
copy_
(
kcache_padded
.
index
({
"..."
,
torch
::
indexing
::
Slice
(
torch
::
indexing
::
None
,
head_size_og
)}));
vcache
.
copy_
(
vcache_padded
.
index
({
"..."
,
torch
::
indexing
::
Slice
(
torch
::
indexing
::
None
,
head_size_og
)}));
}
}
return
{
out
,
softmax_lse
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
doc
()
=
"FlashAttention"
;
m
.
def
(
"fwd"
,
&
mha_fwd
,
"Forward pass"
);
m
.
def
(
"varlen_fwd"
,
&
mha_varlen_fwd
,
"Forward pass (variable length)"
);
m
.
def
(
"bwd"
,
&
mha_bwd
,
"Backward pass"
);
m
.
def
(
"varlen_bwd"
,
&
mha_varlen_bwd
,
"Backward pass (variable length)"
);
m
.
def
(
"fwd_kvcache"
,
&
mha_fwd_kvcache
,
"Forward pass, with KV-cache"
);
}
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/block_info.h
View file @
37c6e054
...
...
@@ -14,9 +14,12 @@ struct BlockInfo {
template
<
typename
Params
>
__device__
BlockInfo
(
const
Params
&
params
,
const
int
bidb
)
:
sum_s_q
(
!
Varlen
||
params
.
cu_seqlens_q
==
nullptr
?
-
1
:
params
.
cu_seqlens_q
[
bidb
])
,
sum_s_k
(
!
Varlen
||
params
.
cu_seqlens_k
==
nullptr
?
-
1
:
params
.
cu_seqlens_k
[
bidb
])
,
sum_s_k
(
!
Varlen
||
params
.
cu_seqlens_k
==
nullptr
||
!
params
.
is_seqlens_k_cumulative
?
-
1
:
params
.
cu_seqlens_k
[
bidb
])
,
actual_seqlen_q
(
!
Varlen
||
params
.
cu_seqlens_q
==
nullptr
?
params
.
seqlen_q
:
params
.
cu_seqlens_q
[
bidb
+
1
]
-
sum_s_q
)
,
actual_seqlen_k
(
!
Varlen
||
params
.
cu_seqlens_k
==
nullptr
?
params
.
seqlen_k
:
params
.
cu_seqlens_k
[
bidb
+
1
]
-
sum_s_k
)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
,
seqlen_k_cache
(
!
Varlen
||
params
.
cu_seqlens_k
==
nullptr
?
params
.
seqlen_k
:
(
params
.
is_seqlens_k_cumulative
?
params
.
cu_seqlens_k
[
bidb
+
1
]
-
sum_s_k
:
params
.
cu_seqlens_k
[
bidb
]))
,
actual_seqlen_k
(
seqlen_k_cache
+
(
params
.
knew_ptr
==
nullptr
?
0
:
params
.
seqlen_q
))
{
}
...
...
@@ -33,6 +36,8 @@ struct BlockInfo {
const
int
sum_s_q
;
const
int
sum_s_k
;
const
int
actual_seqlen_q
;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
const
int
seqlen_k_cache
;
const
int
actual_seqlen_k
;
};
...
...
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/flash.h
View file @
37c6e054
...
...
@@ -80,6 +80,18 @@ struct Flash_fwd_params : public Qkv_params {
int
*
__restrict__
blockmask
;
// The K_new and V_new matrices.
void
*
__restrict__
knew_ptr
;
void
*
__restrict__
vnew_ptr
;
// The stride between rows of the Q, K and V matrices.
index_t
knew_batch_stride
;
index_t
vnew_batch_stride
;
index_t
knew_row_stride
;
index_t
vnew_row_stride
;
index_t
knew_head_stride
;
index_t
vnew_head_stride
;
// The dropout probability (probability of keeping an activation).
float
p_dropout
;
// uint32_t p_dropout_in_uint;
...
...
@@ -99,6 +111,10 @@ struct Flash_fwd_params : public Qkv_params {
bool
is_bf16
;
bool
is_causal
;
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
bool
is_seqlens_k_cumulative
;
int
num_splits
;
// For split-KV version
};
...
...
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
37c6e054
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
37c6e054
...
...
@@ -15,9 +15,9 @@ __global__ void flash_fwd_kernel(Flash_fwd_params params) {
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
>
__global__
void
flash_fwd_splitkv_kernel
(
Flash_fwd_params
params
)
{
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_even_MN
,
Is_even_K
>
(
params
);
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
);
}
template
<
typename
Kernel_traits
,
int
Log_max_splits
,
bool
Is_even_K
>
...
...
@@ -63,45 +63,55 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
template
<
typename
Kernel_traits
>
void
run_flash_splitkv_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
static_assert
(
!
Kernel_traits
::
Is_Q_in_regs
,
"SplitKV implementation does not support Is_Q_in_regs"
);
static_assert
(
!
Kernel_traits
::
Share_Q_K_smem
,
"SplitKV implementation does not support Share_Q_K_smem"
);
constexpr
size_t
smem_size
=
Kernel_traits
::
kSmemSize
;
const
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
dim3
grid
(
num_m_block
,
params
.
num_splits
,
params
.
b
*
params
.
h
);
dim3
grid
(
num_m_block
,
params
.
num_splits
>
1
?
params
.
num_splits
:
params
.
b
,
params
.
num_splits
>
1
?
params
.
b
*
params
.
h
:
params
.
h
);
const
bool
is_even_MN
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
// TODO: do we want to guarantee that seqlen_q <= seqlen_k? That would simplify the kernel a bit.
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
,
IsEvenMNConst
,
IsEvenKConst
>
;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
knew_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr);
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
,
IsEvenMNConst
&&
!
Append_KV
,
IsEvenKConst
,
Split
,
Append_KV
>
;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
});
});
});
});
dim3
grid_combine
((
params
.
b
*
params
.
h
*
params
.
seqlen_q
+
16
-
1
)
/
16
);
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
if
(
params
.
num_splits
<=
2
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
1
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
4
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
2
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
8
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
3
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
16
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
4
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
32
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
5
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
64
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
6
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
// } else if (params.num_splits <= 128) {
// flash_fwd_splitkv_combine_kernel<Kernel_traits, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
}
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
if
(
params
.
num_splits
>
1
)
{
dim3
grid_combine
((
params
.
b
*
params
.
h
*
params
.
seqlen_q
+
16
-
1
)
/
16
);
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
if
(
params
.
num_splits
<=
2
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
1
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
4
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
2
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
8
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
3
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
16
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
4
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
32
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
5
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
64
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
6
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
128
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
7
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
}
}
template
<
typename
T
,
int
Headdim
>
...
...
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/utils.h
View file @
37c6e054
...
...
@@ -291,7 +291,7 @@ template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bo
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
inline
__device__
void
copy
(
TiledCopy
tiled_copy
,
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
identity_MN
,
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
int
max_MN
=
0
)
{
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
const
int
max_MN
=
0
)
{
CUTE_STATIC_ASSERT_V
(
rank
(
S
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
rank
(
D
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
S
)
==
size
<
0
>
(
D
));
// MMA
...
...
@@ -355,4 +355,71 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace flash
template
<
bool
Is_2_sources
=
false
,
bool
Is_even_MN
=
true
,
bool
Is_even_K
=
true
,
bool
Clear_OOB_MN
=
false
,
bool
Clear_OOB_K
=
true
,
typename
TiledCopy
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
inline
__device__
void
copy_2_sources
(
TiledCopy
tiled_copy
,
Tensor
<
Engine0
,
Layout0
>
const
&
S0
,
Tensor
<
Engine0
,
Layout0
>
const
&
S1
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
identity_MN
,
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
const
int
max_MN
=
0
,
const
int
row_idx_switch
=
0
)
{
CUTE_STATIC_ASSERT_V
(
rank
(
S0
)
==
Int
<
3
>
{}
&&
rank
(
S1
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
rank
(
D
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
S0
)
==
size
<
0
>
(
D
)
&&
size
<
0
>
(
S1
)
==
size
<
0
>
(
D
));
// MMA
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
S0
)
==
size
<
1
>
(
D
)
&&
size
<
1
>
(
S1
)
==
size
<
1
>
(
D
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
S0
)
==
size
<
2
>
(
D
)
&&
size
<
2
>
(
S1
)
==
size
<
2
>
(
D
));
// MMA_K
// There's no case where !Clear_OOB_K && Clear_OOB_MN
static_assert
(
!
(
Clear_OOB_MN
&&
!
Clear_OOB_K
));
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", Is_2_sources, max_MN, row_idx_switch); }
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", blockIdx.y, Is_2_sources, max_MN, row_idx_switch); }
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
S0
);
++
m
)
{
auto
&
S
=
!
Is_2_sources
||
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
<
row_idx_switch
?
S0
:
S1
;
if
(
Is_even_MN
||
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
<
max_MN
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
size
<
2
>
(
S0
);
++
k
)
{
if
(
Is_even_K
||
predicate_K
(
k
))
{
cute
::
copy
(
tiled_copy
,
S
(
_
,
m
,
k
),
D
(
_
,
m
,
k
));
}
else
if
(
Clear_OOB_K
)
{
cute
::
clear
(
D
(
_
,
m
,
k
));
}
}
}
else
if
(
Clear_OOB_MN
)
{
cute
::
clear
(
D
(
_
,
m
,
_
));
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
Is_even_K
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
inline
__device__
void
copy_w_min_idx
(
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
identity_MN
,
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
const
int
max_MN
=
0
,
const
int
min_MN
=
0
)
{
CUTE_STATIC_ASSERT_V
(
rank
(
S
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
rank
(
D
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
S
)
==
size
<
0
>
(
D
));
// MMA
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
S
)
==
size
<
1
>
(
D
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
S
)
==
size
<
2
>
(
D
));
// MMA_K
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); }
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
S
);
++
m
)
{
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
if
(
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
>=
min_MN
&&
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
<
max_MN
)
{
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
#pragma unroll
for
(
int
k
=
0
;
k
<
size
<
2
>
(
S
);
++
k
)
{
if
(
Is_even_K
||
predicate_K
(
k
))
{
cute
::
copy
(
S
(
_
,
m
,
k
),
D
(
_
,
m
,
k
));
}
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace flash
\ No newline at end of file
This diff is collapsed.
Click to expand it.
flash_attn/__init__.py
View file @
37c6e054
...
...
@@ -7,4 +7,5 @@ from flash_attn.flash_attn_interface import (
flash_attn_varlen_func
,
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
flash_attn_with_kvcache
,
)
This diff is collapsed.
Click to expand it.
flash_attn/flash_attn_interface.py
View file @
37c6e054
...
...
@@ -5,6 +5,7 @@ from einops import rearrange
# isort: off
# We need to import the CUDA kernels after importing torch
import
flash_attn_2_cuda
as
flash_attn_cuda
# isort: on
...
...
@@ -790,3 +791,74 @@ def flash_attn_varlen_func(
causal
,
return_attn_probs
,
)
def
flash_attn_with_kvcache
(
q
,
k_cache
,
v_cache
,
k
=
None
,
v
=
None
,
cache_seqlens
=
None
,
softmax_scale
=
None
,
causal
=
False
,
num_splits
=
0
,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
the previous step, and update them with the new keys/values from the current step, and do
attention with the updated cache, all in 1 kernel.
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
Does not support backward pass.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size, seqlen_cache, nheads_k, headdim)
v_cache: (batch_size, seqlen_cache, nheads_k, headdim)
k [optional]: (batch_size, seqlen, nheads_k, headdim). If not None, we concatenate k with
k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen, nheads_k, headdim). Similar to k.
cache_seqlens: (batch_size,), dtype torch.int32. The sequence lengths of the KV cache.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
Don't change this unless you know what you are doing.
Return:
out: (batch_size, seqlen, nheads, headdim).
"""
assert
k_cache
.
stride
(
-
1
)
==
1
,
"k_cache must have contiguous last dimension"
assert
v_cache
.
stride
(
-
1
)
==
1
,
"v_cache must have contiguous last dimension"
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
is
not
None
and
x
.
stride
(
-
1
)
!=
1
else
x
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
softmax_lse
=
flash_attn_cuda
.
fwd_kvcache
(
q
,
k_cache
,
v_cache
,
k
,
v
,
cache_seqlens
,
None
,
softmax_scale
,
causal
,
num_splits
)
return
out
This diff is collapsed.
Click to expand it.
flash_attn/utils/generation.py
View file @
37c6e054
...
...
@@ -348,8 +348,14 @@ def decode_speculative(
)
def
sample_tokens
(
input_ids
,
model
,
inference_params
,
sample_fn
,
num_tokens
=
1
,
cg
=
False
,
decoding
=
True
,
last_token_logits
=
False
input_ids
,
model
,
inference_params
,
sample_fn
,
num_tokens
=
1
,
cg
=
False
,
decoding
=
True
,
last_token_logits
=
False
,
):
"""Sample `num_tokens` tokens from the model, given the previous logits.
Also return the logits of the sampled tokens.
...
...
@@ -374,12 +380,18 @@ def decode_speculative(
sequences
=
[]
if
decoding
:
assert
seqlen
==
1
position_ids
=
torch
.
full
(
(
batch_size
,
1
),
inference_params
.
sequence_len_offset
,
dtype
=
torch
.
long
,
device
=
input_ids
.
devic
e
,
position_ids
=
repeat
(
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
+
inference_params
.
sequence_len_offset
,
"s -> b s"
,
b
=
batch_siz
e
,
)
# position_ids = torch.full(
# (batch_size, 1),
# inference_params.sequence_len_offset,
# dtype=torch.long,
# device=input_ids.device,
# )
else
:
position_ids
=
None
logits
=
logits_postprocess_fn
(
...
...
@@ -399,7 +411,11 @@ def decode_speculative(
)
logits
=
logits_postprocess_fn
(
logits_forward_fn
(
model
,
rearrange
(
next_token
,
"b -> b 1"
),
position_ids
,
inference_params
,
cg
=
cg
model
,
rearrange
(
next_token
,
"b -> b 1"
),
position_ids
,
inference_params
,
cg
=
cg
,
)
)
inference_params
.
sequence_len_offset
+=
1
...
...
@@ -420,7 +436,7 @@ def decode_speculative(
sample_fn
=
sample_fn
,
last_token_logits
=
True
,
inference_params
=
inference_params_draft
,
cg
=
cg
cg
=
cg
,
)
if
debug
:
...
...
This diff is collapsed.
Click to expand it.
tests/test_flash_attn.py
View file @
37c6e054
...
...
@@ -11,6 +11,7 @@ from flash_attn import (
flash_attn_varlen_func
,
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
flash_attn_with_kvcache
,
)
from
flash_attn.bert_padding
import
pad_input
,
unpad_input
from
flash_attn.flash_attn_interface
import
_get_block_size
...
...
@@ -1465,6 +1466,95 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
+
2e-4
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
+
2e-4
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.float16])
@
pytest
.
mark
.
parametrize
(
"num_splits"
,
[
1
,
0
])
# @pytest.mark.parametrize("num_splits", [0])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize("mha_type", ["mqa"])
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
,
True
])
# @pytest.mark.parametrize("new_kv", [False])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[
(
1
,
128
),
(
1
,
339
),
(
3
,
1024
),
(
64
,
800
),
(
64
,
256
),
(
3
,
799
),
(
64
,
2048
),
(
16
,
20000
),
(
1
,
128
*
1024
),
(
16
,
128
*
1024
),
(
128
,
128
),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_kvcache
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
new_kv
,
mha_type
,
num_splits
,
dtype
):
if
seqlen_q
>
seqlen_k
and
new_kv
:
pytest
.
skip
()
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
2
nheads
=
6
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
assert
nheads
%
nheads_k
==
0
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
if
new_kv
:
k
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
else
:
k
,
v
=
None
,
None
k_cache
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v_cache
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
cache_seqlens
=
torch
.
randint
(
0
,
(
seqlen_k
-
seqlen_q
+
1
)
if
new_kv
else
(
seqlen_k
+
1
),
(
batch_size
,
),
dtype
=
torch
.
int32
,
device
=
device
)
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
# k_cache[:, 64:] = -1
k_cache_ref
=
k_cache
.
clone
()
v_cache_ref
=
v_cache
.
clone
()
arange
=
rearrange
(
torch
.
arange
(
seqlen_k
,
device
=
device
),
"s -> 1 s"
)
cache_seqlens_expanded
=
rearrange
(
cache_seqlens
,
"b -> b 1"
)
if
new_kv
:
update_mask
=
torch
.
logical_and
(
cache_seqlens_expanded
<=
arange
,
arange
<
cache_seqlens_expanded
+
seqlen_q
)
k_cache_ref
[
update_mask
]
=
rearrange
(
k
,
"b s ... -> (b s) ..."
)
v_cache_ref
[
update_mask
]
=
rearrange
(
v
,
"b s ... -> (b s) ..."
)
k_cache_rep
=
repeat
(
k_cache_ref
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
v_cache_rep
=
repeat
(
v_cache_ref
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
out
=
flash_attn_with_kvcache
(
q
,
k_cache
,
v_cache
,
k
,
v
,
cache_seqlens
,
causal
=
causal
,
num_splits
=
num_splits
)
# out = flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal)
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal)
# qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
# m = qk.amax(-1, keepdim=True)
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# probs = torch.softmax(qk, dim=-1)
key_padding_mask
=
arange
<
cache_seqlens_expanded
+
(
seqlen_q
if
new_kv
else
0
)
out_ref
,
_
=
attention_ref
(
q
,
k_cache_rep
,
v_cache_rep
,
None
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
)
out_pt
,
_
=
attention_ref
(
q
,
k_cache_rep
,
v_cache_rep
,
None
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
,
upcast
=
False
,
reorder_ops
=
True
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch mean diff:
{
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
3
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
if
new_kv
:
assert
torch
.
equal
(
k_cache
,
k_cache_ref
)
assert
torch
.
equal
(
v_cache
,
v_cache_ref
)
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment
Menu
Projects
Groups
Snippets
Help