Skip to content

Commit 5041b1d

Browse files
sgonorovababushk
authored andcommitted
Fix for deadlock in python callback
1 parent b7a5a80 commit 5041b1d

File tree

8 files changed

+84
-27
lines changed

8 files changed

+84
-27
lines changed

samples/python/image_generation/image2image.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,20 @@ def main():
2626

2727
image = read_image(args.image)
2828

29-
image_tensor = pipe.generate(args.prompt, image,
30-
strength=0.8 # controls how initial image is noised after being converted to latent space. `1` means initial image is fully noised
29+
def callback(step, num_steps, latent):
30+
print(f"Step {step + 1}/{num_steps}")
31+
return False
32+
33+
image_tensor = pipe.generate(
34+
args.prompt,
35+
image,
36+
strength=0.8,
37+
callback=callback
3138
)
3239

3340
image = Image.fromarray(image_tensor.data[0])
3441
image.save("image.bmp")
3542

3643

37-
if '__main__' == __name__:
44+
if __name__ == '__main__':
3845
main()

samples/python/image_generation/inpainting.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,15 @@ def main():
2828
image = read_image(args.image)
2929
mask_image = read_image(args.mask)
3030

31-
image_tensor = pipe.generate(args.prompt, image, mask_image)
31+
def callback(step, num_steps, latent):
32+
print(f"Step {step + 1}/{num_steps}")
33+
return False
34+
35+
image_tensor = pipe.generate(args.prompt, image, mask_image, callback=callback)
3236

3337
image = Image.fromarray(image_tensor.data[0])
3438
image.save("image.bmp")
3539

3640

37-
if '__main__' == __name__:
41+
if __name__ == '__main__':
3842
main()

samples/python/image_generation/text2image.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,21 @@ def main():
1717
device = 'CPU' # GPU can be used as well
1818
pipe = openvino_genai.Text2ImagePipeline(args.model_dir, device)
1919

20+
def callback(step, num_steps, latent):
21+
print(f"Step {step + 1}/{num_steps}")
22+
return False
23+
2024
image_tensor = pipe.generate(
2125
args.prompt,
2226
width=512,
2327
height=512,
2428
num_inference_steps=20,
25-
num_images_per_prompt=1)
29+
num_images_per_prompt=1,
30+
callback=callback)
2631

2732
image = Image.fromarray(image_tensor.data[0])
2833
image.save("image.bmp")
2934

3035

31-
if '__main__' == __name__:
36+
if __name__ == '__main__':
3237
main()

src/cpp/src/image_generation/threaded_callback.hpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#pragma once
55

66
#include <thread>
7+
#include <variant>
78

89
#include "synchronized_queue.hpp"
910

@@ -33,7 +34,7 @@ class ThreadedCallbackWrapper {
3334
return CallbackStatus::STOP;
3435
}
3536

36-
m_squeue.push({step, num_steps, latent});
37+
m_squeue.push(std::make_tuple(step, num_steps, latent));
3738

3839
return CallbackStatus::RUNNING;
3940
}
@@ -44,7 +45,7 @@ class ThreadedCallbackWrapper {
4445
}
4546

4647
m_status = CallbackStatus::STOP;
47-
m_squeue.empty();
48+
m_squeue.push(std::monostate());
4849

4950
if (m_worker_thread && m_worker_thread->joinable()) {
5051
m_worker_thread->join();
@@ -58,18 +59,23 @@ class ThreadedCallbackWrapper {
5859
private:
5960
std::function<bool(size_t, size_t, ov::Tensor&)> m_callback = nullptr;
6061
std::shared_ptr<std::thread> m_worker_thread = nullptr;
61-
SynchronizedQueue<std::tuple<size_t, size_t, ov::Tensor>> m_squeue;
62+
SynchronizedQueue<std::variant<std::tuple<size_t, size_t, ov::Tensor>, std::monostate>> m_squeue;
6263

6364
std::atomic<CallbackStatus> m_status = CallbackStatus::RUNNING;
6465

6566
void _worker() {
6667
while (m_status == CallbackStatus::RUNNING) {
67-
// wait for queue pull
68-
auto [step, num_steps, latent] = m_squeue.pull();
69-
70-
if (m_callback(step, num_steps, latent)) {
71-
m_status = CallbackStatus::STOP;
72-
m_squeue.empty();
68+
auto item = m_squeue.pull();
69+
70+
if (auto callback_data = std::get_if<std::tuple<size_t, size_t, ov::Tensor>>(&item)) {
71+
auto& [step, num_steps, latent] = *callback_data;
72+
const auto should_stop = m_callback(step, num_steps, latent);
73+
74+
if (should_stop) {
75+
m_status = CallbackStatus::STOP;
76+
}
77+
} else if (std::get_if<std::monostate>(&item)) {
78+
break;
7379
}
7480
}
7581
}

src/cpp/src/synchronized_queue.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class SynchronizedQueue
2222

2323
T back() {
2424
std::unique_lock<std::mutex> lock(m_mutex);
25-
m_cv.wait(lock, [this]{return !m_queue.empty(); });
25+
m_cv.wait(lock, [this]{return !m_queue.empty();});
2626
return m_queue.back();
2727
}
2828

src/python/py_utils.cpp

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,24 @@ ov::Any py_object_to_any(const py::object& py_obj, std::string property_name) {
373373
} else if (py::isinstance<ov::genai::Generator>(py_obj)) {
374374
return py::cast<std::shared_ptr<ov::genai::Generator>>(py_obj);
375375
} else if (py::isinstance<py::function>(py_obj) && property_name == "callback") {
376-
return py::cast<std::function<bool(size_t, size_t, ov::Tensor&)>>(py_obj);
376+
auto py_callback = py::cast<py::function>(py_obj);
377+
auto shared_callback = std::shared_ptr<py::function>(
378+
new py::function(py_callback),
379+
[](py::function* f) {
380+
if (Py_IsInitialized()) {
381+
PyGILState_STATE gstate = PyGILState_Ensure();
382+
delete f;
383+
PyGILState_Release(gstate);
384+
}
385+
}
386+
);
387+
388+
return std::function<bool(size_t, size_t, ov::Tensor&)>(
389+
[shared_callback](size_t step, size_t num_steps, ov::Tensor& latent) -> bool {
390+
py::gil_scoped_acquire acquire;
391+
return (*shared_callback)(step, num_steps, latent).cast<bool>();
392+
}
393+
);
377394
} else if ((py::isinstance<py::function>(py_obj) || py::isinstance<ov::genai::StreamerBase>(py_obj) || py::isinstance<std::monostate>(py_obj)) && property_name == "streamer") {
378395
auto streamer = py::cast<ov::genai::pybind::utils::PyBindStreamerVariant>(py_obj);
379396
return ov::genai::streamer(pystreamer_to_streamer(streamer)).second;
@@ -437,12 +454,25 @@ ov::genai::StreamerVariant pystreamer_to_streamer(const PyBindStreamerVariant& p
437454

438455
std::visit(overloaded {
439456
[&streamer](const std::function<std::optional<uint16_t>(py::str)>& py_callback){
440-
// Wrap python streamer with manual utf-8 decoding. Do not rely
441-
// on pybind automatic decoding since it raises exceptions on incomplete strings.
442-
auto callback_wrapped = [py_callback](std::string subword) -> ov::genai::StreamingStatus {
457+
auto shared_callback = std::shared_ptr<std::function<std::optional<uint16_t>(py::str)>>(
458+
new std::function<std::optional<uint16_t>(py::str)>(py_callback),
459+
[](std::function<std::optional<uint16_t>(py::str)>* f) {
460+
if (Py_IsInitialized()) {
461+
PyGILState_STATE gstate = PyGILState_Ensure();
462+
delete f;
463+
PyGILState_Release(gstate);
464+
}
465+
}
466+
);
467+
468+
auto callback_wrapped = [shared_callback](std::string subword) -> ov::genai::StreamingStatus {
443469
py::gil_scoped_acquire acquire;
444-
auto py_str = PyUnicode_DecodeUTF8(subword.data(), subword.length(), "replace");
445-
std::optional<uint16_t> callback_output = py_callback(py::reinterpret_borrow<py::str>(py_str));
470+
PyObject* py_str = PyUnicode_DecodeUTF8(subword.data(), subword.length(), "replace");
471+
if (!py_str) {
472+
PyErr_Clear();
473+
return StreamingStatus::RUNNING;
474+
}
475+
std::optional<uint16_t> callback_output = (*shared_callback)(py::reinterpret_steal<py::str>(py_str));
446476
if (callback_output.has_value()) {
447477
if (*callback_output == (uint16_t)StreamingStatus::RUNNING)
448478
return StreamingStatus::RUNNING;

tests/python_tests/samples/test_inpainting.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,23 @@
1111
download_mask_image = download_test_content
1212

1313
class TestInpainting:
14+
PROMPT = "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting"
15+
IMAGE_PATH = "images/image.png"
16+
MASK_PATH = "mask_image.png"
17+
1418
@pytest.mark.samples
1519
@pytest.mark.LCM_Dreamshaper_v7_int8_ov
1620
@pytest.mark.parametrize(
1721
"download_model, prompt",
1822
[
19-
pytest.param("LCM_Dreamshaper_v7-int8-ov", "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting"),
23+
pytest.param("LCM_Dreamshaper_v7-int8-ov", PROMPT),
2024
],
2125
indirect=["download_model"],
2226
)
2327
@pytest.mark.parametrize(
2428
"download_test_content, download_mask_image",
2529
[
26-
pytest.param("images/image.png", "mask_image.png"),
30+
pytest.param(IMAGE_PATH, MASK_PATH),
2731
],
2832
indirect=["download_test_content", "download_mask_image"],
2933
)

tests/python_tests/samples/test_text2image.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
from test_utils import run_sample
1010

1111
class TestText2Image:
12+
PROMPT = "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting"
13+
1214
@pytest.mark.samples
1315
@pytest.mark.dreamlike_anime_1_0
1416
@pytest.mark.parametrize(
1517
"convert_model, sample_args",
1618
[
17-
pytest.param("dreamlike-anime-1.0", "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting"),
19+
pytest.param("dreamlike-anime-1.0", PROMPT),
1820
],
1921
indirect=["convert_model"],
2022
)
@@ -29,7 +31,6 @@ def test_sample_text2image(self, convert_model, sample_args):
2931
cpp_command = [cpp_sample, convert_model, sample_args]
3032
run_sample(cpp_command)
3133

32-
3334
@pytest.mark.samples
3435
@pytest.mark.dreamlike_anime_1_0
3536
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)