Skip to content

The forward() Method in ModernBertForTokenClassification is missing **kwargs #41911

@SinaDBMS

Description

@SinaDBMS

System Info

I'm using peft.LoraConfig() to fine tune ModernBertForTokenClassification. According to the trail of the exception i'm getting, the forward() method of ModernBertForTokenClassification is missing **kwargs:

│ /home/gsgs2tk/ADA_ModelingFramework/.venv/lib/python3.11/site-packages/peft/tuners/tuners_utils.py:222 in forward                                                     │
│                                                                                                                                                                       │
│    219 │   │   return self.active_adapter                                                                                                                             │
│    220 │                                                                                                                                                              │
│    221 │   def forward(self, *args: Any, **kwargs: Any):                                                                                                              │
│ ❱  222 │   │   return self.model.forward(*args, **kwargs)                                                                                                             │
│    223 │                                                                                                                                                              │
│    224 │   def _pre_injection_hook(self, model: nn.Module, config: PeftConfig, adapter_name: st                                                                       │
│    225 │   │   r"""                                                                                                                                                   │
│                                                                                                                                                                       │
│ ╭────────────────────────────────────────── locals ───────────────────────────────────────────╮                                                                       │
│ │   args = ()                                                                                 │                                                                       │
│ │ kwargs = {                                                                                  │                                                                       │
│ │          │   'input_ids': tensor([[  102,  1170,   853,  ...,     0,     0,     0],         │                                                                       │
│ │          │   │   [  102, 16567,  1774,  ...,     0,     0,     0]], device='cuda:0'),       │                                                                       │
│ │          │   'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],                            │                                                                       │
│ │          │   │   [1, 1, 1,  ..., 1, 1, 1]], device='cuda:0'),                               │                                                                       │
│ │          │   'inputs_embeds': None,                                                         │                                                                       │
│ │          │   'labels': tensor([[-100,    0,    0,  ..., -100, -100, -100],                  │                                                                       │
│ │          │   │   [-100,    0,    0,  ..., -100, -100, -100]], device='cuda:0'),             │                                                                       │
│ │          │   'output_attentions': None,                                                     │                                                                       │
│ │          │   'output_hidden_states': None,                                                  │                                                                       │
│ │          │   'return_dict': True,                                                           │                                                                       │
│ │          │   'use_cache': False                                                             │                                                                       │
│ │          }                                                                                  │                                                                       │
│ │   self = LoraModel(                                                                         │                                                                       │
│ │            (model): ModernBertForTokenClassification(                                       │                                                                       │
│ │          │   (model): ModernBertModel(                                                      │                                                                       │
│ │          │     (embeddings): ModernBertEmbeddings(                                          │                                                                       │
│ │          │   │   (tok_embeddings): Embedding(31103, 768, padding_idx=0)                     │                                                                       │
│ │          │   │   (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)              │                                                                       │
│ │          │   │   (drop): Dropout(p=0.0, inplace=False)                                      │                                                                       │
│ │          │     )                                                                            │                                                                       │
│ │          │     (layers): ModuleList(                                                        │                                                                       │
│ │          │   │   (0): ModernBertEncoderLayer(                                               │                                                                       │
│ │          │   │     (attn_norm): Identity()                                                  │                                                                       │
│ │          │   │     (attn): ModernBertAttention(                                             │                                                                       │
│ │          │   │   │   (Wqkv): lora.Linear(                                                   │                                                                       │
│ │          │   │   │     (base_layer): Linear(in_features=768, out_features=2304, bias=False) │                                                                       │
│ │          │   │   │     (lora_dropout): ModuleDict(                                          │                                                                       │
│ │          │   │   │   │   (default): Dropout(p=0.1, inplace=False)                           │                                                                       │
│ │          │   │   │     )                                                                    │                                                                       │
│ │          │   │   │     (lora_A): ModuleDict(                                                │                                                                       │
│ │          │   │   │   │   (default): Linear(in_features=768, out_features=8, bias=False)     │                                                                       │
│ │          │   │   │     )                                                                    │                                                                       │
│ │          │   │   │     (lora_B): ModuleDict(                                                │                                                                       │
│ │          │   │   │   │   (default): Linear(in_features=8, out_features=2304, bias=False)    │                                                                       │
│ │          │   │   │     )                                                                    │                                                                       │
│ │          │   │   │     (lora_embedding_A): ParameterDict()                                  │                                                                       │
│ │          │   │   │     (lora_embedding_B): ParameterDict()                                  │                                                                       │
│ │          │   │   │     (lora_magnitude_vector): ModuleDict()                                │                                                                       │
│ │          │   │   │   )                                                                      │                                                                       │
│ │          │   │   │   (rotary_emb): ModernBertRotaryEmbedding()                              │                                                                       │
│ │          │   │   │   (Wo): Linear(in_features=768, out_features=768, bias=False)            │                                                                       │
│ │          │   │   │   (out_drop): Identity()                                                 │                                                                       │
│ │          │   │     )                                                                        │                                                                       │
│ │          │   │     (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)        │                                                                       │
│ │          │   │     (mlp): ModernBertMLP(                                                    │                                                                       │
│ │          │   │   │   (Wi): Linear(in_features=768, out_features=2304, bias=False)           │                                                                       │
│ │          │   │   │   (act): GELUActivation()                                                │                                                                       │
│ │          │   │   │   (drop): Dropout(p=0.0, inplace=False)                                  │                                                                       │
│ │          │   │   │   (Wo): Linear(in_features=1152, out_features=768, bias=False)           │                                                                       │
│ │          │   │     )                                                                        │                                                                       │
│ │          │   │   )                                                                          │                                                                       │
│ │          │   │   (1-21): 21 x ModernBertEncoderLayer(                                       │                                                                       │
│ │          │   │     (attn_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)       │                                                                       │
│ │          │   │     (attn): ModernBertAttention(                                             │                                                                       │
│ │          │   │   │   (Wqkv): lora.Linear(                                                   │                                                                       │
│ │          │   │   │     (base_layer): Linear(in_features=768, out_features=2304, bias=False) │                                                                       │
│ │          │   │   │     (lora_dropout): ModuleDict(                                          │                                                                       │
│ │          │   │   │   │   (default): Dropout(p=0.1, inplace=False)                           │                                                                       │
│ │          │   │   │     )                                                                    │                                                                       │
│ │          │   │   │     (lora_A): ModuleDict(                                                │                                                                       │
│ │          │   │   │   │   (default): Linear(in_features=768, out_features=8, bias=False)     │                                                                       │
│ │          │   │   │     )                                                                    │                                                                       │
│ │          │   │   │     (lora_B): ModuleDict(                                                │                                                                       │
│ │          │   │   │   │   (default): Linear(in_features=8, out_features=2304, bias=False)    │                                                                       │
│ │          │   │   │     )                                                                    │                                                                       │
│ │          │   │   │     (lora_embedding_A): ParameterDict()                                  │                                                                       │
│ │          │   │   │     (lora_embedding_B): ParameterDict()                                  │                                                                       │
│ │          │   │   │     (lora_magnitude_vector): ModuleDict()                                │                                                                       │
│ │          │   │   │   )                                                                      │                                                                       │
│ │          │   │   │   (rotary_emb): ModernBertRotaryEmbedding()                              │                                                                       │
│ │          │   │   │   (Wo): Linear(in_features=768, out_features=768, bias=False)            │                                                                       │
│ │          │   │   │   (out_drop): Identity()                                                 │                                                                       │
│ │          │   │     )                                                                        │                                                                       │
│ │          │   │     (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)        │                                                                       │
│ │          │   │     (mlp): ModernBertMLP(                                                    │                                                                       │
│ │          │   │   │   (Wi): Linear(in_features=768, out_features=2304, bias=False)           │                                                                       │
│ │          │   │   │   (act): GELUActivation()                                                │                                                                       │
│ │          │   │   │   (drop): Dropout(p=0.0, inplace=False)                                  │                                                                       │
│ │          │   │   │   (Wo): Linear(in_features=1152, out_features=768, bias=False)           │                                                                       │
│ │          │   │     )                                                                        │                                                                       │
│ │          │   │   )                                                                          │                                                                       │
│ │          │     )                                                                            │                                                                       │
│ │          │     (final_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)          │                                                                       │
│ │          │   )                                                                              │                                                                       │
│ │          │   (head): ModernBertPredictionHead(                                              │                                                                       │
│ │          │     (dense): Linear(in_features=768, out_features=768, bias=False)               │                                                                       │
│ │          │     (act): GELUActivation()                                                      │                                                                       │
│ │          │     (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)                │                                                                       │
│ │          │   )                                                                              │                                                                       │
│ │          │   (drop): Dropout(p=0.0, inplace=False)                                          │                                                                       │
│ │          │   (classifier): ModulesToSaveWrapper(                                            │                                                                       │
│ │          │     (original_module): Linear(in_features=768, out_features=13, bias=True)       │                                                                       │
│ │          │     (modules_to_save): ModuleDict(                                               │                                                                       │
│ │          │   │   (default): Linear(in_features=768, out_features=13, bias=True)             │                                                                       │
│ │          │     )                                                                            │                                                                       │
│ │          │   )                                                                              │                                                                       │
│ │            )                                                                                │                                                                       │
│ │          )                                                                                  │                                                                       │
│ ╰─────────────────────────────────────────────────────────────────────────────────────────────╯                                                                       │
╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: ModernBertForTokenClassification.forward() got an unexpected keyword argument 'use_cache'

Fix:

Just add **kwargs to the method signature:

@auto_docstring(
    custom_intro="""
    The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.
    """
)
class ModernBertForTokenClassification(ModernBertPreTrainedModel):
    def __init__(self, config: ModernBertConfig):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.model = ModernBertModel(config)
        self.head = ModernBertPredictionHead(config)
        self.drop = torch.nn.Dropout(config.classifier_dropout)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        # Initialize weights and apply final processing
        self.post_init()

    @auto_docstring
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        sliding_window_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        indices: Optional[torch.Tensor] = None,
        cu_seqlens: Optional[torch.Tensor] = None,
        max_seqlen: Optional[int] = None,
        batch_size: Optional[int] = None,
        seq_len: Optional[int] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
# Rest of the code...

Output of my transformers env:

- `transformers` version: 4.57.1
- Platform: Linux-6.8.0-1039-aws-x86_64-with-glibc2.35
- Python version: 3.11.0rc1
- Huggingface_hub version: 0.35.3
- Safetensors version: 0.6.2
- Accelerate version: 1.11.0
- Accelerate config:    not found
- DeepSpeed version: not installed
- PyTorch version (accelerator?): 2.9.0+cu128 (CUDA)
- Tensorflow version (GPU?): 2.20.0 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: <fill in>
- Using GPU in script?: <fill in>
- GPU type: NVIDIA A10G

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Having followed this blog https://www.mohammedsbaihi.com/blog/modernbert.html but for ModernBertForTokenClassification

Expected behavior

Code should run without the aforementioned exception.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions