Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Shenguo Wang
Flash Attention
Commits
26d7d92f
Commit
26d7d92f
authored
1 year ago
by
Tri Dao
Browse files
Options
Download
Email Patches
Plain Diff
Fix splitKV combine function when local LSEs are all -inf
parent
de2949f3
main
v2.5.9
v2.5.9.post1
v2.5.8
v2.5.7
v2.5.6
v2.5.5
v2.5.4
v2.5.3
v2.5.2
v2.5.1
v2.5.1.post1
v2.5.0
v2.4.3
v2.4.3.post1
v2.4.2
v2.4.1
v2.4.0
v2.4.0.post1
v2.3.6
v2.3.5
v2.3.4
v2.3.3
v2.3.2
v2.3.1
v2.3.1.post1
v2.3.0
v2.2.5
v2.2.4
v2.2.4.post1
v2.2.3
v2.2.3.post2
v2.2.3.post1
v2.2.2
v2.2.1
v2.2.0
v2.1.2
v2.1.2.post3
v2.1.2.post2
v2.1.2.post1
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
csrc/flash_attn/src/flash_fwd_kernel.h
+3
-1
csrc/flash_attn/src/flash_fwd_kernel.h
with
3 additions
and
1 deletion
+3
-1
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
26d7d92f
...
...
@@ -1124,7 +1124,9 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
for
(
int
l
=
1
;
l
<
kNLsePerThread
;
++
l
)
{
lse_sum
+=
expf
(
lse_accum
(
l
)
-
lse_max
);
}
SumOp
<
float
>
sum_op
;
lse_sum
=
Allreduce
<
kRowsPerLoadTranspose
>::
run
(
lse_sum
,
sum_op
);
ElementAccum
lse_logsum
=
logf
(
lse_sum
)
+
lse_max
;
// For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise
// lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
ElementAccum
lse_logsum
=
(
lse_sum
==
0.
f
||
lse_sum
!=
lse_sum
)
?
INFINITY
:
logf
(
lse_sum
)
+
lse_max
;
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
if
(
tidx
%
kRowsPerLoadTranspose
==
0
&&
tidx
/
kRowsPerLoadTranspose
<
kBlockM
)
{
gLSE
(
tidx
/
kRowsPerLoadTranspose
)
=
lse_logsum
;
}
// Store the scales exp(lse - lse_logsum) in shared memory.
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment
Menu
Projects
Groups
Snippets
Help