Skip to content

Commit 9394793

Browse files
fixes ci checks
1 parent ff143bc commit 9394793

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

src/transformers/models/jais2/configuration_jais2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from ...configuration_utils import PretrainedConfig
2+
3+
24
# from ...modeling_rope_utils import rope_config_validation, standardize_rope_params
35

46

@@ -171,7 +173,6 @@ def __init__(
171173
self.attention_dropout = attention_dropout
172174
self.mlp_bias = mlp_bias
173175
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
174-
175176
# Set up rope_parameters from rope_scaling
176177
self.rope_parameters = rope_scaling
177178

src/transformers/models/jais2/modeling_jais2.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,7 @@ def forward(
918918
output_hidden_states: Optional[bool] = None,
919919
return_dict: Optional[bool] = None,
920920
cache_position: Optional[torch.LongTensor] = None,
921+
**kwargs,
921922
) -> Union[tuple, BaseModelOutputWithPast]:
922923
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
923924
output_hidden_states = (
@@ -1151,7 +1152,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
11511152

11521153
class Jais2ForCausalLM(Jais2PreTrainedModel, GenerationMixin):
11531154
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
1154-
# _tied_weights_keys = ["lm_head.weight"]
11551155

11561156
def __init__(self, config):
11571157
super().__init__(config)
@@ -1321,6 +1321,7 @@ def forward(
13211321
output_attentions: Optional[bool] = None,
13221322
output_hidden_states: Optional[bool] = None,
13231323
return_dict: Optional[bool] = None,
1324+
**kwargs,
13241325
) -> Union[tuple, SequenceClassifierOutputWithPast]:
13251326
r"""
13261327
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1366,7 +1367,7 @@ def forward(
13661367

13671368
loss = None
13681369
if labels is not None:
1369-
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
1370+
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config, **kwargs)
13701371

13711372
if not return_dict:
13721373
output = (pooled_logits,) + transformer_outputs[1:]
@@ -1516,6 +1517,7 @@ def forward(
15161517
output_attentions: Optional[bool] = None,
15171518
output_hidden_states: Optional[bool] = None,
15181519
return_dict: Optional[bool] = None,
1520+
**kwargs,
15191521
) -> Union[tuple, TokenClassifierOutput]:
15201522
r"""
15211523
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1542,7 +1544,7 @@ def forward(
15421544

15431545
loss = None
15441546
if labels is not None:
1545-
loss = self.loss_function(logits=logits, labels=labels, config=self.config)
1547+
loss = self.loss_function(logits=logits, labels=labels, config=self.config, **kwargs)
15461548

15471549
if not return_dict:
15481550
output = (logits,) + outputs[2:]

0 commit comments

Comments
 (0)