1717import modelengine .fel .community .model .openai .entity .embed .OpenAiEmbedding ;
1818import modelengine .fel .community .model .openai .entity .embed .OpenAiEmbeddingRequest ;
1919import modelengine .fel .community .model .openai .entity .embed .OpenAiEmbeddingResponse ;
20+ import modelengine .fel .community .model .openai .entity .image .OpenAiImageRequest ;
21+ import modelengine .fel .community .model .openai .entity .image .OpenAiImageResponse ;
22+ import modelengine .fel .community .model .openai .enums .ModelProcessingState ;
2023import modelengine .fel .community .model .openai .util .HttpUtils ;
2124import modelengine .fel .core .chat .ChatMessage ;
2225import modelengine .fel .core .chat .ChatModel ;
2326import modelengine .fel .core .chat .ChatOption ;
2427import modelengine .fel .core .chat .Prompt ;
28+ import modelengine .fel .core .chat .support .AiMessage ;
2529import modelengine .fel .core .embed .EmbedModel ;
2630import modelengine .fel .core .embed .EmbedOption ;
2731import modelengine .fel .core .embed .Embedding ;
32+ import modelengine .fel .core .image .ImageModel ;
33+ import modelengine .fel .core .image .ImageOption ;
34+ import modelengine .fel .core .model .http .SecureConfig ;
35+ import modelengine .fit .http .client .HttpClassicClient ;
2836import modelengine .fit .http .client .HttpClassicClientFactory ;
2937import modelengine .fit .http .client .HttpClassicClientRequest ;
3038import modelengine .fit .http .client .HttpClassicClientResponse ;
3139import modelengine .fit .http .entity .ObjectEntity ;
3240import modelengine .fit .http .protocol .HttpRequestMethod ;
41+ import modelengine .fit .security .Decryptor ;
3342import modelengine .fitframework .annotation .Component ;
43+ import modelengine .fitframework .annotation .Fit ;
44+ import modelengine .fitframework .conf .Config ;
3445import modelengine .fitframework .exception .FitException ;
3546import modelengine .fitframework .flowable .Choir ;
47+ import modelengine .fitframework .ioc .BeanContainer ;
48+ import modelengine .fitframework .ioc .BeanFactory ;
49+ import modelengine .fitframework .log .Logger ;
3650import modelengine .fitframework .resource .UrlUtils ;
51+ import modelengine .fitframework .resource .web .Media ;
3752import modelengine .fitframework .serialization .ObjectSerializer ;
3853import modelengine .fitframework .util .CollectionUtils ;
54+ import modelengine .fitframework .util .LazyLoader ;
55+ import modelengine .fitframework .util .MapBuilder ;
56+ import modelengine .fitframework .util .ObjectUtils ;
3957import modelengine .fitframework .util .StringUtils ;
4058
4159import java .io .IOException ;
60+ import java .util .HashMap ;
4261import java .util .List ;
62+ import java .util .Map ;
63+ import java .util .concurrent .atomic .AtomicReference ;
64+ import java .util .stream .Collectors ;
4365
4466/**
4567 * 表示 openai 模型服务。
4870 * @since 2024-08-07
4971 */
5072@ Component
51- public class OpenAiModel implements EmbedModel , ChatModel {
73+ public class OpenAiModel implements EmbedModel , ChatModel , ImageModel {
74+ private static final Logger log = Logger .get (OpenAiModel .class );
75+ private static final Map <String , Boolean > HTTPS_CONFIG_KEY_MAPS = MapBuilder .<String , Boolean >get ()
76+ .put ("client.http.secure.ignore-trust" , Boolean .FALSE )
77+ .put ("client.http.secure.ignore-hostname" , Boolean .FALSE )
78+ .put ("client.http.secure.trust-store-file" , Boolean .FALSE )
79+ .put ("client.http.secure.trust-store-password" , Boolean .TRUE )
80+ .put ("client.http.secure.key-store-file" , Boolean .FALSE )
81+ .put ("client.http.secure.key-store-password" , Boolean .TRUE )
82+ .build ();
83+ private static final String RESPONSE_TEMPLATE = "<think>{0}<//think>{1}" ;
84+
5285 private final HttpClassicClientFactory httpClientFactory ;
53- private final HttpClassicClientFactory .Config config ;
86+ private final HttpClassicClientFactory .Config clientConfig ;
5487 private final String baseUrl ;
5588 private final String defaultApiKey ;
5689 private final ObjectSerializer serializer ;
90+ private final Config config ;
91+ private final Decryptor decryptor ;
92+ private final LazyLoader <HttpClassicClient > httpClient ;
5793
5894 /**
5995 * 创建 openai 嵌入模型服务的实例。
6096 *
6197 * @param httpClientFactory 表示 http 客户端工厂的 {@link HttpClassicClientFactory}。
62- * @param config 表示 openai http 配置的 {@link OpenAiConfig}。
98+ * @param clientConfig 表示 openai http 配置的 {@link OpenAiConfig}。
6399 * @param serializer 表示对象序列化器的 {@link ObjectSerializer}。
100+ * @param config 表示配置信息的 {@link Config}。
101+ * @param container 表示 bean 容器的 {@link BeanContainer}。
64102 * @throws IllegalArgumentException 当 {@code httpClientFactory}、{@code config} 为 {@code null} 时。
65103 */
66- public OpenAiModel (HttpClassicClientFactory httpClientFactory , OpenAiConfig config , ObjectSerializer serializer ) {
67- notNull (config , "The config cannot be null." );
104+ public OpenAiModel (HttpClassicClientFactory httpClientFactory , OpenAiConfig clientConfig ,
105+ @ Fit (alias = "json" ) ObjectSerializer serializer , Config config , BeanContainer container ) {
106+ notNull (clientConfig , "The config cannot be null." );
68107 this .httpClientFactory = notNull (httpClientFactory , "The http client factory cannot be null." );
69- this .config = HttpClassicClientFactory .Config .builder ()
70- .connectTimeout (config .getConnectTimeout ())
71- .socketTimeout (config .getReadTimeout ())
108+ this .clientConfig = HttpClassicClientFactory .Config .builder ()
109+ .connectTimeout (clientConfig .getConnectTimeout ())
110+ .socketTimeout (clientConfig .getReadTimeout ())
72111 .build ();
73112 this .serializer = notNull (serializer , "The serializer cannot be null." );
74- this .baseUrl = config .getApiBase ();
75- this .defaultApiKey = config .getApiKey ();
113+ this .baseUrl = clientConfig .getApiBase ();
114+ this .defaultApiKey = clientConfig .getApiKey ();
115+ this .httpClient = new LazyLoader <>(this ::getHttpClient );
116+ this .config = config ;
117+ this .decryptor = container .lookup (Decryptor .class )
118+ .map (BeanFactory ::<Decryptor >get )
119+ .orElseGet (() -> encrypted -> encrypted );
76120 }
77121
78122 @ Override
79123 public List <Embedding > generate (List <String > inputs , EmbedOption option ) {
80124 notEmpty (inputs , "The input cannot be empty." );
81125 notNull (option , "The embed option cannot be null." );
82126 notBlank (option .model (), "The embed model name cannot be null." );
83- HttpClassicClientRequest request = this .httpClientFactory . create ( this . config )
127+ HttpClassicClientRequest request = this .httpClient . get ( )
84128 .createRequest (HttpRequestMethod .POST , UrlUtils .combine (this .baseUrl , OpenAiApi .EMBEDDING_ENDPOINT ));
85129 HttpUtils .setBearerAuth (request , StringUtils .blankIf (option .apiKey (), this .defaultApiKey ));
86130 request .jsonEntity (new OpenAiEmbeddingRequest (inputs , option .model ()));
@@ -98,19 +142,61 @@ public List<Embedding> generate(List<String> inputs, EmbedOption option) {
98142 public Choir <ChatMessage > generate (Prompt prompt , ChatOption chatOption ) {
99143 notNull (prompt , "The prompt cannot be null." );
100144 notNull (chatOption , "The chat option cannot be null." );
101- HttpClassicClientRequest request = this .httpClientFactory .create (this .config )
102- .createRequest (HttpRequestMethod .POST , UrlUtils .combine (this .baseUrl , OpenAiApi .CHAT_ENDPOINT ));
145+ String modelSource = StringUtils .blankIf (chatOption .baseUrl (), this .baseUrl );
146+ HttpClassicClientRequest request = this .getHttpClient (chatOption .secureConfig ())
147+ .createRequest (HttpRequestMethod .POST , UrlUtils .combine (modelSource , OpenAiApi .CHAT_ENDPOINT ));
103148 HttpUtils .setBearerAuth (request , StringUtils .blankIf (chatOption .apiKey (), this .defaultApiKey ));
104149 request .jsonEntity (new OpenAiChatCompletionRequest (prompt , chatOption ));
105150 return chatOption .stream () ? this .createChatStream (request ) : this .createChatCompletion (request );
106151 }
107152
153+ @ Override
154+ public List <Media > generate (String prompt , ImageOption option ) {
155+ notNull (prompt , "The prompt cannot be null." );
156+ notNull (option , "The image option cannot be null." );
157+ String modelSource = StringUtils .blankIf (option .baseUrl (), this .baseUrl );
158+ HttpClassicClientRequest request = this .httpClient .get ()
159+ .createRequest (HttpRequestMethod .POST , UrlUtils .combine (modelSource , OpenAiApi .IMAGE_ENDPOINT ));
160+ HttpUtils .setBearerAuth (request , StringUtils .blankIf (option .apiKey (), this .defaultApiKey ));
161+ request .jsonEntity (new OpenAiImageRequest (option .model (), option .size (), prompt ));
162+ Class <OpenAiImageResponse > clazz = OpenAiImageResponse .class ;
163+ try (HttpClassicClientResponse <OpenAiImageResponse > response = request .exchange (clazz )) {
164+ return response .objectEntity ()
165+ .map (entity -> entity .object ().media ())
166+ .orElseThrow (() -> new FitException ("The response body is abnormal." ));
167+ } catch (IOException e ) {
168+ throw new IllegalStateException ("Failed to close response." , e );
169+ }
170+ }
171+
108172 private Choir <ChatMessage > createChatStream (HttpClassicClientRequest request ) {
173+ AtomicReference <ModelProcessingState > modelProcessingState =
174+ new AtomicReference <>(ModelProcessingState .INITIAL );
109175 return request .<String >exchangeStream (String .class )
110176 .filter (str -> !StringUtils .equals (str , "[DONE]" ))
111177 .map (str -> this .serializer .<OpenAiChatCompletionResponse >deserialize (str ,
112178 OpenAiChatCompletionResponse .class ))
113- .map (OpenAiChatCompletionResponse ::message );
179+ .map (response -> getChatMessage (response , modelProcessingState ));
180+ }
181+
182+ private ChatMessage getChatMessage (OpenAiChatCompletionResponse response ,
183+ AtomicReference <ModelProcessingState > state ) {
184+ // 适配reasoning_content格式返回的模型推理内容,模型生成内容顺序为先reasoning_content后content
185+ // 在第一个reasoning_content chunk之前增加<think>标签,并且在第一个content chunk之前增加</think>标签
186+ if (state .get () == ModelProcessingState .INITIAL && StringUtils .isNotEmpty (response .reasoningContent ().text ())) {
187+ String text = "<think>" + response .reasoningContent ().text ();
188+ state .set (ModelProcessingState .THINKING );
189+ return new AiMessage (text , response .message ().toolCalls ());
190+ }
191+ if (state .get () == ModelProcessingState .THINKING && StringUtils .isNotEmpty (response .message ().text ())) {
192+ String text = "</think>" + response .message ().text ();
193+ state .set (ModelProcessingState .RESPONDING );
194+ return new AiMessage (text , response .message ().toolCalls ());
195+ }
196+ if (state .get () == ModelProcessingState .THINKING ) {
197+ return new AiMessage (response .reasoningContent ().text (), response .message ().toolCalls ());
198+ }
199+ return response .message ();
114200 }
115201
116202 private Choir <ChatMessage > createChatCompletion (HttpClassicClientRequest request ) {
@@ -119,9 +205,64 @@ private Choir<ChatMessage> createChatCompletion(HttpClassicClientRequest request
119205 OpenAiChatCompletionResponse chatCompletionResponse = response .objectEntity ()
120206 .map (ObjectEntity ::object )
121207 .orElseThrow (() -> new FitException ("The response body is abnormal." ));
122- return Choir .just (chatCompletionResponse .message ());
208+ String finalMessage = chatCompletionResponse .message ().text ();
209+ if (StringUtils .isNotBlank (chatCompletionResponse .reasoningContent ().text ())) {
210+ finalMessage = StringUtils .format (RESPONSE_TEMPLATE ,
211+ chatCompletionResponse .reasoningContent ().text (),
212+ finalMessage );
213+ }
214+ return Choir .just (new AiMessage (finalMessage , chatCompletionResponse .message ().toolCalls ()));
123215 } catch (IOException e ) {
124216 throw new FitException (e );
125217 }
126218 }
219+
220+ private HttpClassicClient getHttpClient () {
221+ Map <String , Object > custom = HTTPS_CONFIG_KEY_MAPS .keySet ()
222+ .stream ()
223+ .filter (sslKey -> this .config .keys ().contains (Config .canonicalizeKey (sslKey )))
224+ .collect (Collectors .toMap (sslKey -> sslKey , sslKey -> {
225+ Object value = this .config .get (sslKey , Object .class );
226+ if (HTTPS_CONFIG_KEY_MAPS .get (sslKey )) {
227+ value = this .decryptor .decrypt (ObjectUtils .cast (value ));
228+ }
229+ return value ;
230+ }));
231+
232+ return this .httpClientFactory .create (HttpClassicClientFactory .Config .builder ()
233+ .socketTimeout (this .clientConfig .socketTimeout ())
234+ .connectTimeout (this .clientConfig .connectTimeout ())
235+ .custom (custom )
236+ .build ());
237+ }
238+
239+ private HttpClassicClient getHttpClient (SecureConfig secureConfig ) {
240+ if (secureConfig == null ) {
241+ return getHttpClient ();
242+ }
243+
244+ Map <String , Object > custom = buildHttpsConfig (secureConfig );
245+ return this .httpClientFactory .create (HttpClassicClientFactory .Config .builder ()
246+ .socketTimeout (this .clientConfig .socketTimeout ())
247+ .connectTimeout (this .clientConfig .connectTimeout ())
248+ .custom (custom )
249+ .build ());
250+ }
251+
252+ private Map <String , Object > buildHttpsConfig (SecureConfig secureConfig ) {
253+ Map <String , Object > result = new HashMap <>();
254+ putConfigIfNotNull (secureConfig .ignoreTrust (), "client.http.secure.ignore-trust" , result );
255+ putConfigIfNotNull (secureConfig .ignoreHostName (), "client.http.secure.ignore-hostname" , result );
256+ putConfigIfNotNull (secureConfig .trustStoreFile (), "client.http.secure.trust-store-file" , result );
257+ putConfigIfNotNull (secureConfig .trustStorePassword (), "client.http.secure.trust-store-password" , result );
258+ putConfigIfNotNull (secureConfig .keyStoreFile (), "client.http.secure.key-store-file" , result );
259+ putConfigIfNotNull (secureConfig .keyStorePassword (), "client.http.secure.key-store-password" , result );
260+ return result ;
261+ }
262+
263+ private static void putConfigIfNotNull (Object value , String key , Map <String , Object > result ) {
264+ if (value != null ) {
265+ result .put (key , value );
266+ }
267+ }
127268}
0 commit comments