Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -522,8 +522,8 @@ jobs:
run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).continuous_batching.test }}
timeout: 360
- name: 'LLM & VLM'
cmd: 'python -m pytest -v ./tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py ./tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py --override-ini cache_dir=/mount/caches/pytest/'
run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).visual_language.test || fromJSON(needs.smart_ci.outputs.affected_components).LLM.test }}
cmd: 'python -m pytest -v ./tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py ./tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py tests/python_tests/test_image_generation.py --override-ini cache_dir=/mount/caches/pytest/'
run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).visual_language.test || fromJSON(needs.smart_ci.outputs.affected_components).LLM.test || fromJSON(needs.smart_ci.outputs.affected_components).Image_generation.test }}
timeout: 180
- name: 'GGUF Reader tests'
cmd: 'python -m pytest -v ./tests/python_tests/test_gguf_reader.py'
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/mac.yml
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,8 @@ jobs:
# timeout: 240
# Only supported on X64 or ARM with SVE support
# - name: 'LLM & VLM'
# cmd: 'tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py'
# run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).visual_language.test || fromJSON(needs.smart_ci.outputs.affected_components).LLM.test }}
# cmd: 'tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py tests/python_tests/test_image_generation.py'
# run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).visual_language.test || fromJSON(needs.smart_ci.outputs.affected_components).LLM.test || fromJSON(needs.smart_ci.outputs.affected_components).Image_generation.test }}
# timeout: 180
- name: 'GGUF Reader tests'
cmd: 'python -m pytest -v ./tests/python_tests/test_gguf_reader.py'
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -611,8 +611,8 @@ jobs:
run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).continuous_batching.test }}
timeout: 360
- name: 'LLM & VLM'
cmd: 'python -m pytest -s -v tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py --override-ini cache_dir=/mount/caches/pytest/'
run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).visual_language.test || fromJSON(needs.smart_ci.outputs.affected_components).LLM.test }}
cmd: 'python -m pytest -s -v tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py tests/python_tests/test_image_generation.py --override-ini cache_dir=/mount/caches/pytest/'
run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).visual_language.test || fromJSON(needs.smart_ci.outputs.affected_components).LLM.test || fromJSON(needs.smart_ci.outputs.affected_components).Image_generation.test }}
timeout: 180
- name: 'GGUF Reader tests'
cmd: 'python -m pytest -s -v tests/python_tests/test_gguf_reader.py'
Expand Down
36 changes: 18 additions & 18 deletions src/python/py_image_generation_pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,12 @@ class TorchGenerator : public ov::genai::CppStdGenerator {
}

float next() override {
py::gil_scoped_acquire acquire;
return m_torch.attr("randn")(1, "generator"_a=m_torch_generator, "dtype"_a=m_float32).attr("item")().cast<float>();
}

ov::Tensor randn_tensor(const ov::Shape& shape) override {
py::gil_scoped_acquire acquire;
py::object torch_tensor = m_torch.attr("randn")(to_py_list(shape), "generator"_a=m_torch_generator, "dtype"_a=m_float32);
py::object numpy_tensor = torch_tensor.attr("numpy")();
py::array numpy_array = py::cast<py::array>(numpy_tensor);
Expand All @@ -201,6 +203,18 @@ class TorchGenerator : public ov::genai::CppStdGenerator {
TorchTensorAllocator(size_t total_size, void * mutable_data, py::object torch_tensor) :
m_total_size(total_size), m_mutable_data(mutable_data), m_torch_tensor(torch_tensor) { }

~TorchTensorAllocator() {
if (m_torch_tensor && Py_IsInitialized()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

m_torch_tensor is always set in constructor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but when we're using move it may be an empty Python object as far as i understand. So it's a little bit defensive here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is TorchTensorAllocator movable? It defines a constructor which should disable default move constructot

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it was movable before i've added destructor (constructor doesn't disable generation of default move methods). When i've added destructor it became unmovable, so i've had to specify default move constructor and copy/assign methods.

py::gil_scoped_acquire acquire;
m_torch_tensor = py::object();
}
}

TorchTensorAllocator(const TorchTensorAllocator&) = default;
TorchTensorAllocator& operator=(const TorchTensorAllocator&) = default;
TorchTensorAllocator(TorchTensorAllocator&&) = default;
TorchTensorAllocator& operator=(TorchTensorAllocator&&) = default;

void* allocate(size_t bytes, size_t) const {
if (m_total_size == bytes) {
return m_mutable_data;
Expand All @@ -221,6 +235,7 @@ class TorchGenerator : public ov::genai::CppStdGenerator {
}

void seed(size_t new_seed) override {
py::gil_scoped_acquire acquire;
create_torch_generator(new_seed);
}
};
Expand Down Expand Up @@ -448,12 +463,7 @@ void init_image_generation_pipelines(py::module_& m) {
) -> py::typing::Union<ov::Tensor> {
ov::AnyMap params = pyutils::kwargs_to_any_map(kwargs);
ov::Tensor res;
if (params_have_torch_generator(params)) {
// TorchGenerator stores python object which causes segfault after gil_scoped_release
// so if it was passed, we don't release GIL
res = pipe.generate(prompt, params);
}
else {
{
py::gil_scoped_release rel;
res = pipe.generate(prompt, params);
}
Expand Down Expand Up @@ -565,12 +575,7 @@ void init_image_generation_pipelines(py::module_& m) {
) -> py::typing::Union<ov::Tensor> {
ov::AnyMap params = pyutils::kwargs_to_any_map(kwargs);
ov::Tensor res;
if (params_have_torch_generator(params)) {
// TorchGenerator stores python object which causes segfault after gil_scoped_release
// so if it was passed, we don't release GIL
res = pipe.generate(prompt, image, params);
}
else {
{
py::gil_scoped_release rel;
res = pipe.generate(prompt, image, params);
}
Expand Down Expand Up @@ -676,12 +681,7 @@ void init_image_generation_pipelines(py::module_& m) {
) -> py::typing::Union<ov::Tensor> {
ov::AnyMap params = pyutils::kwargs_to_any_map(kwargs);
ov::Tensor res;
if (params_have_torch_generator(params)) {
// TorchGenerator stores python object which causes segfault after gil_scoped_release
// so if it was passed, we don't release GIL
res = pipe.generate(prompt, image, mask_image, params);
}
else {
{
py::gil_scoped_release rel;
res = pipe.generate(prompt, image, mask_image, params);
}
Expand Down
53 changes: 42 additions & 11 deletions src/python/py_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,22 @@ ov::Any py_object_to_any(const py::object& py_obj, std::string property_name) {
return py::cast<std::shared_ptr<ov::genai::Generator>>(py_obj);
} else if (py::isinstance<py::function>(py_obj) && property_name == "callback") {
auto py_callback = py::cast<py::function>(py_obj);
auto shared_callback = std::shared_ptr<py::function>(
new py::function(py_callback),
[](py::function* f) {
if (Py_IsInitialized()) {
py::gil_scoped_acquire acquire;
delete f;
} else {
delete f;
}
}
);

return std::function<bool(size_t, size_t, ov::Tensor&)>(
[py_callback](size_t step, size_t num_steps, ov::Tensor& latent) -> bool {
[shared_callback](size_t step, size_t num_steps, ov::Tensor& latent) -> bool {
py::gil_scoped_acquire acquire;
return py_callback(step, num_steps, latent).cast<bool>();
return (*shared_callback)(step, num_steps, latent).cast<bool>();
}
);
} else if ((py::isinstance<py::function>(py_obj) || py::isinstance<ov::genai::StreamerBase>(py_obj) || py::isinstance<std::monostate>(py_obj)) && property_name == "streamer") {
Expand Down Expand Up @@ -443,21 +455,40 @@ ov::genai::StreamerVariant pystreamer_to_streamer(const PyBindStreamerVariant& p

std::visit(overloaded {
[&streamer](const std::function<std::optional<uint16_t>(py::str)>& py_callback){
// Wrap python streamer with manual utf-8 decoding. Do not rely
// on pybind automatic decoding since it raises exceptions on incomplete strings.
auto callback_wrapped = [py_callback](std::string subword) -> ov::genai::StreamingStatus {
auto shared_callback = std::shared_ptr<std::function<std::optional<uint16_t>(py::str)>>(
new std::function<std::optional<uint16_t>(py::str)>(py_callback),
[](std::function<std::optional<uint16_t>(py::str)>* f) {
if (Py_IsInitialized()) {
py::gil_scoped_acquire acquire;
delete f;
} else {
delete f;
}
}
);

auto callback_wrapped = [shared_callback](std::string subword) -> ov::genai::StreamingStatus {
py::gil_scoped_acquire acquire;
auto py_str = PyUnicode_DecodeUTF8(subword.data(), subword.length(), "replace");
std::optional<uint16_t> callback_output = py_callback(py::reinterpret_borrow<py::str>(py_str));
PyObject* py_str = PyUnicode_DecodeUTF8(subword.data(), subword.length(), "replace");
if (!py_str) {
PyErr_WriteUnraisable(nullptr);
return StreamingStatus::RUNNING;
}
auto py_str_obj = py::reinterpret_steal<py::str>(py_str);
std::optional<uint16_t> callback_output;
try {
callback_output = (*shared_callback)(py_str_obj);
} catch (const py::error_already_set&) {
return StreamingStatus::RUNNING;
}
if (callback_output.has_value()) {
if (*callback_output == (uint16_t)StreamingStatus::RUNNING)
if (*callback_output == static_cast<uint16_t>(StreamingStatus::RUNNING))
return StreamingStatus::RUNNING;
else if (*callback_output == (uint16_t)StreamingStatus::CANCEL)
else if (*callback_output == static_cast<uint16_t>(StreamingStatus::CANCEL))
return StreamingStatus::CANCEL;
return StreamingStatus::STOP;
} else {
return StreamingStatus::RUNNING;
}
return StreamingStatus::RUNNING;
};
streamer = callback_wrapped;
},
Expand Down
2 changes: 1 addition & 1 deletion src/python/py_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace ov::genai::pybind::utils {
// When StreamerVariant is used utf-8 decoding is done by pybind and can lead to exception on incomplete texts.
// Therefore strings decoding should be handled with PyUnicode_DecodeUTF8(..., "replace") to not throw errors.
using PyBindStreamerVariant = std::variant<
std::function<std::optional<uint16_t>(std::string)>,
std::function<std::optional<uint16_t>(py::str)>,
std::shared_ptr<StreamerBase>,
std::monostate>;

Expand Down
198 changes: 198 additions & 0 deletions tests/python_tests/test_image_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
import subprocess # nosec B404
import logging
from pathlib import Path
import numpy as np
import openvino as ov
import openvino_genai as ov_genai

from utils.constants import get_ov_cache_converted_models_dir
from utils.atomic_download import AtomicDownloadManager
from utils.network import retry_request

logger = logging.getLogger(__name__)

MODEL_ID = "tiny-random-latent-consistency"
MODEL_NAME = "echarlaix/tiny-random-latent-consistency"


@pytest.fixture(scope="module")
def image_generation_model():
models_dir = get_ov_cache_converted_models_dir()
model_path = Path(models_dir) / MODEL_ID / MODEL_NAME

manager = AtomicDownloadManager(model_path)

def convert_model(temp_path: Path) -> None:
command = [
"optimum-cli", "export", "openvino",
"--model", MODEL_NAME,
"--trust-remote-code",
"--weight-format", "fp16",
str(temp_path)
]
logger.info(f"Conversion command: {' '.join(command)}")
retry_request(lambda: subprocess.run(command, check=True, text=True, capture_output=True))

try:
manager.execute(convert_model)
except subprocess.CalledProcessError as error:
logger.exception(f"optimum-cli returned {error.returncode}. Output:\n{error.output}")
raise

return str(model_path)


def get_random_image(height: int = 64, width: int = 64) -> ov.Tensor:
image_data = np.random.randint(0, 255, (1, height, width, 3), dtype=np.uint8)
return ov.Tensor(image_data)


def get_mask_image(height: int = 64, width: int = 64) -> ov.Tensor:
mask_data = np.zeros((1, height, width, 3), dtype=np.uint8)
mask_data[:, height//4:3*height//4, width//4:3*width//4, :] = 255
return ov.Tensor(mask_data)


class TestImageGenerationCallback:

def test_text2image_with_simple_callback(self, image_generation_model):
pipe = ov_genai.Text2ImagePipeline(image_generation_model, "CPU")

callback_calls = []

def callback(step, num_steps, latent):
callback_calls.append((step, num_steps))
return False

image = pipe.generate(
"test prompt",
width=64,
height=64,
num_inference_steps=2,
callback=callback
)

assert len(callback_calls) > 0, "Callback should be called at least once"
assert image is not None

def test_text2image_with_stateful_callback(self, image_generation_model):
pipe = ov_genai.Text2ImagePipeline(image_generation_model, "CPU")

class ProgressTracker:
def __init__(self):
self.steps = []
self.total = 0

def reset(self, total):
self.total = total
self.steps = []

def update(self, step):
self.steps.append(step)

tracker = ProgressTracker()

def callback(step, num_steps, latent):
if tracker.total != num_steps:
tracker.reset(num_steps)
tracker.update(step)
return False

image = pipe.generate(
"test prompt",
width=64,
height=64,
num_inference_steps=2,
callback=callback
)

assert len(tracker.steps) > 0, "Callback should track steps"
assert image is not None

def test_text2image_callback_early_stop(self, image_generation_model):
pipe = ov_genai.Text2ImagePipeline(image_generation_model, "CPU")

callback_calls = []

def callback(step, num_steps, latent):
callback_calls.append(step)
return step >= 1

image = pipe.generate(
"test prompt",
width=64,
height=64,
num_inference_steps=5,
callback=callback
)

assert len(callback_calls) <= 3, "Callback should stop early"
assert image is not None

def test_text2image_multiple_generates_with_callback(self, image_generation_model):
pipe = ov_genai.Text2ImagePipeline(image_generation_model, "CPU")

for i in range(3):
callback_calls = []

def callback(step, num_steps, latent):
callback_calls.append(step)
return False

image = pipe.generate(
f"test prompt {i}",
width=64,
height=64,
num_inference_steps=2,
callback=callback
)

assert len(callback_calls) > 0
assert image is not None

def test_image2image_with_callback(self, image_generation_model):
pipe = ov_genai.Image2ImagePipeline(image_generation_model, "CPU")

callback_calls = []

def callback(step, num_steps, latent):
callback_calls.append((step, num_steps))
return False

input_image = get_random_image()

image = pipe.generate(
"test prompt",
input_image,
strength=0.8,
callback=callback
)

assert len(callback_calls) > 0
assert image is not None

def test_inpainting_with_callback(self, image_generation_model):
pipe = ov_genai.InpaintingPipeline(image_generation_model, "CPU")

callback_calls = []

def callback(step, num_steps, latent):
callback_calls.append((step, num_steps))
return False

input_image = get_random_image()
mask_image = get_mask_image()

image = pipe.generate(
"test prompt",
input_image,
mask_image,
callback=callback
)

assert len(callback_calls) > 0
assert image is not None
Loading