Skip to content

Commit 01823d7

Browse files
authored
[CI] Fix copies (#42487)
copies
1 parent e51e75e commit 01823d7

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

src/transformers/models/nanochat/modeling_nanochat.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ...activations import ACT2FN
3131
from ...cache_utils import Cache, DynamicCache
3232
from ...generation import GenerationMixin
33+
from ...integrations import use_kernel_func_from_hub
3334
from ...masking_utils import create_causal_mask
3435
from ...modeling_layers import GradientCheckpointingLayer
3536
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -121,6 +122,7 @@ def forward(self, x, position_ids):
121122
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
122123

123124

125+
@use_kernel_func_from_hub("rotary_pos_emb")
124126
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
125127
"""Applies Rotary Position Embedding to the query and key tensors.
126128
@@ -218,6 +220,7 @@ def __init__(self, config: NanoChatConfig, layer_idx: int):
218220
self.o_proj = nn.Linear(
219221
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
220222
)
223+
self.rotary_fn = apply_rotary_pos_emb
221224

222225
self.q_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
223226
self.k_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)

0 commit comments

Comments
 (0)