-
Notifications
You must be signed in to change notification settings - Fork 236
Open
Description
We ran into a perf corner case in Harmony's tokenizer. When tokenizing a long sequence of dashes (regular ASCII dash -, nothing fancy) it exhibits O(N^2) behavior and gets really slow.
For comparison HF tokenizer doesn't exhibit this issue
Repro:
from openai_harmony import load_harmony_encoding
import time
from transformers import AutoTokenizer
encoding = load_harmony_encoding("HarmonyGptOss")
encoding.encode("Hello")
hf_tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-120b")
hf_tokenizer.encode("Hello")
for n in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]:
t = time.time()
harmony_tokens = encoding.encode("Hello" + "-"*n)
harmony_time = time.time() - t
t = time.time()
hf_tokens = hf_tokenizer.encode("Hello" + "-"*n)
hf_time = time.time() - t
assert harmony_tokens == hf_tokens, f"harmony_tokens: {harmony_tokens}, hf_tokens: {hf_tokens}"
print(f"n={n:10d}, harmony_time={harmony_time:5.6f}, hf_time={hf_time:5.6f}")
n= 1, harmony_time=0.001993, hf_time=0.000117
n= 2, harmony_time=0.001793, hf_time=0.000313
n= 4, harmony_time=0.001720, hf_time=0.000055
n= 8, harmony_time=0.001693, hf_time=0.000046
n= 16, harmony_time=0.001701, hf_time=0.000046
n= 32, harmony_time=0.001707, hf_time=0.000049
n= 64, harmony_time=0.001692, hf_time=0.000050
n= 128, harmony_time=0.001722, hf_time=0.000094
n= 256, harmony_time=0.001751, hf_time=0.000116
n= 512, harmony_time=0.001789, hf_time=0.000209
n= 1024, harmony_time=0.002183, hf_time=0.000373
n= 2048, harmony_time=0.007312, hf_time=0.000666
n= 4096, harmony_time=0.031775, hf_time=0.001345
n= 8192, harmony_time=0.133721, hf_time=0.002831
n= 16384, harmony_time=0.561266, hf_time=0.006466
n= 32768, harmony_time=2.223819, hf_time=0.013950
O(N^2) is clearly visible.
I'm using the latest version
pip show openai-harmony
Name: openai-harmony
Version: 0.0.4
Summary: OpenAI's response format for its open-weight model series gpt-oss
Home-page:
Author:
Author-email:
License:
I suspect it has something to do with very long tokens and encoding being non-incremental:
n= 192 [7535, 7535, 7535]
n= 193 [7535, 7535, 7535, 12]
n= 194 [7535, 7535, 7535, 375]
n= 195 [7535, 7535, 7535, 10356]
n= 196 [7535, 7535, 7535, 518]
n= 197 [7535, 7535, 7535, 26067]
n= 198 [7535, 7535, 98699]
Metadata
Metadata
Assignees
Labels
No labels