Skip to content

Commit 2be242f

Browse files
committed
write chunks in TextParserStreamer
1 parent ffa946a commit 2be242f

File tree

3 files changed

+45
-44
lines changed

3 files changed

+45
-44
lines changed

src/cpp/src/parsers.cpp

Lines changed: 15 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,10 @@ class ReasoningIncrementalParser::ReasoningParserImpl {
5959
* @brief Handle the case where both open and close tags are found in the same chunk.
6060
*/
6161
void handle_complete_reasoning(JsonContainer& message, std::string_view txt_chunk,
62-
size_t open_idx, size_t close_idx, std::string& delta_text) {
62+
size_t open_idx, size_t close_idx) {
6363
// Extract reasoning content between tags
64-
message["reasoning_content"] = std::string(txt_chunk.substr(open_idx + m_open_tag.size(),
65-
close_idx - (open_idx + m_open_tag.size())));
66-
67-
if (!m_keep_original_content) {
68-
delta_text = std::string(txt_chunk.substr(close_idx + m_close_tag.size()));
69-
}
64+
message["reasoning_content"] = std::string(txt_chunk.substr(open_idx + m_open_tag.size(), close_idx - (open_idx + m_open_tag.size())));
65+
message["content"] = std::string(txt_chunk.substr(close_idx + m_close_tag.size()));
7066

7167
m_think_tag_opened = false;
7268
m_deactivated = true;
@@ -77,31 +73,22 @@ class ReasoningIncrementalParser::ReasoningParserImpl {
7773
* @brief Handle the case where only the open tag is found.
7874
*/
7975
void handle_open_tag(JsonContainer& message, std::string& reason_str,
80-
std::string_view txt_chunk, size_t open_idx, std::string& delta_text) {
76+
std::string_view txt_chunk, size_t open_idx) {
8177
// Start accumulating reasoning content
8278
reason_str.append(txt_chunk.substr(open_idx + m_open_tag.size()));
8379
message["reasoning_content"] = std::move(reason_str);
84-
85-
if (!m_keep_original_content) {
86-
delta_text.clear();
87-
}
88-
80+
8981
m_think_tag_opened = true;
9082
m_text_cache.clear();
9183
}
9284

9385
/**
9486
* @brief Handle the case where the close tag is found.
9587
*/
96-
void handle_close_tag(JsonContainer& message, std::string& reason_str,
97-
std::string_view txt_chunk, size_t close_idx, std::string& delta_text) {
88+
void handle_close_tag(JsonContainer& message, std::string_view txt_chunk, size_t close_idx) {
9889
// Append text before close tag to reasoning content
99-
reason_str.append(txt_chunk.substr(0, close_idx));
100-
message["reasoning_content"] = std::move(reason_str);
101-
102-
if (!m_keep_original_content) {
103-
delta_text = std::string(txt_chunk.substr(close_idx + m_close_tag.size()));
104-
}
90+
message["reasoning_content"] = std::move(std::string(txt_chunk.substr(0, close_idx)));
91+
message["content"] = std::string(txt_chunk.substr(close_idx + m_close_tag.size()));;
10592

10693
m_text_cache.clear();
10794
m_think_tag_opened = false;
@@ -111,8 +98,7 @@ class ReasoningIncrementalParser::ReasoningParserImpl {
11198
/**
11299
* @brief Handle accumulating text while inside reasoning tags.
113100
*/
114-
void handle_inside_reasoning(JsonContainer& message, std::string& reason_str,
115-
std::string_view txt_chunk, std::string& delta_text) {
101+
void handle_inside_reasoning(JsonContainer& message, std::string& reason_str, std::string_view txt_chunk) {
116102
// Find if the end of txt_chunk might be the start of a close tag
117103
const size_t num_chars_to_keep = find_close_tag_prefix_length(txt_chunk);
118104

@@ -126,9 +112,6 @@ class ReasoningIncrementalParser::ReasoningParserImpl {
126112
m_text_cache.clear();
127113
}
128114

129-
if (!m_keep_original_content) {
130-
delta_text.clear();
131-
}
132115
message["reasoning_content"] = std::move(reason_str);
133116
}
134117

@@ -150,6 +133,7 @@ class ReasoningIncrementalParser::ReasoningParserImpl {
150133
const std::optional<std::vector<int64_t>>& delta_tokens
151134
) {
152135
if (m_deactivated) {
136+
message["content"] = delta_text;
153137
return delta_text;
154138
}
155139
if (!m_expect_open_tag && m_first_run) {
@@ -160,10 +144,7 @@ class ReasoningIncrementalParser::ReasoningParserImpl {
160144
ensure_message_fields(message);
161145

162146
const std::string txt_chunk = m_text_cache + delta_text;
163-
std::string reason_str;
164-
if (message.contains("reasoning_content")) {
165-
reason_str = std::move(message["reasoning_content"].get_string());
166-
}
147+
std::string reason_str = std::move(message["reasoning_content"].get_string());
167148

168149
// Cache find() results to avoid redundant searches
169150
const auto open_idx = txt_chunk.find(m_open_tag);
@@ -175,14 +156,14 @@ class ReasoningIncrementalParser::ReasoningParserImpl {
175156
? close_idx : std::string::npos;
176157

177158
if (close_idx_after_open != std::string::npos) {
178-
handle_complete_reasoning(message, txt_chunk, open_idx, close_idx_after_open, delta_text);
159+
handle_complete_reasoning(message, txt_chunk, open_idx, close_idx_after_open);
179160
} else {
180-
handle_open_tag(message, reason_str, txt_chunk, open_idx, delta_text);
161+
handle_open_tag(message, reason_str, txt_chunk, open_idx);
181162
}
182163
} else if (m_think_tag_opened && close_idx != std::string::npos) {
183-
handle_close_tag(message, reason_str, txt_chunk, close_idx, delta_text);
164+
handle_close_tag(message, txt_chunk, close_idx);
184165
} else if (m_think_tag_opened) {
185-
handle_inside_reasoning(message, reason_str, txt_chunk, delta_text);
166+
handle_inside_reasoning(message, reason_str, txt_chunk);
186167
} else {
187168
// Think tag was not opened yet and not found in the current delta_text.
188169
// Accumulate text in the cache to detect if <think> is split between several delta_text pieces.

src/cpp/src/text_streamer.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,23 @@ std::vector<std::shared_ptr<IncrementalParser>> m_parsers;
141141
JsonContainer m_parsed_message;
142142

143143
TextParserStreamerImpl(std::vector<std::shared_ptr<IncrementalParser>> parsers) : m_parsers{parsers} {}
144+
144145
};
145146

147+
void concatenate_json_containers(JsonContainer& from, const JsonContainer& to, std::vector<std::string> keys_to_concatenate) {
148+
for (const auto& key : keys_to_concatenate) {
149+
if (to.contains(key) && from.contains(key)) {
150+
// If both are strings, concatenate
151+
if (to[key].is_string() && from[key].is_string()) {
152+
to[key] = to[key].get_string() + from[key].get_string();
153+
}
154+
} else if (from.contains(key)) {
155+
auto r = from[key];
156+
to[key] = from[key];
157+
}
158+
}
159+
}
160+
146161
TextParserStreamer::TextParserStreamer(const Tokenizer& tokenizer, std::vector<std::shared_ptr<IncrementalParser>> parsers)
147162
: TextStreamer(tokenizer, [this](std::string s) -> CallbackTypeVariant {
148163
return this->write(s);
@@ -177,13 +192,18 @@ CallbackTypeVariant TextParserStreamer::write(std::string message) {
177192
}
178193
}
179194

195+
JsonContainer msg;
180196
// Iterate over all parsers and apply them to the message
181197
for (auto& parser: m_pimpl->m_parsers) {
182-
message = parser->parse(m_pimpl->m_parsed_message, message, flushed_tokens);
198+
message = parser->parse(msg, message, flushed_tokens);
183199
// Message can be modified inside parser, if parser for example extracted tool calling from message content
184-
m_pimpl->m_parsed_message["content"] = m_pimpl->m_parsed_message["content"].get_string() + message;
185200
}
186-
return write(m_pimpl->m_parsed_message);
201+
202+
// concatenate msg with m_parsed_message
203+
concatenate_json_containers(msg, m_pimpl->m_parsed_message, {"content", "reasoning_content"});
204+
205+
// return write(m_pimpl->m_parsed_message);
206+
return write(msg);
187207
}
188208

189209
JsonContainer TextParserStreamer::get_parsed_message() const {

tests/python_tests/test_parsers.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from utils.hugging_face import convert_and_save_tokenizer, download_and_convert_model
55
from utils.ov_genai_pipelines import create_ov_pipeline
66
import pytest
7-
from openvino_genai import Tokenizer, IncrementalParser, Parser, TextParserStreamer, StreamingStatus, Llama3JsonToolParser, Phi4ReasoningParser, Phi4ReasoningIncrementalParser, DeepSeekR1ReasoningIncrementalParser, GenerationConfig, ReasoningIncrementalParser
7+
from openvino_genai import Tokenizer, IncrementalParser, Parser, TextParserStreamer, StreamingStatus, Llama3JsonToolParser, Phi4ReasoningParser, ReasoningParser, Phi4ReasoningIncrementalParser, DeepSeekR1ReasoningIncrementalParser, GenerationConfig, ReasoningIncrementalParser
88
from transformers import AutoTokenizer
99
import re
1010

@@ -51,13 +51,13 @@ def write(self, message):
5151
return StreamingStatus.RUNNING
5252
streamer = CustomStreamer(genai_tokenizer, parsers=[Phi4ReasoningIncrementalParser()])
5353

54-
msg = {}
5554
for subword in stream_string:
5655
streamer._write(subword)
5756

5857
think_content = answer.split("</think>")[0].replace("<think>", "")
5958
content = answer
60-
59+
60+
msg = streamer.get_parsed_message()
6161
assert msg['reasoning_content'] == think_content
6262
assert msg['content'] == content
6363

@@ -161,17 +161,17 @@ def test_incremental_phi4_reason_parser_2(hf_ov_genai_models, split_answer):
161161

162162
class CustomStreamer(TextParserStreamer):
163163
def write(self, message):
164-
msg.update(message)
164+
# will be accumulated automatically inside streamer
165165
return StreamingStatus.RUNNING
166166
streamer = CustomStreamer(genai_tokenizer, parsers=[Phi4ReasoningIncrementalParser()])
167167

168-
msg = {}
169168
for subword in split_answer:
170169
streamer._write(subword)
171170

172171
think_content = (''.join(split_answer)).split("</think>")[0].replace("<think>", "")
173-
content = ''.join(split_answer)
172+
content = (''.join(split_answer).split("</think>")[1])
174173

174+
msg = streamer.get_parsed_message()
175175
assert msg['reasoning_content'] == think_content
176176
assert msg['content'] == content
177177

@@ -378,7 +378,7 @@ def write(self, message):
378378
streamer = CustomStreamer(tok, parsers=[Phi4ReasoningIncrementalParser()])
379379

380380
prompt = "Please say \"hello\""
381-
res = pipe.generate([prompt], max_new_tokens=600, parsers=[Phi4ReasoningParser()])
381+
res = pipe.generate([prompt], max_new_tokens=600, parsers=[ReasoningParser(keep_original_content=False)])
382382

383383
# extract manually reasoning content from the parsed result
384384
content = res.texts[0]

0 commit comments

Comments
 (0)