Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hf_transformers
Submodule hf_transformers updated 3494 files
13 changes: 3 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
"timeout-decorator",
"torch",
"torchvision",
"transformers~=4.52.4",
"transformers~=4.56.0",
]


Expand All @@ -71,12 +71,7 @@
# packaging: "packaging"
#
# some of the values are versioned whereas others aren't.
deps = {
b: a
for a, b in (
re.findall(r"^(([^!=<>~ ]+)(?:[!=<>~ ].*)?$)", x)[0] for x in _deps
)
}
deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ ]+)(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)}


def deps_list(*pkgs):
Expand Down Expand Up @@ -114,9 +109,7 @@ def deps_list(*pkgs):
"torchvision",
)

extras["quality"] = deps_list(
"black", "datasets", "isort", "flake8", "GitPython"
)
extras["quality"] = deps_list("black", "datasets", "isort", "flake8", "GitPython")

extras["docs"] = deps_list(
"docutils",
Expand Down
184 changes: 95 additions & 89 deletions src/adapters/models/xlm_roberta/modeling_xlm_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
XLMRobertaSelfAttention,
XLMRobertaSelfOutput,
)
from transformers.cache_utils import Cache, EncoderDecoderCache
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils import logging

from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
Expand All @@ -39,61 +41,66 @@


# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->XLMRoberta
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
class XLMRobertaSelfAttentionWithAdapters(BertSelfAttentionAdaptersMixin, XLMRobertaSelfAttention):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
attention_mask = prefix_attention_mask(attention_mask) # type: ignore

mixed_query_layer = self.query(hidden_states)

# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
cache_position: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor]:
batch_size, seq_length, _ = hidden_states.shape
query_layer = self.query(hidden_states)
query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
1, 2
)
is_cross_attention = encoder_hidden_states is not None
if past_key_values is not None:
if isinstance(past_key_values, EncoderDecoderCache):
is_updated = past_key_values.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_layer from cache
curr_past_key_value = past_key_values.cross_attention_cache
else:
curr_past_key_value = past_key_values.self_attention_cache
else:
curr_past_key_value = past_key_values

if is_cross_attention and past_key_value is not None:
current_states = encoder_hidden_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_values is not None and is_updated:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
key_layer = curr_past_key_value.layers[self.layer_idx].keys
value_layer = curr_past_key_value.layers[self.layer_idx].values
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = self.key(current_states)
key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
1, 2
)
value_layer = self.value(current_states)
value_layer = value_layer.view(
batch_size, -1, self.num_attention_heads, self.attention_head_size
).transpose(1, 2)

if past_key_values is not None:
# save all key/value_layer to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_layer, value_layer = curr_past_key_value.update(
key_layer, value_layer, self.layer_idx, {"cache_position": cache_position}
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
if is_cross_attention:
past_key_values.is_updated[self.layer_idx] = True

query_layer = self.transpose_for_scores(mixed_query_layer)
# >>> START AH Changes <<<
query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer)
(attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask)
# >>> END AH Changes <<<

use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)

# TODO - what to do with this?
# >>> START AH Changes <<<
key_layer, value_layer, attention_mask = self.prefix_tuning(
key_layer, value_layer, hidden_states, attention_mask
Expand All @@ -106,7 +113,7 @@ def forward(

if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
if past_key_values is not None:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
)
Expand Down Expand Up @@ -148,23 +155,20 @@ def forward(
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
return context_layer, attention_probs


class XLMRobertaSdpaSelfAttentionWithAdapters(BertSelfAttentionAdaptersMixin, XLMRobertaSdpaSelfAttention):
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
cache_position: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor]:
# >>> START AH Changes <<<
attention_mask = prefix_attention_mask(attention_mask, [2, 3]) # type: ignore
Expand All @@ -184,46 +188,67 @@ def forward(
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
past_key_values,
output_attentions,
cache_position,
)

bsz, tgt_len, _ = hidden_states.size()

query_layer = self.transpose_for_scores(self.query(hidden_states))
query_layer = (
self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
)

# If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
# mask needs to be such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None

current_states = encoder_hidden_states if is_cross_attention else hidden_states
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
if past_key_values is not None:
if isinstance(past_key_values, EncoderDecoderCache):
is_updated = past_key_values.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
curr_past_key_value = past_key_values.cross_attention_cache
else:
curr_past_key_value = past_key_values.self_attention_cache
else:
curr_past_key_value = past_key_values

# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
key_layer, value_layer = past_key_value
current_states = encoder_hidden_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_values is not None and is_updated:
# reuse k,v, cross_attentions
key_layer = curr_past_key_value.layers[self.layer_idx].keys
value_layer = curr_past_key_value.layers[self.layer_idx].values
else:
key_layer = self.transpose_for_scores(self.key(current_states))
value_layer = self.transpose_for_scores(self.value(current_states))
if past_key_value is not None and not is_cross_attention:
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
key_layer = (
self.key(current_states)
.view(bsz, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
value_layer = (
self.value(current_states)
.view(bsz, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)

if past_key_values is not None:
# save all key/value_layer to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_layer, value_layer = curr_past_key_value.update(
key_layer, value_layer, self.layer_idx, {"cache_position": cache_position}
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
if is_cross_attention:
past_key_values.is_updated[self.layer_idx] = True

# >>> START AH Changes <<<
query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer)
(attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask)
# >>> END AH Changes <<<

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
# a causal mask in case tgt_len == 1.
is_causal = self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1

# >>> START AH Changes <<<
key_layer, value_layer, attention_mask = self.prefix_tuning(
Expand All @@ -233,22 +258,6 @@ def forward(
bsz = query_layer.size(0)
# >>> END AH Changes <<<

# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()

# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
# a causal mask in case tgt_len == 1.
is_causal = (
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
)

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
Expand All @@ -261,10 +270,7 @@ def forward(
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)

outputs = (attn_output,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
return attn_output, None


# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput with Roberta->XLMRoberta
Expand Down
Loading