|
11 | 11 | #import "ExecuTorchLLMError.h" |
12 | 12 |
|
13 | 13 | #import <executorch/extension/llm/runner/text_llm_runner.h> |
| 14 | +#import <memory> |
14 | 15 |
|
15 | 16 | using namespace executorch::extension; |
16 | 17 | using namespace executorch::runtime; |
17 | 18 |
|
| 19 | +namespace { |
| 20 | + |
| 21 | +/// A streaming UTF-8 buffer that accumulates bytes until complete UTF-8 |
| 22 | +/// sequences are formed. This handles the case where BPE tokenizers output |
| 23 | +/// partial multi-byte UTF-8 sequences across token boundaries. |
| 24 | +/// |
| 25 | +/// For example, the Chinese character "清" (UTF-8: E6 B8 85) might be split |
| 26 | +/// across two tokens: "æ¸" (E6 B8) and "ħ" (85). This buffer accumulates |
| 27 | +/// bytes and only emits complete, valid UTF-8 strings. |
| 28 | +class UTF8StreamingBuffer { |
| 29 | +public: |
| 30 | + UTF8StreamingBuffer() = default; |
| 31 | + |
| 32 | + /// Process incoming token bytes and return any complete UTF-8 string. |
| 33 | + /// Returns empty string if more bytes are needed to complete a sequence. |
| 34 | + /// Invalid bytes are silently skipped to maintain robustness. |
| 35 | + std::string process(const std::string& token) { |
| 36 | + buffer_.append(token); |
| 37 | + |
| 38 | + std::string result; |
| 39 | + size_t i = 0; |
| 40 | + |
| 41 | + while (i < buffer_.size()) { |
| 42 | + unsigned char byte = static_cast<unsigned char>(buffer_[i]); |
| 43 | + size_t seqLen = utf8SequenceLength(byte); |
| 44 | + |
| 45 | + if (seqLen == 0) { |
| 46 | + // Invalid start byte (lone continuation or illegal byte) - skip it |
| 47 | + i++; |
| 48 | + continue; |
| 49 | + } |
| 50 | + |
| 51 | + if (i + seqLen > buffer_.size()) { |
| 52 | + // Incomplete sequence at the end - keep in buffer for next call |
| 53 | + break; |
| 54 | + } |
| 55 | + |
| 56 | + // Verify all continuation bytes are valid |
| 57 | + bool valid = true; |
| 58 | + for (size_t j = 1; j < seqLen; j++) { |
| 59 | + if (!isUTF8Continuation(static_cast<unsigned char>(buffer_[i + j]))) { |
| 60 | + valid = false; |
| 61 | + break; |
| 62 | + } |
| 63 | + } |
| 64 | + |
| 65 | + if (valid) { |
| 66 | + // Append complete valid sequence to result |
| 67 | + result.append(buffer_, i, seqLen); |
| 68 | + i += seqLen; |
| 69 | + } else { |
| 70 | + // Invalid sequence - skip only the start byte and resync |
| 71 | + i++; |
| 72 | + } |
| 73 | + } |
| 74 | + |
| 75 | + // Keep only the incomplete sequence (if any) for next call |
| 76 | + if (i < buffer_.size()) { |
| 77 | + buffer_ = buffer_.substr(i); |
| 78 | + } else { |
| 79 | + buffer_.clear(); |
| 80 | + } |
| 81 | + |
| 82 | + return result; |
| 83 | + } |
| 84 | + |
| 85 | + /// Flush any remaining bytes in the buffer. |
| 86 | + /// Called at the end of generation to emit any leftover content. |
| 87 | + /// Skips any invalid bytes that couldn't form valid UTF-8. |
| 88 | + std::string flush() { |
| 89 | + std::string result; |
| 90 | + |
| 91 | + for (size_t i = 0; i < buffer_.size(); i++) { |
| 92 | + unsigned char byte = static_cast<unsigned char>(buffer_[i]); |
| 93 | + size_t seqLen = utf8SequenceLength(byte); |
| 94 | + |
| 95 | + // Skip invalid start bytes |
| 96 | + if (seqLen == 0) { |
| 97 | + continue; |
| 98 | + } |
| 99 | + |
| 100 | + // Check if we have enough bytes for this sequence |
| 101 | + if (i + seqLen > buffer_.size()) { |
| 102 | + // Incomplete sequence - skip remaining bytes |
| 103 | + break; |
| 104 | + } |
| 105 | + |
| 106 | + // Verify continuation bytes |
| 107 | + bool valid = true; |
| 108 | + for (size_t j = 1; j < seqLen; j++) { |
| 109 | + if (!isUTF8Continuation(static_cast<unsigned char>(buffer_[i + j]))) { |
| 110 | + valid = false; |
| 111 | + break; |
| 112 | + } |
| 113 | + } |
| 114 | + |
| 115 | + if (valid) { |
| 116 | + result.append(buffer_, i, seqLen); |
| 117 | + i += seqLen - 1; // -1 because loop will i++ |
| 118 | + } |
| 119 | + } |
| 120 | + |
| 121 | + buffer_.clear(); |
| 122 | + return result; |
| 123 | + } |
| 124 | + |
| 125 | +private: |
| 126 | + std::string buffer_; |
| 127 | + |
| 128 | + /// Returns the number of bytes expected for a UTF-8 sequence starting with |
| 129 | + /// the given byte. Returns 0 for invalid start bytes, including overlong |
| 130 | + /// encodings (0xC0, 0xC1) and out-of-range bytes (0xF5-0xFF). |
| 131 | + static size_t utf8SequenceLength(unsigned char byte) { |
| 132 | + if ((byte & 0x80) == 0x00) return 1; // 0xxxxxxx - ASCII |
| 133 | + if (byte == 0xC0 || byte == 0xC1) return 0; // Overlong encoding - invalid |
| 134 | + if ((byte & 0xE0) == 0xC0) return 2; // 110xxxxx |
| 135 | + if ((byte & 0xF0) == 0xE0) return 3; // 1110xxxx |
| 136 | + if (byte >= 0xF5) return 0; // Out of Unicode range - invalid |
| 137 | + if ((byte & 0xF8) == 0xF0) return 4; // 11110xxx |
| 138 | + return 0; // Continuation byte (10xxxxxx) or other invalid |
| 139 | + } |
| 140 | + |
| 141 | + /// Returns true if the byte is a valid UTF-8 continuation byte (10xxxxxx). |
| 142 | + static bool isUTF8Continuation(unsigned char byte) { |
| 143 | + return (byte & 0xC0) == 0x80; |
| 144 | + } |
| 145 | +}; |
| 146 | + |
| 147 | +} // anonymous namespace |
| 148 | + |
18 | 149 | @interface ExecuTorchLLMConfig () |
19 | 150 |
|
20 | 151 | - (const llm::GenerationConfig &)nativeConfig; |
@@ -88,15 +219,47 @@ - (BOOL)generateWithPrompt:(NSString*)prompt |
88 | 219 | if (![self loadWithError:error]) { |
89 | 220 | return NO; |
90 | 221 | } |
| 222 | + |
| 223 | + // Create a UTF-8 streaming buffer to handle partial multi-byte sequences. |
| 224 | + // BPE tokenizers (especially ByteLevel like GPT-2/SmolLM) can output tokens |
| 225 | + // that split UTF-8 characters at byte boundaries. This buffer accumulates |
| 226 | + // bytes until complete UTF-8 sequences are formed before calling the callback. |
| 227 | + auto utf8Buffer = std::make_shared<UTF8StreamingBuffer>(); |
| 228 | + |
91 | 229 | auto status = _runner->generate( |
92 | 230 | prompt.UTF8String, |
93 | 231 | config.nativeConfig, |
94 | | - [callback](const std::string& token) { |
| 232 | + [callback, utf8Buffer](const std::string& token) { |
95 | 233 | if (callback) { |
96 | | - callback(@(token.c_str())); |
| 234 | + // Process token through UTF-8 buffer |
| 235 | + std::string validUTF8 = utf8Buffer->process(token); |
| 236 | + |
| 237 | + // Only call callback when we have complete UTF-8 sequences |
| 238 | + if (!validUTF8.empty()) { |
| 239 | + NSString *tokenString = [[NSString alloc] initWithBytes:validUTF8.data() |
| 240 | + length:validUTF8.size() |
| 241 | + encoding:NSUTF8StringEncoding]; |
| 242 | + if (tokenString) { |
| 243 | + callback(tokenString); |
| 244 | + } |
| 245 | + } |
97 | 246 | } |
98 | 247 | } |
99 | 248 | ); |
| 249 | + |
| 250 | + // Flush any remaining bytes in the buffer |
| 251 | + if (callback) { |
| 252 | + std::string remaining = utf8Buffer->flush(); |
| 253 | + if (!remaining.empty()) { |
| 254 | + NSString *remainingString = [[NSString alloc] initWithBytes:remaining.data() |
| 255 | + length:remaining.size() |
| 256 | + encoding:NSUTF8StringEncoding]; |
| 257 | + if (remainingString) { |
| 258 | + callback(remainingString); |
| 259 | + } |
| 260 | + } |
| 261 | + } |
| 262 | + |
100 | 263 | if (status != Error::Ok) { |
101 | 264 | if (error) { |
102 | 265 | *error = [NSError errorWithDomain:ExecuTorchLLMErrorDomain |
|
0 commit comments