Skip to content

Commit 3ee354b

Browse files
committed
feat(util):增加计算token消耗的工具类
1 parent 889004c commit 3ee354b

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package com.ashin.util;
2+
3+
import com.knuddels.jtokkit.Encodings;
4+
import com.knuddels.jtokkit.api.Encoding;
5+
import com.knuddels.jtokkit.api.EncodingRegistry;
6+
import com.knuddels.jtokkit.api.ModelType;
7+
import com.theokanning.openai.completion.chat.ChatMessage;
8+
import lombok.var;
9+
import org.springframework.stereotype.Component;
10+
11+
import java.util.List;
12+
13+
/**
14+
* 分词器
15+
*
16+
* @author ashinnotfound
17+
* @date 2023/08/06
18+
*/
19+
@Component
20+
public class Tokenizer {
21+
22+
private final EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
23+
24+
/**
25+
* 计算消息token
26+
* via https://jtokkit.knuddels.de/docs/getting-started/recipes/chatml
27+
*
28+
* @param model 模型
29+
* @param messages 消息
30+
* @return int
31+
*/
32+
public int countMessageTokens(ModelType model, List<ChatMessage> messages) {
33+
Encoding encoding = registry.getEncodingForModel(model);
34+
int tokensPerMessage = 0;
35+
if (model.getName().startsWith("gpt-4")) {
36+
tokensPerMessage = 3;
37+
} else if (model.getName().startsWith("gpt-3.5-turbo")) {
38+
tokensPerMessage = 4; // every message follows <|start|>{role/name}\n{content}<|end|>\n
39+
}
40+
41+
int sum = 0;
42+
for (final var message : messages) {
43+
sum += tokensPerMessage;
44+
sum += encoding.countTokens(message.getContent());
45+
sum += encoding.countTokens(message.getRole());
46+
}
47+
48+
sum += 3; // every reply is primed with <|start|>assistant<|message|>
49+
50+
return sum;
51+
}
52+
53+
public int countMessageTokens(String modelName, List<ChatMessage> messages) {
54+
return countMessageTokens(getModelTypeByName(modelName), messages);
55+
}
56+
57+
/**
58+
* 根据名字获取模型类型
59+
*
60+
* @param modelName 模型名称
61+
* @return {@code ModelType}
62+
*/
63+
public ModelType getModelTypeByName(String modelName){
64+
if (ModelType.GPT_4.getName().equals(modelName)){
65+
return ModelType.GPT_4;
66+
} else if (ModelType.GPT_4_32K.getName().equals(modelName)){
67+
return ModelType.GPT_4_32K;
68+
} else if (ModelType.GPT_3_5_TURBO_16K.getName().equals(modelName)){
69+
return ModelType.GPT_3_5_TURBO_16K;
70+
} else {
71+
return ModelType.GPT_3_5_TURBO;
72+
}
73+
}
74+
}

0 commit comments

Comments
 (0)