Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Shenguo Wang
Flash Attention
Commits
fd20f16a
Commit
fd20f16a
authored
1 year ago
by
Tri Dao
Browse files
Options
Download
Email Patches
Plain Diff
Support cache_seqlens being integer
parent
913922ca
main
v2.5.9
v2.5.9.post1
v2.5.8
v2.5.7
v2.5.6
v2.5.5
v2.5.4
v2.5.3
v2.5.2
v2.5.1
v2.5.1.post1
v2.5.0
v2.4.3
v2.4.3.post1
v2.4.2
v2.4.1
v2.4.0
v2.4.0.post1
v2.3.6
v2.3.5
v2.3.4
v2.3.3
v2.3.2
v2.3.1
v2.3.1.post1
v2.3.0
v2.2.5
v2.2.4
v2.2.4.post1
v2.2.3
v2.2.3.post2
v2.2.3.post1
v2.2.2
v2.2.1
v2.2.0
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
flash_attn/flash_attn_interface.py
+9
-3
flash_attn/flash_attn_interface.py
tests/models/test_gpt.py
+30
-13
tests/models/test_gpt.py
with
39 additions
and
16 deletions
+39
-16
flash_attn/flash_attn_interface.py
View file @
fd20f16a
from
typing
import
Optional
,
Union
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
# isort: off
# We need to import the CUDA kernels after importing torch
...
...
@@ -799,7 +800,7 @@ def flash_attn_with_kvcache(
v_cache
,
k
=
None
,
v
=
None
,
cache_seqlens
=
None
,
cache_seqlens
:
Optional
[
Union
[(
int
,
torch
.
Tensor
)]]
=
None
,
softmax_scale
=
None
,
causal
=
False
,
num_splits
=
0
,
...
...
@@ -840,7 +841,8 @@ def flash_attn_with_kvcache(
k [optional]: (batch_size, seqlen, nheads_k, headdim). If not None, we concatenate k with
k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen, nheads_k, headdim). Similar to k.
cache_seqlens: (batch_size,), dtype torch.int32. The sequence lengths of the KV cache.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
...
...
@@ -858,6 +860,10 @@ def flash_attn_with_kvcache(
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
if
cache_seqlens
is
not
None
and
isinstance
(
cache_seqlens
,
int
):
cache_seqlens
=
torch
.
full
(
(
k_cache
.
shape
[
0
],),
cache_seqlens
,
dtype
=
torch
.
int32
,
device
=
k_cache
.
device
)
out
,
softmax_lse
=
flash_attn_cuda
.
fwd_kvcache
(
q
,
k_cache
,
v_cache
,
k
,
v
,
cache_seqlens
,
None
,
softmax_scale
,
causal
,
num_splits
)
...
...
This diff is collapsed.
Click to expand it.
tests/models/test_gpt.py
View file @
fd20f16a
...
...
@@ -3,7 +3,12 @@ import re
import
pytest
import
torch
from
einops
import
rearrange
from
flash_attn.models.gpt
import
GPTLMHeadModel
,
remap_state_dict_hf_gpt2
,
shard_state_dict_tp
,
combine_state_dicts_tp
from
flash_attn.models.gpt
import
(
GPTLMHeadModel
,
remap_state_dict_hf_gpt2
,
shard_state_dict_tp
,
combine_state_dicts_tp
,
)
from
flash_attn.utils.generation
import
InferenceParams
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
transformers
import
GPT2Config
,
GPT2Tokenizer
...
...
@@ -130,9 +135,9 @@ def test_gpt2_optimized(model_name):
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
False
,
True
])
# @pytest.mark.parametrize('fused_ft_kernel', [
Tru
e])
# @pytest.mark.parametrize('fused_ft_kernel', [
Fals
e])
@
pytest
.
mark
.
parametrize
(
"optimized"
,
[
False
,
True
])
# @pytest.mark.parametrize('optimized', [
Fals
e])
# @pytest.mark.parametrize('optimized', [
Tru
e])
@
pytest
.
mark
.
parametrize
(
"rotary"
,
[
False
,
True
])
# @pytest.mark.parametrize('rotary', [False])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"gpt2"
])
...
...
@@ -204,7 +209,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
)
print
(
out
.
sequences
)
print
(
tokenizer
.
batch_decode
(
out
.
sequences
.
tolist
()))
if
fused_ft_kernel
:
if
fused_ft_kernel
or
config
.
use_flash_attn
:
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
...
...
@@ -263,7 +268,6 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
True
,
teacher_outputs
=
teacher_outputs
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
...
...
@@ -277,8 +281,9 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])
@
pytest
.
mark
.
parametrize
(
"rotary"
,
[
None
,
"interleaved"
,
"block"
])
# @pytest.mark.parametrize('rotary', [None])
@
pytest
.
mark
.
parametrize
(
"fused_ft_kernel"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"gpt2"
])
def
test_gpt2_generation_cg
(
model_name
,
rotary
,
seqlen
,
maxlen
):
def
test_gpt2_generation_cg
(
model_name
,
fused_ft_kernel
,
rotary
,
seqlen
,
maxlen
):
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph."""
dtype
=
torch
.
float16
device
=
"cuda"
...
...
@@ -308,8 +313,17 @@ def test_gpt2_generation_cg(model_name, rotary, seqlen, maxlen):
0
,
config
.
vocab_size
,
(
batch_size
,
maxlen
),
dtype
=
torch
.
long
,
device
=
device
)
logits
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
)
logits_cg
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
,
cg
=
True
)
logits
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
,
fused_ft_kernel
=
fused_ft_kernel
)
logits_cg
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
)
assert
torch
.
equal
(
logits
,
logits_cg
)
# Try increasing batch size and seqlen, then decrease them to see if it's still correct
...
...
@@ -446,11 +460,14 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
print
(
tokenizer
.
batch_decode
(
out_og
.
sequences
))
@
pytest
.
mark
.
parametrize
(
"n_heads_q_kv"
,
[
(
8
,
8
),
# Regular attention
(
8
,
4
),
# GQA
(
8
,
2
),
# MQA
])
@
pytest
.
mark
.
parametrize
(
"n_heads_q_kv"
,
[
(
8
,
8
),
# Regular attention
(
8
,
4
),
# GQA
(
8
,
2
),
# MQA
],
)
def
test_gpt2_shard_unshard
(
n_heads_q_kv
):
world_size
=
2
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment
Menu
Projects
Groups
Snippets
Help