Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Shenguo Wang
Flash Attention
Commits
bb9beb36
Commit
bb9beb36
authored
1 year ago
by
Tri Dao
Browse files
Options
Download
Email Patches
Plain Diff
Remove some unused headers
parent
08c295c0
main
v2.5.9
v2.5.9.post1
v2.5.8
v2.5.7
v2.5.6
v2.5.5
v2.5.4
v2.5.3
v2.5.2
v2.5.1
v2.5.1.post1
v2.5.0
v2.4.3
v2.4.3.post1
v2.4.2
v2.4.1
v2.4.0
v2.4.0.post1
v2.3.6
v2.3.5
v2.3.4
v2.3.3
v2.3.2
v2.3.1
v2.3.1.post1
v2.3.0
v2.2.5
v2.2.4
v2.2.4.post1
v2.2.3
v2.2.3.post2
v2.2.3.post1
No related merge requests found
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
csrc/flash_attn/flash_api.cpp
+3
-1
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/src/flash.h
+1
-2
csrc/flash_attn/src/flash.h
csrc/flash_attn/src/flash_bwd_kernel.h
+0
-3
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+0
-49
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+0
-1
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/softmax.h
+1
-2
csrc/flash_attn/src/softmax.h
setup.py
+1
-1
setup.py
with
6 additions
and
59 deletions
+6
-59
csrc/flash_attn/flash_api.cpp
View file @
bb9beb36
...
...
@@ -2,7 +2,9 @@
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#include <torch/extension.h>
// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
...
...
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/flash.h
View file @
bb9beb36
...
...
@@ -13,8 +13,7 @@
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
constexpr
int
TOTAL_DIM
=
0
;
constexpr
int
H_DIM
=
1
;
...
...
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
bb9beb36
...
...
@@ -5,18 +5,15 @@
#pragma once
#include <cute/algorithm/copy.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "block_info.h"
#include "kernel_traits.h"
#include "utils.h"
#include "softmax.h"
#include "philox.cuh"
namespace
flash
{
...
...
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
bb9beb36
...
...
@@ -4,20 +4,16 @@
#pragma once
#include <cmath>
#include <cute/algorithm/copy.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "block_info.h"
#include "kernel_traits.h"
#include "utils.h"
#include "softmax.h"
#include "philox.cuh"
namespace
flash
{
...
...
@@ -25,49 +21,6 @@ using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
MMA_M
,
class
...
Args
,
class
TiledMMA
>
CUTE_HOST_DEVICE
auto
make_tiled_copy_A_warpcontiguousM
(
Copy_Atom
<
Args
...
>
const
&
copy_atom
,
TiledMMA
const
&
tiled_mma
)
{
using
TileShape_MNK
=
typename
TiledMMA
::
TiledShape_MNK
;
using
AtomShape_MNK
=
typename
TiledMMA
::
AtomShape_MNK
;
constexpr
int
AtomShape_M
=
decltype
(
size
<
0
>
(
AtomShape_MNK
{}))
::
value
;
constexpr
int
kNWarps
=
decltype
(
size
<
0
>
(
TileShape_MNK
{}))
::
value
/
AtomShape_M
;
constexpr
int
MMAStride_M
=
MMA_M
*
AtomShape_M
;
auto
t
=
make_tile
(
Layout
<
Shape
<
Int
<
AtomShape_M
>
,
Int
<
kNWarps
>>
,
Stride
<
_1
,
Int
<
MMAStride_M
>>
>
{},
make_layout
(
size
<
2
>
(
TileShape_MNK
{})));
// if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); }
return
make_tiled_copy_impl
(
copy_atom
,
tiled_mma
.
get_layoutA_TV
(),
t
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
MMA_M
,
class
...
Args
,
class
TiledMMA
>
CUTE_HOST_DEVICE
auto
make_tiled_copy_C_warpcontiguousM
(
Copy_Atom
<
Args
...
>
const
&
copy_atom
,
TiledMMA
const
&
tiled_mma
)
{
using
TileShape_MNK
=
typename
TiledMMA
::
TiledShape_MNK
;
using
AtomShape_MNK
=
typename
TiledMMA
::
AtomShape_MNK
;
constexpr
int
AtomShape_M
=
decltype
(
size
<
0
>
(
AtomShape_MNK
{}))
::
value
;
constexpr
int
kNWarps
=
decltype
(
size
<
0
>
(
TileShape_MNK
{}))
::
value
/
AtomShape_M
;
constexpr
int
MMAStride_M
=
MMA_M
*
AtomShape_M
;
auto
t
=
make_tile
(
Layout
<
Shape
<
Int
<
AtomShape_M
>
,
Int
<
kNWarps
>>
,
Stride
<
_1
,
Int
<
MMAStride_M
>>
>
{},
// TODO: Shouldn't this be size<1>?
make_layout
(
size
<
2
>
(
TileShape_MNK
{})));
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); }
return
make_tiled_copy_impl
(
copy_atom
,
tiled_mma
.
get_layoutC_TV
(),
t
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
Is_first
,
bool
Check_inf
=
false
,
typename
Tensor0
,
typename
Tensor1
,
typename
Tensor2
>
inline
__device__
void
softmax_rescale_o
(
Tensor0
&
scores
,
Tensor1
&
scores_max
,
Tensor1
&
scores_sum
,
Tensor2
&
acc_o
,
float
softmax_scale_log2
)
{
...
...
@@ -256,7 +209,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
auto
smem_tiled_copy_Q
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
);
auto
smem_thr_copy_Q
=
smem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
// auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor
tSsQ
=
smem_thr_copy_Q
.
partition_S
(
sQ
);
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
...
...
@@ -558,7 +510,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Partition sO to match the accumulator partitioning
auto
smem_tiled_copy_O
=
make_tiled_copy_C
(
typename
Kernel_traits
::
SmemCopyAtomO
{},
tiled_mma
);
auto
smem_thr_copy_O
=
smem_tiled_copy_O
.
get_thread_slice
(
tidx
);
// auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
Tensor
taccOrO
=
smem_thr_copy_O
.
retile_S
(
rO
);
// ((Atom,AtomNum), MMA_M, MMA_N)
Tensor
taccOsO
=
smem_thr_copy_O
.
partition_D
(
sO
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
...
...
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
bb9beb36
...
...
@@ -76,7 +76,6 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
knew_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr);
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
,
IsEvenMNConst
&&
!
Append_KV
,
IsEvenKConst
,
Split
,
Append_KV
>
;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
...
...
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/softmax.h
View file @
bb9beb36
...
...
@@ -8,8 +8,7 @@
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include "philox.cuh"
#include "utils.h"
...
...
This diff is collapsed.
Click to expand it.
setup.py
View file @
bb9beb36
...
...
@@ -189,7 +189,7 @@ if not SKIP_CUDA_BUILD:
"--expt-relaxed-constexpr"
,
"--expt-extended-lambda"
,
"--use_fast_math"
,
"--ptxas-options=-v"
,
#
"--ptxas-options=-v",
# "--ptxas-options=-O2",
"-lineinfo"
]
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment
Menu
Projects
Groups
Snippets
Help