-
Notifications
You must be signed in to change notification settings - Fork 294
Open
Labels
bugSomething isn't workingSomething isn't working
Description
⚙️ Your current environment
The output of python collect_env.py
Operating System: `Linux-5.4.0-144-generic-x86_64-with-glibc2.35`
Python Version: `3.12.4 | packaged by Anaconda, Inc. | (main, Jun 18 2024, 15:12:24) [GCC 11.2.0]`
llm-compressor Version: `0.8.1`
compressed-tensors Version: `0.12.2`
transformers Version: `4.57.1`
torch Version: `2.8.0`
CUDA Devices: `['NVIDIA L40']`
AMD Devices: `None`
🐛 Describe the bug
I try to quantize InternVL3-8B.
- 1、download the
chat_template.jinjafrom InternVL3_5-8B and place the file in the local dir of InternVL3-8B - 2、replace the int() in modeling_internvl_chat.py with math.floor()
- 3、run the script below
import torch
from torchvision.transforms.functional import InterpolationMode
import torchvision.transforms as T
from PIL import Image
from datasets import load_dataset
from transformers import AutoModel, AutoTokenizer
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
# Load model.
model_id = "/root/workspace/models/InternVL3-8B"
model = AutoModel.from_pretrained(model_id, torch_dtype=torch.bfloat16, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=False)
# Load datasets
DATASET_ID = "/root/workspace/datasets/flickr30k"
DATASET_SPLIT = {"calibration": "test[:512]"}
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42)
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def load_image(image_file, input_size=448, max_num=12):
image = Image.open(image_file).convert('RGB')
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
#### copy from https://hf-mirror.com/OpenGVLab/InternVL3-8B#inference-with-transformers end ####
def load_image_from_PIL(image_obj, input_size=448, max_num=12):
image = image_obj.convert('RGB')
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
def preprocess_and_tokenize(example):
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": ""
},
{
"type": "text",
"text": "What does the image show?"
},
],
}
]
text = tokenizer.apply_chat_template(messages)
image_flags = torch.ones(1, dtype=torch.long)
example["input_ids"] = text
example['image_flags'] = image_flags
return example
ds = ds.map(preprocess_and_tokenize)
def data_collator(batch):
assert len(batch) == 1
item = {key: value for key, value in batch[0].items()}
item["pixel_values"] = load_image_from_PIL(item["image"]).to(torch.bfloat16)
item["input_ids"] = torch.LongTensor([item["input_ids"]])
item["labels"] = item["input_ids"].clone()
item["image_flags"] = torch.LongTensor([item["image_flags"]])
return item
# Recipe
recipe = GPTQModifier(
targets="Linear",
scheme="FP8",
ignore=["re:.*lm_head","re:mlp1.*", "re:.*vision_model.*"]
)
# Perform oneshot
oneshot(
model=model,
tokenizer=model_id,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
trust_remote_code_model=True,
data_collator=data_collator
)
# Save to disk compressed.
SAVE_DIR = "/root/workspace/models/InternVL3-8B-FP8-GPTQ"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)But I got the error:
Traceback (most recent call last):
File "/root/workspace/src/w8a8/ezcloud_fp8.py", line 153, in <module>
oneshot(
File "/root/workspace/src/llmcompressor/src/llmcompressor/entrypoints/oneshot.py", line 348, in oneshot
one_shot()
File "/root/workspace/src/llmcompressor/src/llmcompressor/entrypoints/oneshot.py", line 172, in __call__
self.apply_recipe_modifiers(
File "/root/workspace/src/llmcompressor/src/llmcompressor/entrypoints/oneshot.py", line 220, in apply_recipe_modifiers
pipeline(
File "/root/workspace/src/llmcompressor/src/llmcompressor/pipelines/independent/pipeline.py", line 45, in __call__
pipeline(model, dataloader, dataset_args)
File "/root/workspace/src/llmcompressor/src/llmcompressor/pipelines/sequential/pipeline.py", line 105, in __call__
subgraph.forward(model, **inputs)
File "/root/workspace/src/llmcompressor/src/llmcompressor/pipelines/sequential/helpers.py", line 74, in forward
return forward_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<string>", line 13, in forward
AttributeError: 'dict' object has no attribute 'last_hidden_state'
🛠️ Steps to reproduce
No response
handsomeli177
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working