Skip to content

Commit 6957d74

Browse files
committed
Fix for deadlock in python callback
1 parent eabc2de commit 6957d74

File tree

2 files changed

+43
-12
lines changed

2 files changed

+43
-12
lines changed

src/python/py_utils.cpp

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -374,10 +374,22 @@ ov::Any py_object_to_any(const py::object& py_obj, std::string property_name) {
374374
return py::cast<std::shared_ptr<ov::genai::Generator>>(py_obj);
375375
} else if (py::isinstance<py::function>(py_obj) && property_name == "callback") {
376376
auto py_callback = py::cast<py::function>(py_obj);
377+
auto shared_callback = std::shared_ptr<py::function>(
378+
new py::function(py_callback),
379+
[](py::function* f) {
380+
if (Py_IsInitialized()) {
381+
py::gil_scoped_acquire acquire;
382+
delete f;
383+
} else {
384+
delete f;
385+
}
386+
}
387+
);
388+
377389
return std::function<bool(size_t, size_t, ov::Tensor&)>(
378-
[py_callback](size_t step, size_t num_steps, ov::Tensor& latent) -> bool {
390+
[shared_callback](size_t step, size_t num_steps, ov::Tensor& latent) -> bool {
379391
py::gil_scoped_acquire acquire;
380-
return py_callback(step, num_steps, latent).cast<bool>();
392+
return (*shared_callback)(step, num_steps, latent).cast<bool>();
381393
}
382394
);
383395
} else if ((py::isinstance<py::function>(py_obj) || py::isinstance<ov::genai::StreamerBase>(py_obj) || py::isinstance<std::monostate>(py_obj)) && property_name == "streamer") {
@@ -443,21 +455,40 @@ ov::genai::StreamerVariant pystreamer_to_streamer(const PyBindStreamerVariant& p
443455

444456
std::visit(overloaded {
445457
[&streamer](const std::function<std::optional<uint16_t>(py::str)>& py_callback){
446-
// Wrap python streamer with manual utf-8 decoding. Do not rely
447-
// on pybind automatic decoding since it raises exceptions on incomplete strings.
448-
auto callback_wrapped = [py_callback](std::string subword) -> ov::genai::StreamingStatus {
458+
auto shared_callback = std::shared_ptr<std::function<std::optional<uint16_t>(py::str)>>(
459+
new std::function<std::optional<uint16_t>(py::str)>(py_callback),
460+
[](std::function<std::optional<uint16_t>(py::str)>* f) {
461+
if (Py_IsInitialized()) {
462+
py::gil_scoped_acquire acquire;
463+
delete f;
464+
} else {
465+
delete f;
466+
}
467+
}
468+
);
469+
470+
auto callback_wrapped = [shared_callback](std::string subword) -> ov::genai::StreamingStatus {
449471
py::gil_scoped_acquire acquire;
450-
auto py_str = PyUnicode_DecodeUTF8(subword.data(), subword.length(), "replace");
451-
std::optional<uint16_t> callback_output = py_callback(py::reinterpret_borrow<py::str>(py_str));
472+
PyObject* py_str = PyUnicode_DecodeUTF8(subword.data(), subword.length(), "replace");
473+
if (!py_str) {
474+
PyErr_WriteUnraisable(nullptr);
475+
return StreamingStatus::RUNNING;
476+
}
477+
auto py_str_obj = py::reinterpret_steal<py::str>(py_str);
478+
std::optional<uint16_t> callback_output;
479+
try {
480+
callback_output = (*shared_callback)(py_str_obj);
481+
} catch (const py::error_already_set&) {
482+
return StreamingStatus::RUNNING;
483+
}
452484
if (callback_output.has_value()) {
453-
if (*callback_output == (uint16_t)StreamingStatus::RUNNING)
485+
if (*callback_output == static_cast<uint16_t>(StreamingStatus::RUNNING))
454486
return StreamingStatus::RUNNING;
455-
else if (*callback_output == (uint16_t)StreamingStatus::CANCEL)
487+
else if (*callback_output == static_cast<uint16_t>(StreamingStatus::CANCEL))
456488
return StreamingStatus::CANCEL;
457489
return StreamingStatus::STOP;
458-
} else {
459-
return StreamingStatus::RUNNING;
460490
}
491+
return StreamingStatus::RUNNING;
461492
};
462493
streamer = callback_wrapped;
463494
},

src/python/py_utils.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace ov::genai::pybind::utils {
1919
// When StreamerVariant is used utf-8 decoding is done by pybind and can lead to exception on incomplete texts.
2020
// Therefore strings decoding should be handled with PyUnicode_DecodeUTF8(..., "replace") to not throw errors.
2121
using PyBindStreamerVariant = std::variant<
22-
std::function<std::optional<uint16_t>(std::string)>,
22+
std::function<std::optional<uint16_t>(py::str)>,
2323
std::shared_ptr<StreamerBase>,
2424
std::monostate>;
2525

0 commit comments

Comments
 (0)