Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/transformers/integrations/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,18 @@ def flash_attention_forward(
else:
target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype

# FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
kwargs.pop("is_causal", None)
# Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
is_causal = kwargs.pop("is_causal", None)
if is_causal is None:
is_causal = module.is_causal

attn_output = _flash_attention_forward(
query,
key,
value,
attention_mask,
query_length=seq_len,
is_causal=module.is_causal,
is_causal=is_causal,
dropout=dropout,
softmax_scale=scaling,
sliding_window=sliding_window,
Expand Down