Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Shenguo Wang
Flash Attention
Commits
3557e0bb
Commit
3557e0bb
authored
1 year ago
by
Tri Dao
Browse files
Options
Download
Email Patches
Plain Diff
[MLP] Implement SwiGLU with torch jiterator
parent
37c6e054
main
v2.5.9
v2.5.9.post1
v2.5.8
v2.5.7
v2.5.6
v2.5.5
v2.5.4
v2.5.3
v2.5.2
v2.5.1
v2.5.1.post1
v2.5.0
v2.4.3
v2.4.3.post1
v2.4.2
v2.4.1
v2.4.0
v2.4.0.post1
v2.3.6
v2.3.5
v2.3.4
v2.3.3
v2.3.2
v2.3.1
v2.3.1.post1
v2.3.0
v2.2.5
v2.2.4
v2.2.4.post1
v2.2.3
v2.2.3.post2
v2.2.3.post1
v2.2.2
v2.2.1
v2.2.0
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
flash_attn/modules/mlp.py
+10
-1
flash_attn/modules/mlp.py
flash_attn/ops/activations.py
+31
-0
flash_attn/ops/activations.py
with
41 additions
and
1 deletion
+41
-1
flash_attn/modules/mlp.py
View file @
3557e0bb
# Copyright (c) 202
2
, Tri Dao.
# Copyright (c) 202
3
, Tri Dao.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.distributed
import
ProcessGroup
try
:
from
flash_attn.ops.activations
import
swiglu
except
ImportError
:
swiglu
=
None
try
:
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
,
RowParallelLinear
except
ImportError
:
...
...
@@ -120,6 +126,9 @@ class GatedMlp(nn.Module):
y
=
self
.
fc1
(
x
)
if
self
.
activation
==
F
.
sigmoid
:
# Special case for GLU
y
=
F
.
glu
(
y
,
dim
=-
1
)
elif
self
.
activation
==
F
.
silu
and
swiglu
is
not
None
:
# Special case for SwiGLU
y
,
gate
=
y
.
chunk
(
2
,
dim
=-
1
)
y
=
swiglu
(
gate
,
y
)
else
:
y
,
gate
=
y
.
chunk
(
2
,
dim
=-
1
)
y
=
y
*
self
.
activation
(
gate
)
...
...
This diff is collapsed.
Click to expand it.
flash_attn/ops/activations.py
View file @
3557e0bb
...
...
@@ -102,3 +102,34 @@ def sqrelu_fwd(x):
@
torch
.
jit
.
script
def
sqrelu_bwd
(
g
,
x
):
return
(
2.0
*
g
*
F
.
relu
(
x
)).
to
(
dtype
=
x
.
dtype
)
swiglu_fwd_codestring
=
"""
template <typename T> T swiglu_fwd(T x, T y) {
return float(x) * float(y) / (1.0f + ::exp(-float(x)));
}
"""
swiglu_bwd_codestring
=
"""
template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
dy = float(x) * x_sigmoid * float(g);
}
"""
swiglu_fwd
=
torch
.
cuda
.
jiterator
.
_create_jit_fn
(
swiglu_fwd_codestring
)
swiglu_bwd
=
torch
.
cuda
.
jiterator
.
_create_multi_output_jit_fn
(
swiglu_bwd_codestring
,
num_outputs
=
2
)
class
SwiGLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
y
):
ctx
.
save_for_backward
(
x
,
y
)
return
swiglu_fwd
(
x
,
y
)
@
staticmethod
def
backward
(
ctx
,
dout
):
x
,
y
=
ctx
.
saved_tensors
return
swiglu_bwd
(
x
,
y
,
dout
)
swiglu
=
SwiGLUFunction
.
apply
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment
Menu
Projects
Groups
Snippets
Help