Skip to content

Commit f90ec0e

Browse files
committedMar 13, 2025
Add --use-flash-attention flag.
This is useful on AMD systems, as FA builds are still 10% faster than Pytorch cross-attention.
1 parent 35504e2 commit f90ec0e

File tree

3 files changed

+64
-0
lines changed

3 files changed

+64
-0
lines changed
 

‎comfy/cli_args.py

+1
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ class LatentPreviewMethod(enum.Enum):
106106
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
107107
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
108108
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
109+
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
109110

110111
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
111112

‎comfy/ldm/modules/attention.py

+60
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@
2424
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
2525
exit(-1)
2626

27+
if model_management.flash_attention_enabled():
28+
try:
29+
from flash_attn import flash_attn_func
30+
except ModuleNotFoundError:
31+
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
32+
exit(-1)
33+
2734
from comfy.cli_args import args
2835
import comfy.ops
2936
ops = comfy.ops.disable_weight_init
@@ -496,6 +503,56 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
496503
return out
497504

498505

506+
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
507+
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
508+
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
509+
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
510+
511+
512+
@flash_attn_wrapper.register_fake
513+
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
514+
# Output shape is the same as q
515+
return q.new_empty(q.shape)
516+
517+
518+
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
519+
if skip_reshape:
520+
b, _, _, dim_head = q.shape
521+
else:
522+
b, _, dim_head = q.shape
523+
dim_head //= heads
524+
q, k, v = map(
525+
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
526+
(q, k, v),
527+
)
528+
529+
if mask is not None:
530+
# add a batch dimension if there isn't already one
531+
if mask.ndim == 2:
532+
mask = mask.unsqueeze(0)
533+
# add a heads dimension if there isn't already one
534+
if mask.ndim == 3:
535+
mask = mask.unsqueeze(1)
536+
537+
try:
538+
assert mask is None
539+
out = flash_attn_wrapper(
540+
q.transpose(1, 2),
541+
k.transpose(1, 2),
542+
v.transpose(1, 2),
543+
dropout_p=0.0,
544+
causal=False,
545+
).transpose(1, 2)
546+
except Exception as e:
547+
logging.warning("Flash Attention failed, using default SDPA: {e}")
548+
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
549+
if not skip_output_reshape:
550+
out = (
551+
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
552+
)
553+
return out
554+
555+
499556
optimized_attention = attention_basic
500557

501558
if model_management.sage_attention_enabled():
@@ -504,6 +561,9 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
504561
elif model_management.xformers_enabled():
505562
logging.info("Using xformers attention")
506563
optimized_attention = attention_xformers
564+
elif model_management.flash_attention_enabled():
565+
logging.info("Using Flash Attention")
566+
optimized_attention = attention_flash
507567
elif model_management.pytorch_attention_enabled():
508568
logging.info("Using pytorch attention")
509569
optimized_attention = attention_pytorch

‎comfy/model_management.py

+3
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,9 @@ def cast_to_device(tensor, device, dtype, copy=False):
930930
def sage_attention_enabled():
931931
return args.use_sage_attention
932932

933+
def flash_attention_enabled():
934+
return args.use_flash_attention
935+
933936
def xformers_enabled():
934937
global directml_enabled
935938
global cpu_state

0 commit comments

Comments
 (0)