-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Closed
Labels
Description
System Info
transformersversion: 4.57.1- Platform: Linux-6.8.0-1043-aws-x86_64-with-glibc2.35
- Python version: 3.11.10
- Huggingface_hub version: 0.34.3
- Safetensors version: 0.4.3
- Accelerate version: 1.10.1
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (accelerator?): 2.7.1+cu126 (CUDA)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?:
- Using GPU in script?:
- GPU type: NVIDIA L40S
Who can help?
Is Gemma3ForConditionalGeneration expected to work with torch compile? I am trying to speed up inference following this page and I'm getting errors. Here is the script I'm using
Not sure if there's anything that I'm doing obviously wrong. Thanks!
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Running this python script:
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :)
model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it", dtype=torch.bfloat16, device_map="auto")
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
processor = AutoProcessor.from_pretrained(
"google/gemma-3-4b-it",
padding_side="left"
)
messages = [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."}
]
},
{
"role": "user", "content": [
{"type": "text", "text": "What is shown in this image?"},
]
},
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
outputs = model.generate(**inputs, max_new_tokens=100, cache_implementation="static")
print(processor.tokenizer.batch_decode(outputs, skip_special_tokens=True))and I'm getting the following error
skipping cudagraphs due to mutated inputs (68 instances). Found from :
File "/home/coder/work/canva/tools/build/python/third_party/.venv/lib/python3.11/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 1100, in forward
outputs = self.model(
File "/home/coder/work/canva/tools/build/python/third_party/.venv/lib/python3.11/site-packages/transformers/utils/generic.py", line 918, in wrapper
output = func(self, *args, **kwargs)
File "/home/coder/work/canva/tools/build/python/third_party/.venv/lib/python3.11/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 957, in forward
outputs = self.language_model(
File "/home/coder/work/canva/tools/build/python/third_party/.venv/lib/python3.11/site-packages/transformers/utils/generic.py", line 1064, in wrapper
outputs = func(self, *args, **kwargs)
File "/home/coder/work/canva/tools/build/python/third_party/.venv/lib/python3.11/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 570, in forward
layer_outputs = decoder_layer(
File "/home/coder/work/canva/tools/build/python/third_party/.venv/lib/python3.11/site-packages/transformers/modeling_layers.py", line 94, in __call__
return super().__call__(*args, **kwargs)
File "/home/coder/work/canva/tools/build/python/third_party/.venv/lib/python3.11/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
return func(*args, **kwargs)
File "/home/coder/work/canva/tools/build/python/third_party/.venv/lib/python3.11/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 382, in forward
hidden_states, self_attn_weights = self.self_attn(
File "/home/coder/work/canva/tools/build/python/third_party/.venv/lib/python3.11/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
return func(*args, **kwargs)
File "/home/coder/work/canva/tools/build/python/third_party/.venv/lib/python3.11/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 321, in forward
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
File "/home/coder/work/canva/tools/build/python/third_party/.venv/lib/python3.11/site-packages/transformers/cache_utils.py", line 776, in update
keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
File "/home/coder/work/canva/tools/build/python/third_party/.venv/lib/python3.11/site-packages/transformers/cache_utils.py", line 443, in update
self.keys.index_copy_(2, cache_position, key_states)
[1] 359729 segmentation fault (core dumped) ipython
Expected behavior
I thought it would work without error