-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Open
Labels
Description
System Info
I'm using peft.LoraConfig() to fine tune ModernBertForTokenClassification. According to the trail of the exception i'm getting, the forward() method of ModernBertForTokenClassification is missing **kwargs:
│ /home/gsgs2tk/ADA_ModelingFramework/.venv/lib/python3.11/site-packages/peft/tuners/tuners_utils.py:222 in forward │
│ │
│ 219 │ │ return self.active_adapter │
│ 220 │ │
│ 221 │ def forward(self, *args: Any, **kwargs: Any): │
│ ❱ 222 │ │ return self.model.forward(*args, **kwargs) │
│ 223 │ │
│ 224 │ def _pre_injection_hook(self, model: nn.Module, config: PeftConfig, adapter_name: st │
│ 225 │ │ r""" │
│ │
│ ╭────────────────────────────────────────── locals ───────────────────────────────────────────╮ │
│ │ args = () │ │
│ │ kwargs = { │ │
│ │ │ 'input_ids': tensor([[ 102, 1170, 853, ..., 0, 0, 0], │ │
│ │ │ │ [ 102, 16567, 1774, ..., 0, 0, 0]], device='cuda:0'), │ │
│ │ │ 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1], │ │
│ │ │ │ [1, 1, 1, ..., 1, 1, 1]], device='cuda:0'), │ │
│ │ │ 'inputs_embeds': None, │ │
│ │ │ 'labels': tensor([[-100, 0, 0, ..., -100, -100, -100], │ │
│ │ │ │ [-100, 0, 0, ..., -100, -100, -100]], device='cuda:0'), │ │
│ │ │ 'output_attentions': None, │ │
│ │ │ 'output_hidden_states': None, │ │
│ │ │ 'return_dict': True, │ │
│ │ │ 'use_cache': False │ │
│ │ } │ │
│ │ self = LoraModel( │ │
│ │ (model): ModernBertForTokenClassification( │ │
│ │ │ (model): ModernBertModel( │ │
│ │ │ (embeddings): ModernBertEmbeddings( │ │
│ │ │ │ (tok_embeddings): Embedding(31103, 768, padding_idx=0) │ │
│ │ │ │ (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) │ │
│ │ │ │ (drop): Dropout(p=0.0, inplace=False) │ │
│ │ │ ) │ │
│ │ │ (layers): ModuleList( │ │
│ │ │ │ (0): ModernBertEncoderLayer( │ │
│ │ │ │ (attn_norm): Identity() │ │
│ │ │ │ (attn): ModernBertAttention( │ │
│ │ │ │ │ (Wqkv): lora.Linear( │ │
│ │ │ │ │ (base_layer): Linear(in_features=768, out_features=2304, bias=False) │ │
│ │ │ │ │ (lora_dropout): ModuleDict( │ │
│ │ │ │ │ │ (default): Dropout(p=0.1, inplace=False) │ │
│ │ │ │ │ ) │ │
│ │ │ │ │ (lora_A): ModuleDict( │ │
│ │ │ │ │ │ (default): Linear(in_features=768, out_features=8, bias=False) │ │
│ │ │ │ │ ) │ │
│ │ │ │ │ (lora_B): ModuleDict( │ │
│ │ │ │ │ │ (default): Linear(in_features=8, out_features=2304, bias=False) │ │
│ │ │ │ │ ) │ │
│ │ │ │ │ (lora_embedding_A): ParameterDict() │ │
│ │ │ │ │ (lora_embedding_B): ParameterDict() │ │
│ │ │ │ │ (lora_magnitude_vector): ModuleDict() │ │
│ │ │ │ │ ) │ │
│ │ │ │ │ (rotary_emb): ModernBertRotaryEmbedding() │ │
│ │ │ │ │ (Wo): Linear(in_features=768, out_features=768, bias=False) │ │
│ │ │ │ │ (out_drop): Identity() │ │
│ │ │ │ ) │ │
│ │ │ │ (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) │ │
│ │ │ │ (mlp): ModernBertMLP( │ │
│ │ │ │ │ (Wi): Linear(in_features=768, out_features=2304, bias=False) │ │
│ │ │ │ │ (act): GELUActivation() │ │
│ │ │ │ │ (drop): Dropout(p=0.0, inplace=False) │ │
│ │ │ │ │ (Wo): Linear(in_features=1152, out_features=768, bias=False) │ │
│ │ │ │ ) │ │
│ │ │ │ ) │ │
│ │ │ │ (1-21): 21 x ModernBertEncoderLayer( │ │
│ │ │ │ (attn_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) │ │
│ │ │ │ (attn): ModernBertAttention( │ │
│ │ │ │ │ (Wqkv): lora.Linear( │ │
│ │ │ │ │ (base_layer): Linear(in_features=768, out_features=2304, bias=False) │ │
│ │ │ │ │ (lora_dropout): ModuleDict( │ │
│ │ │ │ │ │ (default): Dropout(p=0.1, inplace=False) │ │
│ │ │ │ │ ) │ │
│ │ │ │ │ (lora_A): ModuleDict( │ │
│ │ │ │ │ │ (default): Linear(in_features=768, out_features=8, bias=False) │ │
│ │ │ │ │ ) │ │
│ │ │ │ │ (lora_B): ModuleDict( │ │
│ │ │ │ │ │ (default): Linear(in_features=8, out_features=2304, bias=False) │ │
│ │ │ │ │ ) │ │
│ │ │ │ │ (lora_embedding_A): ParameterDict() │ │
│ │ │ │ │ (lora_embedding_B): ParameterDict() │ │
│ │ │ │ │ (lora_magnitude_vector): ModuleDict() │ │
│ │ │ │ │ ) │ │
│ │ │ │ │ (rotary_emb): ModernBertRotaryEmbedding() │ │
│ │ │ │ │ (Wo): Linear(in_features=768, out_features=768, bias=False) │ │
│ │ │ │ │ (out_drop): Identity() │ │
│ │ │ │ ) │ │
│ │ │ │ (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) │ │
│ │ │ │ (mlp): ModernBertMLP( │ │
│ │ │ │ │ (Wi): Linear(in_features=768, out_features=2304, bias=False) │ │
│ │ │ │ │ (act): GELUActivation() │ │
│ │ │ │ │ (drop): Dropout(p=0.0, inplace=False) │ │
│ │ │ │ │ (Wo): Linear(in_features=1152, out_features=768, bias=False) │ │
│ │ │ │ ) │ │
│ │ │ │ ) │ │
│ │ │ ) │ │
│ │ │ (final_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) │ │
│ │ │ ) │ │
│ │ │ (head): ModernBertPredictionHead( │ │
│ │ │ (dense): Linear(in_features=768, out_features=768, bias=False) │ │
│ │ │ (act): GELUActivation() │ │
│ │ │ (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) │ │
│ │ │ ) │ │
│ │ │ (drop): Dropout(p=0.0, inplace=False) │ │
│ │ │ (classifier): ModulesToSaveWrapper( │ │
│ │ │ (original_module): Linear(in_features=768, out_features=13, bias=True) │ │
│ │ │ (modules_to_save): ModuleDict( │ │
│ │ │ │ (default): Linear(in_features=768, out_features=13, bias=True) │ │
│ │ │ ) │ │
│ │ │ ) │ │
│ │ ) │ │
│ │ ) │ │
│ ╰─────────────────────────────────────────────────────────────────────────────────────────────╯ │
╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: ModernBertForTokenClassification.forward() got an unexpected keyword argument 'use_cache'
Fix:
Just add **kwargs to the method signature:
@auto_docstring(
custom_intro="""
The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.
"""
)
class ModernBertForTokenClassification(ModernBertPreTrainedModel):
def __init__(self, config: ModernBertConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.model = ModernBertModel(config)
self.head = ModernBertPredictionHead(config)
self.drop = torch.nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
batch_size: Optional[int] = None,
seq_len: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs
# Rest of the code...
Output of my transformers env:
- `transformers` version: 4.57.1
- Platform: Linux-6.8.0-1039-aws-x86_64-with-glibc2.35
- Python version: 3.11.0rc1
- Huggingface_hub version: 0.35.3
- Safetensors version: 0.6.2
- Accelerate version: 1.11.0
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (accelerator?): 2.9.0+cu128 (CUDA)
- Tensorflow version (GPU?): 2.20.0 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: <fill in>
- Using GPU in script?: <fill in>
- GPU type: NVIDIA A10G
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Having followed this blog https://www.mohammedsbaihi.com/blog/modernbert.html but for ModernBertForTokenClassification
Expected behavior
Code should run without the aforementioned exception.