Skip to content

Commit ef16d49

Browse files
committed
Fix for deadlock in python callback
1 parent 65b9356 commit ef16d49

File tree

1 file changed

+31
-7
lines changed

1 file changed

+31
-7
lines changed

src/python/py_utils.cpp

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -374,10 +374,21 @@ ov::Any py_object_to_any(const py::object& py_obj, std::string property_name) {
374374
return py::cast<std::shared_ptr<ov::genai::Generator>>(py_obj);
375375
} else if (py::isinstance<py::function>(py_obj) && property_name == "callback") {
376376
auto py_callback = py::cast<py::function>(py_obj);
377+
auto shared_callback = std::shared_ptr<py::function>(
378+
new py::function(py_callback),
379+
[](py::function* f) {
380+
if (Py_IsInitialized()) {
381+
PyGILState_STATE gstate = PyGILState_Ensure();
382+
delete f;
383+
PyGILState_Release(gstate);
384+
}
385+
}
386+
);
387+
377388
return std::function<bool(size_t, size_t, ov::Tensor&)>(
378-
[py_callback](size_t step, size_t num_steps, ov::Tensor& latent) -> bool {
389+
[shared_callback](size_t step, size_t num_steps, ov::Tensor& latent) -> bool {
379390
py::gil_scoped_acquire acquire;
380-
return py_callback(step, num_steps, latent).cast<bool>();
391+
return (*shared_callback)(step, num_steps, latent).cast<bool>();
381392
}
382393
);
383394
} else if ((py::isinstance<py::function>(py_obj) || py::isinstance<ov::genai::StreamerBase>(py_obj) || py::isinstance<std::monostate>(py_obj)) && property_name == "streamer") {
@@ -443,12 +454,25 @@ ov::genai::StreamerVariant pystreamer_to_streamer(const PyBindStreamerVariant& p
443454

444455
std::visit(overloaded {
445456
[&streamer](const std::function<std::optional<uint16_t>(py::str)>& py_callback){
446-
// Wrap python streamer with manual utf-8 decoding. Do not rely
447-
// on pybind automatic decoding since it raises exceptions on incomplete strings.
448-
auto callback_wrapped = [py_callback](std::string subword) -> ov::genai::StreamingStatus {
457+
auto shared_callback = std::shared_ptr<std::function<std::optional<uint16_t>(py::str)>>(
458+
new std::function<std::optional<uint16_t>(py::str)>(py_callback),
459+
[](std::function<std::optional<uint16_t>(py::str)>* f) {
460+
if (Py_IsInitialized()) {
461+
PyGILState_STATE gstate = PyGILState_Ensure();
462+
delete f;
463+
PyGILState_Release(gstate);
464+
}
465+
}
466+
);
467+
468+
auto callback_wrapped = [shared_callback](std::string subword) -> ov::genai::StreamingStatus {
449469
py::gil_scoped_acquire acquire;
450-
auto py_str = PyUnicode_DecodeUTF8(subword.data(), subword.length(), "replace");
451-
std::optional<uint16_t> callback_output = py_callback(py::reinterpret_borrow<py::str>(py_str));
470+
PyObject* py_str = PyUnicode_DecodeUTF8(subword.data(), subword.length(), "replace");
471+
if (!py_str) {
472+
PyErr_Clear();
473+
return StreamingStatus::RUNNING;
474+
}
475+
std::optional<uint16_t> callback_output = (*shared_callback)(py::reinterpret_steal<py::str>(py_str));
452476
if (callback_output.has_value()) {
453477
if (*callback_output == (uint16_t)StreamingStatus::RUNNING)
454478
return StreamingStatus::RUNNING;

0 commit comments

Comments
 (0)