@@ -918,6 +918,7 @@ def forward(
918918 output_hidden_states : Optional [bool ] = None ,
919919 return_dict : Optional [bool ] = None ,
920920 cache_position : Optional [torch .LongTensor ] = None ,
921+ ** kwargs ,
921922 ) -> Union [tuple , BaseModelOutputWithPast ]:
922923 output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
923924 output_hidden_states = (
@@ -1151,7 +1152,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
11511152
11521153class Jais2ForCausalLM (Jais2PreTrainedModel , GenerationMixin ):
11531154 _tied_weights_keys = {"lm_head.weight" : "model.embed_tokens.weight" }
1154- # _tied_weights_keys = ["lm_head.weight"]
11551155
11561156 def __init__ (self , config ):
11571157 super ().__init__ (config )
@@ -1321,6 +1321,7 @@ def forward(
13211321 output_attentions : Optional [bool ] = None ,
13221322 output_hidden_states : Optional [bool ] = None ,
13231323 return_dict : Optional [bool ] = None ,
1324+ ** kwargs ,
13241325 ) -> Union [tuple , SequenceClassifierOutputWithPast ]:
13251326 r"""
13261327 labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1366,7 +1367,7 @@ def forward(
13661367
13671368 loss = None
13681369 if labels is not None :
1369- loss = self .loss_function (logits = logits , labels = labels , pooled_logits = pooled_logits , config = self .config )
1370+ loss = self .loss_function (logits = logits , labels = labels , pooled_logits = pooled_logits , config = self .config , ** kwargs )
13701371
13711372 if not return_dict :
13721373 output = (pooled_logits ,) + transformer_outputs [1 :]
@@ -1516,6 +1517,7 @@ def forward(
15161517 output_attentions : Optional [bool ] = None ,
15171518 output_hidden_states : Optional [bool ] = None ,
15181519 return_dict : Optional [bool ] = None ,
1520+ ** kwargs ,
15191521 ) -> Union [tuple , TokenClassifierOutput ]:
15201522 r"""
15211523 labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1542,7 +1544,7 @@ def forward(
15421544
15431545 loss = None
15441546 if labels is not None :
1545- loss = self .loss_function (logits = logits , labels = labels , config = self .config )
1547+ loss = self .loss_function (logits = logits , labels = labels , config = self .config , ** kwargs )
15461548
15471549 if not return_dict :
15481550 output = (logits ,) + outputs [2 :]
0 commit comments