Skip to content

[Bug]: InternVL-8B quantize failed with error: AttributeError: 'dict' object has no attribute 'last_hidden_state' #2056

@BigFaceBoy

Description

@BigFaceBoy

⚙️ Your current environment

The output of python collect_env.py
Operating System: `Linux-5.4.0-144-generic-x86_64-with-glibc2.35`
Python Version: `3.12.4 | packaged by Anaconda, Inc. | (main, Jun 18 2024, 15:12:24) [GCC 11.2.0]`
llm-compressor Version: `0.8.1`
compressed-tensors Version: `0.12.2`
transformers Version: `4.57.1`
torch Version: `2.8.0`
CUDA Devices: `['NVIDIA L40']`
AMD Devices: `None`

🐛 Describe the bug

I try to quantize InternVL3-8B.

  • 1、download the chat_template.jinja from InternVL3_5-8B and place the file in the local dir of InternVL3-8B
  • 2、replace the int() in modeling_internvl_chat.py with math.floor()
  • 3、run the script below
import torch
from torchvision.transforms.functional import InterpolationMode
import torchvision.transforms as T
from PIL import Image
from datasets import load_dataset
from transformers import AutoModel, AutoTokenizer
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
# Load model.
model_id = "/root/workspace/models/InternVL3-8B"
model = AutoModel.from_pretrained(model_id, torch_dtype=torch.bfloat16, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=False)

# Load datasets
DATASET_ID = "/root/workspace/datasets/flickr30k"
DATASET_SPLIT = {"calibration": "test[:512]"}
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42)

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def load_image(image_file, input_size=448, max_num=12):
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values
#### copy from https://hf-mirror.com/OpenGVLab/InternVL3-8B#inference-with-transformers end ####


def load_image_from_PIL(image_obj, input_size=448, max_num=12):
    image = image_obj.convert('RGB')
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values

def preprocess_and_tokenize(example):
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image", 
                    "image": ""
                },
                {
                    "type": "text", 
                    "text":  "What does the image show?"
                },
            ],
        }
    ]
    text = tokenizer.apply_chat_template(messages)
    image_flags = torch.ones(1, dtype=torch.long)

    example["input_ids"] = text
    example['image_flags'] = image_flags
    return example

ds = ds.map(preprocess_and_tokenize)

def data_collator(batch):
    assert len(batch) == 1
    item = {key: value for key, value in batch[0].items()}
    item["pixel_values"] = load_image_from_PIL(item["image"]).to(torch.bfloat16)
    item["input_ids"] = torch.LongTensor([item["input_ids"]])
    item["labels"] = item["input_ids"].clone()
    item["image_flags"] = torch.LongTensor([item["image_flags"]])
    return item



# Recipe
recipe = GPTQModifier(
        targets="Linear", 
        scheme="FP8", 
        ignore=["re:.*lm_head","re:mlp1.*", "re:.*vision_model.*"]
    )


# Perform oneshot
oneshot(
    model=model,
    tokenizer=model_id,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
    trust_remote_code_model=True,
    data_collator=data_collator
)


# Save to disk compressed.
SAVE_DIR = "/root/workspace/models/InternVL3-8B-FP8-GPTQ"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)

But I got the error:

Traceback (most recent call last):
  File "/root/workspace/src/w8a8/ezcloud_fp8.py", line 153, in <module>
    oneshot(
  File "/root/workspace/src/llmcompressor/src/llmcompressor/entrypoints/oneshot.py", line 348, in oneshot
    one_shot()
  File "/root/workspace/src/llmcompressor/src/llmcompressor/entrypoints/oneshot.py", line 172, in __call__
    self.apply_recipe_modifiers(
  File "/root/workspace/src/llmcompressor/src/llmcompressor/entrypoints/oneshot.py", line 220, in apply_recipe_modifiers
    pipeline(
  File "/root/workspace/src/llmcompressor/src/llmcompressor/pipelines/independent/pipeline.py", line 45, in __call__
    pipeline(model, dataloader, dataset_args)
  File "/root/workspace/src/llmcompressor/src/llmcompressor/pipelines/sequential/pipeline.py", line 105, in __call__
    subgraph.forward(model, **inputs)
  File "/root/workspace/src/llmcompressor/src/llmcompressor/pipelines/sequential/helpers.py", line 74, in forward
    return forward_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 13, in forward
AttributeError: 'dict' object has no attribute 'last_hidden_state'

🛠️ Steps to reproduce

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions