Skip to content

Commit fc503c8

Browse files
authored
Fix for deadlock in python callback (#3034)
Fixes deadlock [CVS-176777](https://jira.devtools.intel.com/browse/CVS-176777) and adds more samples and test cases.
1 parent 945a938 commit fc503c8

File tree

8 files changed

+55
-22
lines changed

8 files changed

+55
-22
lines changed

samples/python/image_generation/image2image.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,20 @@ def main():
2626

2727
image = read_image(args.image)
2828

29-
image_tensor = pipe.generate(args.prompt, image,
30-
strength=0.8 # controls how initial image is noised after being converted to latent space. `1` means initial image is fully noised
29+
def callback(step, num_steps, latent):
30+
print(f"Step {step + 1}/{num_steps}")
31+
return False
32+
33+
image_tensor = pipe.generate(
34+
args.prompt,
35+
image,
36+
strength=0.8,
37+
callback=callback
3138
)
3239

3340
image = Image.fromarray(image_tensor.data[0])
3441
image.save("image.bmp")
3542

3643

37-
if '__main__' == __name__:
44+
if __name__ == '__main__':
3845
main()

samples/python/image_generation/inpainting.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,15 @@ def main():
2828
image = read_image(args.image)
2929
mask_image = read_image(args.mask)
3030

31-
image_tensor = pipe.generate(args.prompt, image, mask_image)
31+
def callback(step, num_steps, latent):
32+
print(f"Step {step + 1}/{num_steps}")
33+
return False
34+
35+
image_tensor = pipe.generate(args.prompt, image, mask_image, callback=callback)
3236

3337
image = Image.fromarray(image_tensor.data[0])
3438
image.save("image.bmp")
3539

3640

37-
if '__main__' == __name__:
41+
if __name__ == '__main__':
3842
main()

samples/python/image_generation/text2image.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,21 @@ def main():
1717
device = 'CPU' # GPU can be used as well
1818
pipe = openvino_genai.Text2ImagePipeline(args.model_dir, device)
1919

20+
def callback(step, num_steps, latent):
21+
print(f"Step {step + 1}/{num_steps}")
22+
return False
23+
2024
image_tensor = pipe.generate(
2125
args.prompt,
2226
width=512,
2327
height=512,
2428
num_inference_steps=20,
25-
num_images_per_prompt=1)
29+
num_images_per_prompt=1,
30+
callback=callback)
2631

2732
image = Image.fromarray(image_tensor.data[0])
2833
image.save("image.bmp")
2934

3035

31-
if '__main__' == __name__:
36+
if __name__ == '__main__':
3237
main()

src/cpp/src/image_generation/threaded_callback.hpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#pragma once
55

66
#include <thread>
7+
#include <variant>
78

89
#include "synchronized_queue.hpp"
910

@@ -33,7 +34,7 @@ class ThreadedCallbackWrapper {
3334
return CallbackStatus::STOP;
3435
}
3536

36-
m_squeue.push({step, num_steps, latent});
37+
m_squeue.push(std::make_tuple(step, num_steps, latent));
3738

3839
return CallbackStatus::RUNNING;
3940
}
@@ -44,7 +45,7 @@ class ThreadedCallbackWrapper {
4445
}
4546

4647
m_status = CallbackStatus::STOP;
47-
m_squeue.empty();
48+
m_squeue.push(std::monostate());
4849

4950
if (m_worker_thread && m_worker_thread->joinable()) {
5051
m_worker_thread->join();
@@ -58,18 +59,23 @@ class ThreadedCallbackWrapper {
5859
private:
5960
std::function<bool(size_t, size_t, ov::Tensor&)> m_callback = nullptr;
6061
std::shared_ptr<std::thread> m_worker_thread = nullptr;
61-
SynchronizedQueue<std::tuple<size_t, size_t, ov::Tensor>> m_squeue;
62+
SynchronizedQueue<std::variant<std::tuple<size_t, size_t, ov::Tensor>, std::monostate>> m_squeue;
6263

6364
std::atomic<CallbackStatus> m_status = CallbackStatus::RUNNING;
6465

6566
void _worker() {
6667
while (m_status == CallbackStatus::RUNNING) {
67-
// wait for queue pull
68-
auto [step, num_steps, latent] = m_squeue.pull();
69-
70-
if (m_callback(step, num_steps, latent)) {
71-
m_status = CallbackStatus::STOP;
72-
m_squeue.empty();
68+
auto item = m_squeue.pull();
69+
70+
if (auto callback_data = std::get_if<std::tuple<size_t, size_t, ov::Tensor>>(&item)) {
71+
auto& [step, num_steps, latent] = *callback_data;
72+
const auto should_stop = m_callback(step, num_steps, latent);
73+
74+
if (should_stop) {
75+
m_status = CallbackStatus::STOP;
76+
}
77+
} else if (std::get_if<std::monostate>(&item)) {
78+
break;
7379
}
7480
}
7581
}

src/cpp/src/synchronized_queue.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class SynchronizedQueue
2222

2323
T back() {
2424
std::unique_lock<std::mutex> lock(m_mutex);
25-
m_cv.wait(lock, [this]{return !m_queue.empty(); });
25+
m_cv.wait(lock, [this]{return !m_queue.empty();});
2626
return m_queue.back();
2727
}
2828

src/python/py_utils.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,13 @@ ov::Any py_object_to_any(const py::object& py_obj, std::string property_name) {
373373
} else if (py::isinstance<ov::genai::Generator>(py_obj)) {
374374
return py::cast<std::shared_ptr<ov::genai::Generator>>(py_obj);
375375
} else if (py::isinstance<py::function>(py_obj) && property_name == "callback") {
376-
return py::cast<std::function<bool(size_t, size_t, ov::Tensor&)>>(py_obj);
376+
auto py_callback = py::cast<py::function>(py_obj);
377+
return std::function<bool(size_t, size_t, ov::Tensor&)>(
378+
[py_callback](size_t step, size_t num_steps, ov::Tensor& latent) -> bool {
379+
py::gil_scoped_acquire acquire;
380+
return py_callback(step, num_steps, latent).cast<bool>();
381+
}
382+
);
377383
} else if ((py::isinstance<py::function>(py_obj) || py::isinstance<ov::genai::StreamerBase>(py_obj) || py::isinstance<std::monostate>(py_obj)) && property_name == "streamer") {
378384
auto streamer = py::cast<ov::genai::pybind::utils::PyBindStreamerVariant>(py_obj);
379385
return ov::genai::streamer(pystreamer_to_streamer(streamer)).second;

tests/python_tests/samples/test_inpainting.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,23 @@
1111
download_mask_image = download_test_content
1212

1313
class TestInpainting:
14+
PROMPT = "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting"
15+
IMAGE_PATH = "images/image.png"
16+
MASK_PATH = "mask_image.png"
17+
1418
@pytest.mark.samples
1519
@pytest.mark.LCM_Dreamshaper_v7_int8_ov
1620
@pytest.mark.parametrize(
1721
"download_model, prompt",
1822
[
19-
pytest.param("LCM_Dreamshaper_v7-int8-ov", "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting"),
23+
pytest.param("LCM_Dreamshaper_v7-int8-ov", PROMPT),
2024
],
2125
indirect=["download_model"],
2226
)
2327
@pytest.mark.parametrize(
2428
"download_test_content, download_mask_image",
2529
[
26-
pytest.param("images/image.png", "mask_image.png"),
30+
pytest.param(IMAGE_PATH, MASK_PATH),
2731
],
2832
indirect=["download_test_content", "download_mask_image"],
2933
)

tests/python_tests/samples/test_text2image.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
from test_utils import run_sample
1010

1111
class TestText2Image:
12+
PROMPT = "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting"
13+
1214
@pytest.mark.samples
1315
@pytest.mark.dreamlike_anime_1_0
1416
@pytest.mark.parametrize(
1517
"convert_model, sample_args",
1618
[
17-
pytest.param("dreamlike-anime-1.0", "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting"),
19+
pytest.param("dreamlike-anime-1.0", PROMPT),
1820
],
1921
indirect=["convert_model"],
2022
)
@@ -29,7 +31,6 @@ def test_sample_text2image(self, convert_model, sample_args):
2931
cpp_command = [cpp_sample, convert_model, sample_args]
3032
run_sample(cpp_command)
3133

32-
3334
@pytest.mark.samples
3435
@pytest.mark.dreamlike_anime_1_0
3536
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)