Commit fd20f16a authored by Tri Dao's avatar Tri Dao
Browse files

Support cache_seqlens being integer

Showing with 39 additions and 16 deletions
+39 -16
from typing import Optional, Union
import torch
import torch.nn as nn
from einops import rearrange
# isort: off
# We need to import the CUDA kernels after importing torch
......@@ -799,7 +800,7 @@ def flash_attn_with_kvcache(
v_cache,
k=None,
v=None,
cache_seqlens=None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
softmax_scale=None,
causal=False,
num_splits=0,
......@@ -840,7 +841,8 @@ def flash_attn_with_kvcache(
k [optional]: (batch_size, seqlen, nheads_k, headdim). If not None, we concatenate k with
k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen, nheads_k, headdim). Similar to k.
cache_seqlens: (batch_size,), dtype torch.int32. The sequence lengths of the KV cache.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
......@@ -858,6 +860,10 @@ def flash_attn_with_kvcache(
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if cache_seqlens is not None and isinstance(cache_seqlens, int):
cache_seqlens = torch.full(
(k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
)
out, softmax_lse = flash_attn_cuda.fwd_kvcache(
q, k_cache, v_cache, k, v, cache_seqlens, None, softmax_scale, causal, num_splits
)
......
......@@ -3,7 +3,12 @@ import re
import pytest
import torch
from einops import rearrange
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2, shard_state_dict_tp, combine_state_dicts_tp
from flash_attn.models.gpt import (
GPTLMHeadModel,
remap_state_dict_hf_gpt2,
shard_state_dict_tp,
combine_state_dicts_tp,
)
from flash_attn.utils.generation import InferenceParams
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import GPT2Config, GPT2Tokenizer
......@@ -130,9 +135,9 @@ def test_gpt2_optimized(model_name):
@pytest.mark.parametrize("fused_ft_kernel", [False, True])
# @pytest.mark.parametrize('fused_ft_kernel', [True])
# @pytest.mark.parametrize('fused_ft_kernel', [False])
@pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize('optimized', [False])
# @pytest.mark.parametrize('optimized', [True])
@pytest.mark.parametrize("rotary", [False, True])
# @pytest.mark.parametrize('rotary', [False])
@pytest.mark.parametrize("model_name", ["gpt2"])
......@@ -204,7 +209,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
)
print(out.sequences)
print(tokenizer.batch_decode(out.sequences.tolist()))
if fused_ft_kernel:
if fused_ft_kernel or config.use_flash_attn:
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
......@@ -263,7 +268,6 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
out = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=True,
teacher_outputs=teacher_outputs,
return_dict_in_generate=True,
output_scores=True,
......@@ -277,8 +281,9 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])
@pytest.mark.parametrize("rotary", [None, "interleaved", "block"])
# @pytest.mark.parametrize('rotary', [None])
@pytest.mark.parametrize("fused_ft_kernel", [False, True])
@pytest.mark.parametrize("model_name", ["gpt2"])
def test_gpt2_generation_cg(model_name, rotary, seqlen, maxlen):
def test_gpt2_generation_cg(model_name, fused_ft_kernel, rotary, seqlen, maxlen):
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph."""
dtype = torch.float16
device = "cuda"
......@@ -308,8 +313,17 @@ def test_gpt2_generation_cg(model_name, rotary, seqlen, maxlen):
0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
)
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
logits = get_logits(
model, input_ids, maxlen, teacher_outputs=teacher_outputs, fused_ft_kernel=fused_ft_kernel
)
logits_cg = get_logits(
model,
input_ids,
maxlen,
teacher_outputs=teacher_outputs,
fused_ft_kernel=fused_ft_kernel,
cg=True,
)
assert torch.equal(logits, logits_cg)
# Try increasing batch size and seqlen, then decrease them to see if it's still correct
......@@ -446,11 +460,14 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
print(tokenizer.batch_decode(out_og.sequences))
@pytest.mark.parametrize("n_heads_q_kv", [
(8, 8), # Regular attention
(8, 4), # GQA
(8, 2), # MQA
])
@pytest.mark.parametrize(
"n_heads_q_kv",
[
(8, 8), # Regular attention
(8, 4), # GQA
(8, 2), # MQA
],
)
def test_gpt2_shard_unshard(n_heads_q_kv):
world_size = 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment