Skip to content

Commit 06d0693

Browse files
committed
Transfering fixes from release branch
1 parent 6957d74 commit 06d0693

File tree

5 files changed

+223
-25
lines changed

5 files changed

+223
-25
lines changed

.github/workflows/linux.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -522,8 +522,8 @@ jobs:
522522
run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).continuous_batching.test }}
523523
timeout: 360
524524
- name: 'LLM & VLM'
525-
cmd: 'python -m pytest -v ./tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py ./tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py --override-ini cache_dir=/mount/caches/pytest/'
526-
run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).visual_language.test || fromJSON(needs.smart_ci.outputs.affected_components).LLM.test }}
525+
cmd: 'python -m pytest -v ./tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py ./tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py tests/python_tests/test_image_generation.py --override-ini cache_dir=/mount/caches/pytest/'
526+
run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).visual_language.test || fromJSON(needs.smart_ci.outputs.affected_components).LLM.test || fromJSON(needs.smart_ci.outputs.affected_components).Image_generation.test }}
527527
timeout: 180
528528
- name: 'GGUF Reader tests'
529529
cmd: 'python -m pytest -v ./tests/python_tests/test_gguf_reader.py'

.github/workflows/mac.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,8 +447,8 @@ jobs:
447447
# timeout: 240
448448
# Only supported on X64 or ARM with SVE support
449449
# - name: 'LLM & VLM'
450-
# cmd: 'tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py'
451-
# run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).visual_language.test || fromJSON(needs.smart_ci.outputs.affected_components).LLM.test }}
450+
# cmd: 'tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py tests/python_tests/test_image_generation.py'
451+
# run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).visual_language.test || fromJSON(needs.smart_ci.outputs.affected_components).LLM.test || fromJSON(needs.smart_ci.outputs.affected_components).Image_generation.test }}
452452
# timeout: 180
453453
- name: 'GGUF Reader tests'
454454
cmd: 'python -m pytest -v ./tests/python_tests/test_gguf_reader.py'

.github/workflows/windows.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -611,8 +611,8 @@ jobs:
611611
run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).continuous_batching.test }}
612612
timeout: 360
613613
- name: 'LLM & VLM'
614-
cmd: 'python -m pytest -s -v tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py --override-ini cache_dir=/mount/caches/pytest/'
615-
run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).visual_language.test || fromJSON(needs.smart_ci.outputs.affected_components).LLM.test }}
614+
cmd: 'python -m pytest -s -v tests/python_tests/test_llm_pipeline.py tests/python_tests/test_llm_pipeline_static.py tests/python_tests/test_vlm_pipeline.py tests/python_tests/test_structured_output.py tests/python_tests/test_image_generation.py --override-ini cache_dir=/mount/caches/pytest/'
615+
run_condition: ${{ fromJSON(needs.smart_ci.outputs.affected_components).visual_language.test || fromJSON(needs.smart_ci.outputs.affected_components).LLM.test || fromJSON(needs.smart_ci.outputs.affected_components).Image_generation.test }}
616616
timeout: 180
617617
- name: 'GGUF Reader tests'
618618
cmd: 'python -m pytest -s -v tests/python_tests/test_gguf_reader.py'

src/python/py_image_generation_pipelines.cpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,12 @@ class TorchGenerator : public ov::genai::CppStdGenerator {
180180
}
181181

182182
float next() override {
183+
py::gil_scoped_acquire acquire;
183184
return m_torch.attr("randn")(1, "generator"_a=m_torch_generator, "dtype"_a=m_float32).attr("item")().cast<float>();
184185
}
185186

186187
ov::Tensor randn_tensor(const ov::Shape& shape) override {
188+
py::gil_scoped_acquire acquire;
187189
py::object torch_tensor = m_torch.attr("randn")(to_py_list(shape), "generator"_a=m_torch_generator, "dtype"_a=m_float32);
188190
py::object numpy_tensor = torch_tensor.attr("numpy")();
189191
py::array numpy_array = py::cast<py::array>(numpy_tensor);
@@ -195,12 +197,24 @@ class TorchGenerator : public ov::genai::CppStdGenerator {
195197
class TorchTensorAllocator {
196198
size_t m_total_size;
197199
void * m_mutable_data;
198-
py::object m_torch_tensor; // we need to hold torch.Tensor to avoid memory destruction
200+
py::object m_torch_tensor;
199201

200202
public:
201203
TorchTensorAllocator(size_t total_size, void * mutable_data, py::object torch_tensor) :
202204
m_total_size(total_size), m_mutable_data(mutable_data), m_torch_tensor(torch_tensor) { }
203205

206+
~TorchTensorAllocator() {
207+
if (m_torch_tensor && Py_IsInitialized()) {
208+
py::gil_scoped_acquire acquire;
209+
m_torch_tensor = py::object();
210+
}
211+
}
212+
213+
TorchTensorAllocator(const TorchTensorAllocator&) = default;
214+
TorchTensorAllocator& operator=(const TorchTensorAllocator&) = default;
215+
TorchTensorAllocator(TorchTensorAllocator&&) = default;
216+
TorchTensorAllocator& operator=(TorchTensorAllocator&&) = default;
217+
204218
void* allocate(size_t bytes, size_t) const {
205219
if (m_total_size == bytes) {
206220
return m_mutable_data;
@@ -221,6 +235,7 @@ class TorchGenerator : public ov::genai::CppStdGenerator {
221235
}
222236

223237
void seed(size_t new_seed) override {
238+
py::gil_scoped_acquire acquire;
224239
create_torch_generator(new_seed);
225240
}
226241
};
@@ -448,12 +463,7 @@ void init_image_generation_pipelines(py::module_& m) {
448463
) -> py::typing::Union<ov::Tensor> {
449464
ov::AnyMap params = pyutils::kwargs_to_any_map(kwargs);
450465
ov::Tensor res;
451-
if (params_have_torch_generator(params)) {
452-
// TorchGenerator stores python object which causes segfault after gil_scoped_release
453-
// so if it was passed, we don't release GIL
454-
res = pipe.generate(prompt, params);
455-
}
456-
else {
466+
{
457467
py::gil_scoped_release rel;
458468
res = pipe.generate(prompt, params);
459469
}
@@ -565,12 +575,7 @@ void init_image_generation_pipelines(py::module_& m) {
565575
) -> py::typing::Union<ov::Tensor> {
566576
ov::AnyMap params = pyutils::kwargs_to_any_map(kwargs);
567577
ov::Tensor res;
568-
if (params_have_torch_generator(params)) {
569-
// TorchGenerator stores python object which causes segfault after gil_scoped_release
570-
// so if it was passed, we don't release GIL
571-
res = pipe.generate(prompt, image, params);
572-
}
573-
else {
578+
{
574579
py::gil_scoped_release rel;
575580
res = pipe.generate(prompt, image, params);
576581
}
@@ -676,12 +681,7 @@ void init_image_generation_pipelines(py::module_& m) {
676681
) -> py::typing::Union<ov::Tensor> {
677682
ov::AnyMap params = pyutils::kwargs_to_any_map(kwargs);
678683
ov::Tensor res;
679-
if (params_have_torch_generator(params)) {
680-
// TorchGenerator stores python object which causes segfault after gil_scoped_release
681-
// so if it was passed, we don't release GIL
682-
res = pipe.generate(prompt, image, mask_image, params);
683-
}
684-
else {
684+
{
685685
py::gil_scoped_release rel;
686686
res = pipe.generate(prompt, image, mask_image, params);
687687
}
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import pytest
5+
import subprocess # nosec B404
6+
import logging
7+
from pathlib import Path
8+
import numpy as np
9+
import openvino as ov
10+
import openvino_genai as ov_genai
11+
12+
from utils.constants import get_ov_cache_models_dir
13+
from utils.atomic_download import AtomicDownloadManager
14+
from utils.network import retry_request
15+
16+
logger = logging.getLogger(__name__)
17+
18+
MODEL_ID = "tiny-random-latent-consistency"
19+
MODEL_NAME = "echarlaix/tiny-random-latent-consistency"
20+
21+
22+
@pytest.fixture(scope="module")
23+
def image_generation_model():
24+
models_dir = get_ov_cache_models_dir()
25+
model_path = Path(models_dir) / MODEL_ID / MODEL_NAME
26+
27+
manager = AtomicDownloadManager(model_path)
28+
29+
def convert_model(temp_path: Path) -> None:
30+
command = [
31+
"optimum-cli", "export", "openvino",
32+
"--model", MODEL_NAME,
33+
"--trust-remote-code",
34+
"--weight-format", "fp16",
35+
str(temp_path)
36+
]
37+
logger.info(f"Conversion command: {' '.join(command)}")
38+
retry_request(lambda: subprocess.run(command, check=True, text=True, capture_output=True))
39+
40+
try:
41+
manager.execute(convert_model)
42+
except subprocess.CalledProcessError as error:
43+
logger.exception(f"optimum-cli returned {error.returncode}. Output:\n{error.output}")
44+
raise
45+
46+
return str(model_path)
47+
48+
49+
def get_random_image(height: int = 64, width: int = 64) -> ov.Tensor:
50+
image_data = np.random.randint(0, 255, (1, height, width, 3), dtype=np.uint8)
51+
return ov.Tensor(image_data)
52+
53+
54+
def get_mask_image(height: int = 64, width: int = 64) -> ov.Tensor:
55+
mask_data = np.zeros((1, height, width, 3), dtype=np.uint8)
56+
mask_data[:, height//4:3*height//4, width//4:3*width//4, :] = 255
57+
return ov.Tensor(mask_data)
58+
59+
60+
class TestImageGenerationCallback:
61+
62+
def test_text2image_with_simple_callback(self, image_generation_model):
63+
pipe = ov_genai.Text2ImagePipeline(image_generation_model, "CPU")
64+
65+
callback_calls = []
66+
67+
def callback(step, num_steps, latent):
68+
callback_calls.append((step, num_steps))
69+
return False
70+
71+
image = pipe.generate(
72+
"test prompt",
73+
width=64,
74+
height=64,
75+
num_inference_steps=2,
76+
callback=callback
77+
)
78+
79+
assert len(callback_calls) > 0, "Callback should be called at least once"
80+
assert image is not None
81+
82+
def test_text2image_with_stateful_callback(self, image_generation_model):
83+
pipe = ov_genai.Text2ImagePipeline(image_generation_model, "CPU")
84+
85+
class ProgressTracker:
86+
def __init__(self):
87+
self.steps = []
88+
self.total = 0
89+
90+
def reset(self, total):
91+
self.total = total
92+
self.steps = []
93+
94+
def update(self, step):
95+
self.steps.append(step)
96+
97+
tracker = ProgressTracker()
98+
99+
def callback(step, num_steps, latent):
100+
if tracker.total != num_steps:
101+
tracker.reset(num_steps)
102+
tracker.update(step)
103+
return False
104+
105+
image = pipe.generate(
106+
"test prompt",
107+
width=64,
108+
height=64,
109+
num_inference_steps=2,
110+
callback=callback
111+
)
112+
113+
assert len(tracker.steps) > 0, "Callback should track steps"
114+
assert image is not None
115+
116+
def test_text2image_callback_early_stop(self, image_generation_model):
117+
pipe = ov_genai.Text2ImagePipeline(image_generation_model, "CPU")
118+
119+
callback_calls = []
120+
121+
def callback(step, num_steps, latent):
122+
callback_calls.append(step)
123+
return step >= 1
124+
125+
image = pipe.generate(
126+
"test prompt",
127+
width=64,
128+
height=64,
129+
num_inference_steps=5,
130+
callback=callback
131+
)
132+
133+
assert len(callback_calls) <= 3, "Callback should stop early"
134+
assert image is not None
135+
136+
def test_text2image_multiple_generates_with_callback(self, image_generation_model):
137+
pipe = ov_genai.Text2ImagePipeline(image_generation_model, "CPU")
138+
139+
for i in range(3):
140+
callback_calls = []
141+
142+
def callback(step, num_steps, latent):
143+
callback_calls.append(step)
144+
return False
145+
146+
image = pipe.generate(
147+
f"test prompt {i}",
148+
width=64,
149+
height=64,
150+
num_inference_steps=2,
151+
callback=callback
152+
)
153+
154+
assert len(callback_calls) > 0
155+
assert image is not None
156+
157+
def test_image2image_with_callback(self, image_generation_model):
158+
pipe = ov_genai.Image2ImagePipeline(image_generation_model, "CPU")
159+
160+
callback_calls = []
161+
162+
def callback(step, num_steps, latent):
163+
callback_calls.append((step, num_steps))
164+
return False
165+
166+
input_image = get_random_image()
167+
168+
image = pipe.generate(
169+
"test prompt",
170+
input_image,
171+
strength=0.8,
172+
callback=callback
173+
)
174+
175+
assert len(callback_calls) > 0
176+
assert image is not None
177+
178+
def test_inpainting_with_callback(self, image_generation_model):
179+
pipe = ov_genai.InpaintingPipeline(image_generation_model, "CPU")
180+
181+
callback_calls = []
182+
183+
def callback(step, num_steps, latent):
184+
callback_calls.append((step, num_steps))
185+
return False
186+
187+
input_image = get_random_image()
188+
mask_image = get_mask_image()
189+
190+
image = pipe.generate(
191+
"test prompt",
192+
input_image,
193+
mask_image,
194+
callback=callback
195+
)
196+
197+
assert len(callback_calls) > 0
198+
assert image is not None

0 commit comments

Comments
 (0)