Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions samples/deployment-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ librosa==0.11.0 # For Whisper
pillow==12.0.0 # Image processing for VLMs
json5==0.12.1 # For ReAct
pydantic==2.12.4 # For Structured output json schema
opencv-python # For video-to-text VLM sample
17 changes: 13 additions & 4 deletions samples/python/visual_language_chat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

This example showcases inference of text-generation Vision Language Models (VLMs): `miniCPM-V-2_6` and other models with the same signature. The application doesn't have many configuration options to encourage the reader to explore and modify the source code. For example, change the device for inference to GPU. The sample features `openvino_genai.VLMPipeline` and configures it for the chat scenario. There is also a Jupyter [notebook](https://github.com/openvinotoolkit/openvino_notebooks/tree/latest/notebooks/minicpm-v-multimodal-chatbot) which provides an example of Visual-language assistant.

There are two sample files:
There are three sample files:
- [`visual_language_chat.py`](./visual_language_chat.py) demonstrates basic usage of the VLM pipeline.
- [`video_to_text_chat.py`](./video_to_text_chat.py) demonstrates video to text usage of the VLM pipeline.
- [`benchmark_vlm.py`](./benchmark_vlm.py) shows how to benchmark a VLM in OpenVINO GenAI. The script includes functionality for warm-up iterations, generating text and calculating various performance metrics.

## Download and convert the model and tokenizers
Expand Down Expand Up @@ -38,14 +39,22 @@ tokenizer = AutoTokenizer.from_pretrained("openbmb/MiniCPM-V-2_6")
export_tokenizer(tokenizer, output_dir)
```

## Run:
Install [deployment-requirements.txt](../../deployment-requirements.txt) via `pip install -r ../../deployment-requirements.txt` to run VLM samples.

[This image](https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/d5fbbd1a-d484-415c-88cb-9986625b7b11) can be used as a sample image.
## Run image-to-text chat sample:

Install [deployment-requirements.txt](../../deployment-requirements.txt) via `pip install -r ../../deployment-requirements.txt` and then, run a sample:
[This image](https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/d5fbbd1a-d484-415c-88cb-9986625b7b11) can be used as a sample image.

`python visual_language_chat.py ./miniCPM-V-2_6/ 319483352-d5fbbd1a-d484-415c-88cb-9986625b7b11.jpg`

## Run video-to-text chat sample:

To run this sample a model that supports video input is required, for example `llava-hf/LLaVA-NeXT-Video-7B-hf`.

[This video](https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4) can be used as a sample video.

`python video_to_text_chat.py ./LLaVA-NeXT-Video-7B-hf/ sample_demo_1.mp4`


Discrete GPUs (dGPUs) usually provide better performance compared to CPUs. It is recommended to run larger models on a dGPU with 32GB+ RAM. # TODO: examples of larger models
Modify the source code to change the device for inference to the GPU.
Expand Down
100 changes: 100 additions & 0 deletions samples/python/visual_language_chat/video_to_text_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#!/usr/bin/env python3
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


import argparse
import numpy as np
import cv2
import openvino_genai
from openvino import Tensor
from pathlib import Path


def streamer(subword: str) -> bool:
'''
Args:
subword: sub-word of the generated text.
Returns: Return flag corresponds whether generation should be stopped.
'''
print(subword, end='', flush=True)

# No value is returned as in this example we don't want to stop the generation in this method.
# "return None" will be treated the same as "return openvino_genai.StreamingStatus.RUNNING".


def read_video(path: str, num_frames: int = 10) -> Tensor:
'''
Args:
path: The path to the video.
Returns: the ov.Tensor containing the video.
'''
cap = cv2.VideoCapture(path)

frames = []

while cap.isOpened():
ret, frame = cap.read()
if not ret:
break

frames.append(np.array(frame))
cap.release()

indices = np.arange(0, len(frames), len(frames) / num_frames).astype(int)
frames = [frames[i] for i in indices]

return Tensor(frames)


def read_videos(path: str) -> list[Tensor]:
entry = Path(path)
if entry.is_dir():
return [read_video(str(file)) for file in sorted(entry.iterdir())]
return [read_video(path)]


def main():
parser = argparse.ArgumentParser()
parser.add_argument('model_dir', help="Path to the model directory")
parser.add_argument('video_dir', help="Path to a video file.")
parser.add_argument('device', nargs='?', default='CPU', help="Device to run the model on (default: CPU)")
args = parser.parse_args()

video = read_videos(args.video_dir)

# GPU and NPU can be used as well.
# Note: If NPU is selected, only the language model will be run on the NPU.
enable_compile_cache = dict()
if args.device == "GPU":
# Cache compiled models on disk for GPU to save time on the next run.
# It's not beneficial for CPU.
enable_compile_cache["CACHE_DIR"] = "vlm_cache"

pipe = openvino_genai.VLMPipeline(args.model_dir, args.device, **enable_compile_cache)

config = openvino_genai.GenerationConfig()
config.max_new_tokens = 100

pipe.start_chat()
prompt = input('question:\n')
pipe.generate(prompt, videos=video, generation_config=config, streamer=streamer)

while True:
try:
prompt = input("\n----------\n"
"question:\n")
except EOFError:
break
pipe.generate(prompt, generation_config=config, streamer=streamer)
pipe.finish_chat()


if __name__ == '__main__':
main()
7 changes: 6 additions & 1 deletion tests/python_tests/samples/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@
"tiny-random-SpeechT5ForTextToSpeech": {
"name": "hf-internal-testing/tiny-random-SpeechT5ForTextToSpeech",
"convert_args": ["--model-kwargs", json.dumps({"vocoder": "fxmarty/speecht5-hifigan-tiny"})]
},
"tiny-random-llava-next-video": {
"name": "optimum-intel-internal-testing/tiny-random-llava-next-video",
"convert_args": ["--task", "image-text-to-text"]
}
}

Expand All @@ -159,7 +163,8 @@
"cat.png": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png",
"cat": "https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/d5fbbd1a-d484-415c-88cb-9986625b7b11",
"3283_1447_000.tar.gz": "https://huggingface.co/datasets/facebook/multilingual_librispeech/resolve/main/data/mls_polish/train/audio/3283_1447_000.tar.gz",
"cmu_us_awb_arctic-wav-arctic_a0001.bin": "https://huggingface.co/datasets/Xenova/cmu-arctic-xvectors-extracted/resolve/main/cmu_us_awb_arctic-wav-arctic_a0001.bin"
"cmu_us_awb_arctic-wav-arctic_a0001.bin": "https://huggingface.co/datasets/Xenova/cmu-arctic-xvectors-extracted/resolve/main/cmu_us_awb_arctic-wav-arctic_a0001.bin",
"videos/sample_video.mp4": "https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
}

SAMPLES_PY_DIR = Path(os.environ.get("SAMPLES_PY_DIR", os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../samples/python"))))
Expand Down
36 changes: 36 additions & 0 deletions tests/python_tests/samples/test_video_to_text_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import os
import pytest
import subprocess # nosec B404
import sys

from conftest import SAMPLES_PY_DIR, SAMPLES_CPP_DIR, SAMPLES_C_DIR
from test_utils import run_sample

class TestVisualLanguageChat:
@pytest.mark.vlm
@pytest.mark.samples
@pytest.mark.parametrize(
"convert_model, download_test_content, questions",
[
pytest.param("tiny-random-llava-next-video", "videos/sample_video.mp4", 'What is unusual on this video?\nGo on.')
],
indirect=["convert_model", "download_test_content"],
)
def test_sample_visual_language_chat(self, convert_model, download_test_content, questions):
# Test CPP sample
# TODO

# Test C sample
# TODO

# Test Python sample
py_script = os.path.join(SAMPLES_PY_DIR, "visual_language_chat/video_to_text_chat.py")
py_command = [sys.executable, py_script, convert_model, download_test_content]
py_result = run_sample(py_command, questions)

# Compare results
# assert py_result.stdout == cpp_result.stdout, f"Results should match"
# assert cpp_result.stdout == c_result.stdout, f"Results should match"
Loading