diff --git a/src/cpp/include/openvino/genai/json_container.hpp b/src/cpp/include/openvino/genai/json_container.hpp index d109241eef..68368cebcc 100644 --- a/src/cpp/include/openvino/genai/json_container.hpp +++ b/src/cpp/include/openvino/genai/json_container.hpp @@ -80,6 +80,14 @@ class OPENVINO_GENAI_EXPORTS JsonContainer { */ static JsonContainer from_json_string(const std::string& json_str); + /** + * @brief Concatenate two JsonContainers. + * @param dst Destination JsonContainer to append to + * @param src Source JsonContainer to append from + * @throw ov::Exception if keys in both containers are not strings. + */ + static void concatenate(JsonContainer& dst, const JsonContainer& src); + /** * @brief Create JsonContainer as an empty JSON object. */ diff --git a/src/cpp/include/openvino/genai/parsers.hpp b/src/cpp/include/openvino/genai/parsers.hpp index 156d158aca..96187d043c 100644 --- a/src/cpp/include/openvino/genai/parsers.hpp +++ b/src/cpp/include/openvino/genai/parsers.hpp @@ -175,7 +175,7 @@ class OPENVINO_GENAI_EXPORTS IncrementalParser { * @return std::string Filtered text that should be added to the content */ virtual std::string parse( - JsonContainer& message, + JsonContainer& delta_message, std::string& delta_text, const std::optional>& delta_tokens = std::nullopt ) = 0; @@ -222,7 +222,7 @@ class OPENVINO_GENAI_EXPORTS ReasoningIncrementalParser : public IncrementalPars * @return std::string Filtered text with reasoning content processed according to configuration */ std::string parse( - JsonContainer& message, + JsonContainer& delta_message, std::string& delta_text, const std::optional>& delta_tokens = std::nullopt ) override; diff --git a/src/cpp/src/json_container.cpp b/src/cpp/src/json_container.cpp index 0ef004a8b6..3d027f321f 100644 --- a/src/cpp/src/json_container.cpp +++ b/src/cpp/src/json_container.cpp @@ -409,5 +409,29 @@ void* JsonContainer::_get_json_value_ptr() const { return m_impl->get_json_value_ptr(m_path, AccessMode::Read); } +void JsonContainer::concatenate(JsonContainer& dst, const JsonContainer& src) { + auto dst_ = static_cast(dst._get_json_value_ptr()); + auto src_ = static_cast(src._get_json_value_ptr()); + + for (auto it = src_->begin(); it != src_->end(); ++it) { + const auto& src_val = it.value(); + // Check if both values are of string type only if need to concatenate them. + // Otherwise just write the source value to destination. Extra check is not needed. + + if (!dst_->contains(it.key())) { + (*dst_)[it.key()] = src_val; + continue; + } + + auto& dst_val = (*dst_)[it.key()]; + OPENVINO_ASSERT( + src_val.is_string() && dst_val.is_string(), + "JsonContainer concatenate supports only string concatenation for object values. " + "Key: '", it.key(), "', src_val type: '", src_val.type_name(), "', dst_val type: '", dst_val.type_name(), "'." + ); + dst_val = dst_val.get() + src_val.get(); + } +} + } // namespace genai } // namespace ov diff --git a/src/cpp/src/parsers.cpp b/src/cpp/src/parsers.cpp index d359fad2fc..9283347220 100644 --- a/src/cpp/src/parsers.cpp +++ b/src/cpp/src/parsers.cpp @@ -20,18 +20,6 @@ class ReasoningIncrementalParser::ReasoningParserImpl { bool m_think_tag_opened = false; std::string m_text_cache = ""; bool m_deactivated = false; - - /** - * @brief Ensure required fields exist in the message container. - */ - void ensure_message_fields(JsonContainer& message) { - if (!message.contains("reasoning_content")) { - message["reasoning_content"] = ""; - } - if (!message.contains("content")) { - message["content"] = ""; - } - } /** * @brief Find the longest suffix of text that is a prefix of the close tag. @@ -61,8 +49,8 @@ class ReasoningIncrementalParser::ReasoningParserImpl { void handle_complete_reasoning(JsonContainer& message, std::string_view txt_chunk, size_t open_idx, size_t close_idx, std::string& delta_text) { // Extract reasoning content between tags - message["reasoning_content"] = std::string(txt_chunk.substr(open_idx + m_open_tag.size(), - close_idx - (open_idx + m_open_tag.size()))); + message["reasoning_content"] = std::string(txt_chunk.substr(open_idx + m_open_tag.size(), close_idx - (open_idx + m_open_tag.size()))); + message["content"] = std::string(txt_chunk.substr(close_idx + m_close_tag.size())); if (!m_keep_original_content) { delta_text = std::string(txt_chunk.substr(close_idx + m_close_tag.size())); @@ -76,16 +64,16 @@ class ReasoningIncrementalParser::ReasoningParserImpl { /** * @brief Handle the case where only the open tag is found. */ - void handle_open_tag(JsonContainer& message, std::string& reason_str, - std::string_view txt_chunk, size_t open_idx, std::string& delta_text) { + void handle_open_tag(JsonContainer& delta_message, std::string_view txt_chunk, size_t open_idx, std::string& delta_text) { // Start accumulating reasoning content - reason_str.append(txt_chunk.substr(open_idx + m_open_tag.size())); - message["reasoning_content"] = std::move(reason_str); + delta_message["reasoning_content"] = std::string(txt_chunk.substr(open_idx + m_open_tag.size())); if (!m_keep_original_content) { delta_text.clear(); + } else { + delta_text = txt_chunk; } - + m_think_tag_opened = true; m_text_cache.clear(); } @@ -93,14 +81,20 @@ class ReasoningIncrementalParser::ReasoningParserImpl { /** * @brief Handle the case where the close tag is found. */ - void handle_close_tag(JsonContainer& message, std::string& reason_str, - std::string_view txt_chunk, size_t close_idx, std::string& delta_text) { + void handle_close_tag(JsonContainer& delta_message, std::string_view txt_chunk, size_t close_idx, std::string& delta_text) { // Append text before close tag to reasoning content - reason_str.append(txt_chunk.substr(0, close_idx)); - message["reasoning_content"] = std::move(reason_str); + delta_message["reasoning_content"] = std::move(std::string(txt_chunk.substr(0, close_idx))); + auto content = std::string(txt_chunk.substr(close_idx + m_close_tag.size())); + delta_message["content"] = content; if (!m_keep_original_content) { - delta_text = std::string(txt_chunk.substr(close_idx + m_close_tag.size())); + // Despite the fact that we put txt_chunk to delta_text it's correct. + // Since txt_chunk contains some cached parts from the previous calls that were not yet processed yet + // and we kept them in cache until we decide what to do with them. Here we definitely know that that cached parts + // belonged to reasoning_content so we can discard them. + delta_text = content; + } else { + delta_text = txt_chunk; } m_text_cache.clear(); @@ -111,25 +105,27 @@ class ReasoningIncrementalParser::ReasoningParserImpl { /** * @brief Handle accumulating text while inside reasoning tags. */ - void handle_inside_reasoning(JsonContainer& message, std::string& reason_str, - std::string_view txt_chunk, std::string& delta_text) { + void handle_inside_reasoning(JsonContainer& delta_message, std::string_view txt_chunk, std::string& delta_text) { // Find if the end of txt_chunk might be the start of a close tag const size_t num_chars_to_keep = find_close_tag_prefix_length(txt_chunk); + std::string reason_str; if (num_chars_to_keep > 0) { // Keep potential partial close tag in cache m_text_cache = std::string(txt_chunk.substr(txt_chunk.size() - num_chars_to_keep)); - reason_str.append(txt_chunk.substr(0, txt_chunk.size() - num_chars_to_keep)); + reason_str = txt_chunk.substr(0, txt_chunk.size() - num_chars_to_keep); + if (m_keep_original_content) { + delta_text = std::string(txt_chunk.substr(0, txt_chunk.size() - num_chars_to_keep)); + } } else { // No partial close tag, accumulate all text - reason_str.append(txt_chunk); + reason_str = txt_chunk; m_text_cache.clear(); } - + delta_message["reasoning_content"] = std::move(reason_str); if (!m_keep_original_content) { delta_text.clear(); } - message["reasoning_content"] = std::move(reason_str); } public: @@ -145,7 +141,7 @@ class ReasoningIncrementalParser::ReasoningParserImpl { m_close_tag(close_tag) {} std::string parse( - JsonContainer& message, + JsonContainer& delta_message, std::string& delta_text, const std::optional>& delta_tokens ) { @@ -157,13 +153,8 @@ class ReasoningIncrementalParser::ReasoningParserImpl { } m_first_run = false; - ensure_message_fields(message); - const std::string txt_chunk = m_text_cache + delta_text; - std::string reason_str; - if (message.contains("reasoning_content")) { - reason_str = std::move(message["reasoning_content"].get_string()); - } + std::string txt_chunk = m_text_cache + delta_text; // Cache find() results to avoid redundant searches const auto open_idx = txt_chunk.find(m_open_tag); @@ -175,20 +166,22 @@ class ReasoningIncrementalParser::ReasoningParserImpl { ? close_idx : std::string::npos; if (close_idx_after_open != std::string::npos) { - handle_complete_reasoning(message, txt_chunk, open_idx, close_idx_after_open, delta_text); + handle_complete_reasoning(delta_message, txt_chunk, open_idx, close_idx_after_open, delta_text); } else { - handle_open_tag(message, reason_str, txt_chunk, open_idx, delta_text); + handle_open_tag(delta_message, txt_chunk, open_idx, delta_text); } } else if (m_think_tag_opened && close_idx != std::string::npos) { - handle_close_tag(message, reason_str, txt_chunk, close_idx, delta_text); + handle_close_tag(delta_message, txt_chunk, close_idx, delta_text); } else if (m_think_tag_opened) { - handle_inside_reasoning(message, reason_str, txt_chunk, delta_text); + handle_inside_reasoning(delta_message, txt_chunk, delta_text); } else { // Think tag was not opened yet and not found in the current delta_text. // Accumulate text in the cache to detect if is split between several delta_text pieces. m_text_cache += delta_text; + // Intentionally clear delta_text: no delta content is returned to the user during this phase + // (we are waiting for the tag to be fully detected in the cache). + delta_text.clear(); } - return delta_text; } @@ -207,11 +200,11 @@ ReasoningIncrementalParser::ReasoningIncrementalParser(bool expect_open_tag, boo ReasoningIncrementalParser::~ReasoningIncrementalParser() = default; std::string ReasoningIncrementalParser::parse( - JsonContainer& message, + JsonContainer& delta_message, std::string& delta_text, const std::optional>& delta_tokens ) { - return m_impl->parse(message, delta_text, delta_tokens); + return m_impl->parse(delta_message, delta_text, delta_tokens); } void ReasoningIncrementalParser::reset() { diff --git a/src/cpp/src/text_streamer.cpp b/src/cpp/src/text_streamer.cpp index 56a32fe35b..9edfb0ca6e 100644 --- a/src/cpp/src/text_streamer.cpp +++ b/src/cpp/src/text_streamer.cpp @@ -141,18 +141,21 @@ std::vector> m_parsers; JsonContainer m_parsed_message; TextParserStreamerImpl(std::vector> parsers) : m_parsers{parsers} {} + }; TextParserStreamer::TextParserStreamer(const Tokenizer& tokenizer, std::vector> parsers) : TextStreamer(tokenizer, [this](std::string s) -> CallbackTypeVariant { return this->write(s); - }), m_pimpl{std::make_unique(parsers)} {} + }), m_pimpl{std::make_unique(parsers)} { + m_pimpl->m_parsed_message["content"] = ""; + } -CallbackTypeVariant TextParserStreamer::write(std::string message) { +CallbackTypeVariant TextParserStreamer::write(std::string delta_text) { // When 'write' is called with string, it means new chunk of tokens is decoded into text auto flushed_tokens = std::vector(); - if (message.back() == '\n') { + if (delta_text.back() == '\n') { // Flush all tokens flushed_tokens.assign(m_tokens_cache.begin(), m_tokens_cache.end()); } else if (m_decoded_lengths.size() >= delay_n_tokens) { @@ -177,13 +180,19 @@ CallbackTypeVariant TextParserStreamer::write(std::string message) { } } + // Every time we start to cycle through iterative parsers we create a new delta_message. + // Parsers should neither delete fields nor rewrite; they should only append or add new fields. + // The only field that is updated automatically is "content": delta_text is put there. + JsonContainer delta_message; // Iterate over all parsers and apply them to the message for (auto& parser: m_pimpl->m_parsers) { - message = parser->parse(m_pimpl->m_parsed_message, message, flushed_tokens); + delta_text = parser->parse(delta_message, delta_text, flushed_tokens); // Message can be modified inside parser, if parser for example extracted tool calling from message content - m_pimpl->m_parsed_message["content"] = m_pimpl->m_parsed_message["content"].get_string() + message; } - return write(m_pimpl->m_parsed_message); + delta_message["content"] = delta_text; + + JsonContainer::concatenate(m_pimpl->m_parsed_message, delta_message); + return write(delta_message); } JsonContainer TextParserStreamer::get_parsed_message() const { @@ -192,6 +201,7 @@ JsonContainer TextParserStreamer::get_parsed_message() const { void TextParserStreamer::reset() { m_pimpl->m_parsed_message = JsonContainer(); + m_pimpl->m_parsed_message["content"] = ""; for (auto& parser : m_pimpl->m_parsers) { parser->reset(); } diff --git a/tests/cpp/parser.cpp b/tests/cpp/parser.cpp index e4db4da3f2..af4be108c5 100644 --- a/tests/cpp/parser.cpp +++ b/tests/cpp/parser.cpp @@ -92,12 +92,13 @@ TEST_F(DeepSeekR1ReasoningParserTest, ReasoningContentAccumulatesAcrossCalls) { std::string ref_res = "First, I recognize that the question is asking for the sum of 2 and 1.\n\nI know that addition involves combining two numbers to find their total.\n\nStarting with 2, I add 1 to it.\n\n2 plus 1 equals 3.\n"; JsonContainer msg; - + JsonContainer accumulated_msg; for (int i = 1; i < input_stream.size(); i++) { std::string delta_text = input_stream[i]; delta_text = parser.parse(msg, delta_text); + JsonContainer::concatenate(accumulated_msg, msg); } - ASSERT_EQ(msg["reasoning_content"], ref_res); + ASSERT_EQ(accumulated_msg["reasoning_content"], ref_res); } TEST(ParserTest, test_custom_parser) { diff --git a/tests/python_tests/test_parsers.py b/tests/python_tests/test_parsers.py index 7828eb77a7..6cf8d6aa28 100644 --- a/tests/python_tests/test_parsers.py +++ b/tests/python_tests/test_parsers.py @@ -7,8 +7,20 @@ from openvino_genai import Tokenizer, IncrementalParser, Parser, TextParserStreamer, StreamingStatus, Llama3JsonToolParser, Phi4ReasoningParser, Phi4ReasoningIncrementalParser, DeepSeekR1ReasoningIncrementalParser, GenerationConfig, ReasoningIncrementalParser from transformers import AutoTokenizer import re +from io import StringIO +def concatenate_dicts(dst_dict, src_dict): + # keys that exist in both dictionaries + keys = set(dst_dict.keys()).intersection(set(src_dict.keys())) + for key in keys: + dst_dict[key] += src_dict[key] + + # keys that exist in src_dict but missing in dst_dict + missing_keys = set(src_dict.keys()) - set(dst_dict.keys()) + for key in missing_keys: + dst_dict[key] = src_dict[key] + @pytest.fixture(scope="module") def hf_ov_genai_models(request, tmp_path_factory): model_id = request.param @@ -23,6 +35,83 @@ def hf_ov_genai_models(request, tmp_path_factory): return hf_tokenizer, genai_tokenizer +@pytest.mark.parametrize( + "hf_ov_genai_models", + ["katuni4ka/tiny-random-phi3"], # this tokenizer is used as a stub only + indirect=True +) +def test_several_incremental_parsers(hf_ov_genai_models): + hf_tokenizer, genai_tokenizer = hf_ov_genai_models + + class CustomReasonParser(IncrementalParser): + thinking_started: bool = False + deactivated: bool = False + + def parse(self, message: dict, delta_text: str, delta_tokens = None) -> dict: + + if self.deactivated: + return delta_text + + if not self.thinking_started and delta_text == '': + self.thinking_started = True + elif self.thinking_started and delta_text != '': + message["reasoning_content"] = delta_text + elif self.thinking_started and delta_text == '': + self.deactivated = True + + return delta_text + + + class IncrementalToolParser(IncrementalParser): + started_took_call: bool = False + accumulated_tool_call: StringIO = StringIO() + deactivated: bool = False + + def parse(self, delta_msg: dict, delta_text: str, delta_tokens = None) -> str: + if self.deactivated: + return delta_text + + if delta_text == '{' and not self.started_took_call: + self.started_took_call = True + self.accumulated_tool_call.write(delta_text) + + # If not keep took call in resulting string + # delta_text = '' + elif self.started_took_call and delta_text == '}': + self.started_took_call = False + self.accumulated_tool_call.write(delta_text) + self.deactivated = True + delta_msg["tool_calls"] = [json.loads(self.accumulated_tool_call.getvalue())] + # If not keep took call in resulting string + # delta_text = '' + elif self.started_took_call: + self.accumulated_tool_call.write(delta_text) + + return delta_text + + + class CustomStreamer(TextParserStreamer): + def write(self, message): + print(message) + return StreamingStatus.RUNNING + + streamer = CustomStreamer(genai_tokenizer, parsers=[IncrementalToolParser(), CustomReasonParser()]) + + stream_string = ["Hello", "", " ", "world", " ", "", "!", "{", '"func_name": ', '"weather", ' '"location": "New York"', "}"] + think_content = " world " + # content = ''.join(stream_string).replace("", "").replace("", "") + content = ''.join(stream_string) + tool_call = {"func_name": "weather", "location": "New York"} + + for subword in stream_string: + streamer._write(subword) + + final_msg = streamer.get_parsed_message() + assert final_msg["reasoning_content"] == think_content + assert final_msg["content"] == content + assert final_msg["tool_calls"][0] == tool_call + + @pytest.mark.parametrize( "hf_ov_genai_models", ["katuni4ka/tiny-random-phi3"], # this tokenizer is used as a stub only @@ -45,21 +134,25 @@ def test_incremental_phi4_reason_parser_1(hf_ov_genai_models, answer): stream_string = re.split(r"(\s+)", answer) + # manually accumulate content from streamer + content = StringIO() + class CustomStreamer(TextParserStreamer): def write(self, message): - msg.update(message) + nonlocal content + content.write(message["content"]) return StreamingStatus.RUNNING streamer = CustomStreamer(genai_tokenizer, parsers=[Phi4ReasoningIncrementalParser()]) - msg = {} for subword in stream_string: streamer._write(subword) think_content = answer.split("")[0].replace("", "") - content = answer - - assert msg['reasoning_content'] == think_content - assert msg['content'] == content + + msg = streamer.get_parsed_message() + assert msg["reasoning_content"] == think_content + assert msg["content"] == answer + assert msg["content"].endswith(content.getvalue()) @pytest.mark.parametrize( @@ -70,13 +163,13 @@ def write(self, message): def test_incremental_phi4_reason_integer_token_ids(hf_ov_genai_models): hf_tokenizer, genai_tokenizer = hf_ov_genai_models + accumulated_message = {} class CustomStreamer(TextParserStreamer): - def write(self, message): - msg.update(message) + def write(self, delta_message): + concatenate_dicts(accumulated_message, delta_message) return StreamingStatus.RUNNING streamer = CustomStreamer(genai_tokenizer, parsers=[Phi4ReasoningIncrementalParser()]) - msg = {} answer = "\nOkay, the user is asking for the answer to 2 + 1.\n\nThe answer to 2 + 1 is \boxed{3}." encoded_tokens = genai_tokenizer.encode(answer).input_ids.data.tolist()[0] for token in encoded_tokens: @@ -84,10 +177,12 @@ def write(self, message): streamer.end() think_content = answer.split("")[0].replace("", "") - content = answer - assert msg['reasoning_content'] == think_content - assert msg['content'] == content + msg = streamer.get_parsed_message() + assert msg["reasoning_content"] == think_content + assert msg["content"] == answer + assert accumulated_message["reasoning_content"] == think_content + assert answer.endswith(accumulated_message["content"]) @pytest.mark.parametrize( @@ -101,35 +196,30 @@ def test_incremental_integer_token_ids(hf_ov_genai_models): class CustomIncrementalParser(IncrementalParser): started_reasoning: bool = False - def parse(self, msg: dict, delta_text: str, delta_tokens = None) -> str: - if 'content' not in msg: - msg['content'] = '' - if 'reasoning_content' not in msg: - msg['reasoning_content'] = '' - + def parse(self, delta_message: dict, delta_text: str, delta_tokens = None) -> str: if 1 in delta_tokens and not self.started_reasoning: self.started_reasoning = True - msg['reasoning_content'] += delta_text + delta_message["reasoning_content"] = delta_text delta_text = '' elif 1 in delta_tokens and self.started_reasoning: self.started_reasoning = False delta_text = '' elif self.started_reasoning: - msg['reasoning_content'] += delta_text + delta_message["reasoning_content"] = delta_text delta_text = '' - + # # Here we are only collecting ordinary text, therefore leave delta_text unchanged. - # # msg['content'] += delta_text will happen under the hood + delta_message["content"] = delta_text # will happen under the hood return delta_text - + + accumulated_message = {} class CustomStreamer(TextParserStreamer): - def write(self, message): - msg.update(message) + def write(self, delta_message): + concatenate_dicts(accumulated_message, delta_message) return StreamingStatus.RUNNING streamer = CustomStreamer(genai_tokenizer, parsers=[CustomIncrementalParser()]) - msg = {} - # All closing tags , <|/inst|>, <|endoftext|>, ent. in tiny-random-phi3 add strange \x0c\x0c characters + # All closing tags , <|/inst|>, <|endoftext|>, etc. in tiny-random-phi3 add strange \x0c\x0c characters # so we avoid them in this test. answer = "\nOkay, the user is asking for the answer to 2 + 1.The answer to 2 + 1 is 3." encoded_tokens = genai_tokenizer.encode(answer, add_special_tokens=False).input_ids.data.tolist()[0] @@ -137,9 +227,9 @@ def write(self, message): for token in encoded_tokens: streamer._write([token]) streamer.end() - - assert msg['reasoning_content'] == "\nOkay, the user is asking for the answer to 2 + 1" - assert msg['content'] == " The answer to 2 + 1 is 3." + + assert accumulated_message["reasoning_content"] == "\nOkay, the user is asking for the answer to 2 + 1" + assert accumulated_message["content"] == " The answer to 2 + 1 is 3." @pytest.mark.parametrize( @@ -159,21 +249,26 @@ def test_incremental_phi4_reason_parser_2(hf_ov_genai_models, split_answer): # check that if thinking opening and closing tags are in the middle of the subword, it is still parsed correctly hf_tokenizer, genai_tokenizer = hf_ov_genai_models + msg_manual = {} class CustomStreamer(TextParserStreamer): def write(self, message): - msg.update(message) + # will be accumulated automatically inside streamer + concatenate_dicts(msg_manual, message) return StreamingStatus.RUNNING streamer = CustomStreamer(genai_tokenizer, parsers=[Phi4ReasoningIncrementalParser()]) - msg = {} for subword in split_answer: streamer._write(subword) think_content = (''.join(split_answer)).split("")[0].replace("", "") content = ''.join(split_answer) - assert msg['reasoning_content'] == think_content - assert msg['content'] == content + msg = streamer.get_parsed_message() + assert msg["reasoning_content"] == think_content + assert msg["content"].endswith(content) # since msg contains all accumulated content + assert msg_manual["reasoning_content"] == think_content + assert msg_manual["content"] == content + @pytest.mark.parametrize("answer", [ @@ -184,18 +279,16 @@ def test_incremental_phi4_reason_parser_nostreamer(answer): parser = Phi4ReasoningIncrementalParser() stream_string = re.split(r"(\s+)", answer) - msg = {} + + accumulated_message = {} for subword in stream_string: - parser.parse(msg, subword) - # When parser is called from streamer, it is expected that content is accumulated inside streamer. - # Here we are calling parser manually therefore we need to accumulate content manually. - msg['content'] += subword + delta_message = {} # msg when the first parser is called should be empty + parser.parse(delta_message, subword) + concatenate_dicts(accumulated_message, delta_message) think_content = answer.split("")[0].replace("", "") - content = answer - assert msg['reasoning_content'] == think_content - assert msg['content'] == content + assert accumulated_message["reasoning_content"] == think_content @pytest.mark.parametrize("keep_original_content", [True, False]) @@ -213,14 +306,14 @@ def test_reasoning_parser_cut_content(hf_ov_genai_models, answer, keep_original_ stream_string = re.split(r"(\s+)", answer) + msg = {} class CustomStreamer(TextParserStreamer): def write(self, message): - msg.update(message) + concatenate_dicts(msg, message) return StreamingStatus.RUNNING streamer = CustomStreamer(genai_tokenizer, parsers=[ReasoningIncrementalParser(expect_open_tag=True, keep_original_content=keep_original_content)]) num_runs = 2 - msg = {} for i in range(num_runs): if do_reset: streamer.reset() @@ -232,15 +325,14 @@ def write(self, message): if do_reset: # If has been reset, check that content is parsed correctly - assert msg['reasoning_content'] == think_content - assert msg['content'] == (answer if keep_original_content else "\n\nThe answer to 2 + 1 is \boxed{3}.") + assert msg["reasoning_content"] == think_content + assert msg["content"] == (answer if keep_original_content else "\n\nThe answer to 2 + 1 is \boxed{3}.") else: - # If has not been reset(), then content msg['content'] will continue to accumulate thinking parts from the next runs - assert msg['content'].find("") >= 0 + # If has not been reset(), then content msg["content"] will continue to accumulate thinking parts from the next runs + assert msg["content"].find("") >= 0 def test_incremental_deepseek_parser(): - msg = {} stream_string = [ "<|begin▁of▁sentence|>", "First", ",", " I", " recognize", " that", " the", " question", " is", " asking", " for", " the", " sum", " of", " ", "2", " and", " ", "1", ".\n\n", "I", " know", " that", " addition", @@ -254,14 +346,15 @@ def test_incremental_deepseek_parser(): full_str = ''.join(stream_string) think_content = full_str.split("")[0] - content = full_str.split("")[1] + delta_message = {} + accumulated_message = {} parser = DeepSeekR1ReasoningIncrementalParser() for subword in stream_string: - msg = parser.parse(msg, subword) + parser.parse(delta_message, subword) + concatenate_dicts(accumulated_message, delta_message) - assert msg['reasoning_content'] == think_content - assert msg['content'] == content + assert accumulated_message["reasoning_content"] == think_content @pytest.mark.parametrize( @@ -275,26 +368,21 @@ def test_custom_incremental_parser(hf_ov_genai_models): class CustomParser(IncrementalParser): main_part_started: bool = False - def parse(self, msg: dict, delta_text: str, delta_tokens = None) -> str: - if 'content' not in msg: - msg['content'] = '' - if 'main_text' not in msg: - msg['main_text'] = '' - + def parse(self, delta_message: dict, delta_text: str, delta_tokens = None) -> str: if not self.main_part_started and delta_text == '': self.main_part_started = True elif self.main_part_started and delta_text == '': self.main_part_started = False else: if self.main_part_started: - msg['main_text'] += delta_text - + delta_message["main_text"] = delta_text + delta_message["content"] = delta_text return delta_text - msg = {} + accumulated_message = {} class CustomStreamer(TextParserStreamer): - def write(self, message): - msg.update(message) + def write(self, delta_message): + concatenate_dicts(accumulated_message, delta_message) return StreamingStatus.RUNNING streamer = CustomStreamer(genai_tokenizer, parsers=[CustomParser()]) @@ -302,8 +390,7 @@ def write(self, message): for subword in stream_string: streamer._write(subword) - - assert msg['main_text'] == " world " + assert accumulated_message["main_text"] == " world " @pytest.mark.parametrize( @@ -321,7 +408,7 @@ def test_final_parser_llama_32_json(hf_ov_genai_models): parser = Llama3JsonToolParser() parser.parse(content_json) - assert content_json['tool_calls'][0] == json.loads(json_str) + assert content_json["tool_calls"][0] == json.loads(json_str) @pytest.mark.parametrize("model_id", ["microsoft/Phi-4-mini-reasoning"]) @@ -329,13 +416,12 @@ def test_final_parser_llama_32_json(hf_ov_genai_models): def test_custom_parser(tmp_path, model_id): models_path = download_and_convert_model(model_id, padding_side="left").models_path pipe = create_ov_pipeline(models_path) - tok = pipe.get_tokenizer() class CustomParser(Parser): def parse(self, msg: dict): content = None if 'content' in msg: - content = msg['content'] + content = msg["content"] if not content: return @@ -344,7 +430,7 @@ def parse(self, msg: dict): think_end = content.find("") if think_start != -1 and think_end != -1 and think_end > think_start: think_text = content[think_start + len(""):think_end].strip() - msg['reasoning_content'] = think_text + msg["reasoning_content"] = think_text parser = CustomParser() config = GenerationConfig() @@ -361,8 +447,8 @@ def parse(self, msg: dict): think_text = content[think_start + len(""):think_end].strip() assert 'reasoning_content' in res.parsed[0] - assert res.parsed[0]['reasoning_content'] != "" - assert res.parsed[0]['reasoning_content'] == think_text + assert res.parsed[0]["reasoning_content"] != "" + assert res.parsed[0]["reasoning_content"] == think_text @pytest.mark.parametrize("model_id", ["microsoft/Phi-4-mini-reasoning"]) @@ -388,8 +474,8 @@ def write(self, message): think_text = content[think_start + len(""):think_end] assert 'reasoning_content' in res.parsed[0] - assert res.parsed[0]['reasoning_content'] != "" - assert res.parsed[0]['reasoning_content'] == think_text + assert res.parsed[0]["reasoning_content"] != "" + assert res.parsed[0]["reasoning_content"] == think_text res_streamer_1 = pipe.generate([prompt], max_new_tokens=600, streamer=streamer) res_streamer_2 = pipe.generate([prompt], max_new_tokens=600, streamer=streamer)