Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA;
import static dev.langchain4j.model.openai.internal.OpenAiUtils.aiMessageFrom;
import static dev.langchain4j.model.openai.internal.OpenAiUtils.finishReasonFrom;
import static dev.langchain4j.model.openai.internal.OpenAiUtils.toFunctions;
Expand All @@ -13,9 +14,11 @@
import java.net.Proxy;
import java.time.Duration;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import org.jboss.logging.Logger;
Expand All @@ -24,6 +27,7 @@
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.ModelProvider;
import dev.langchain4j.model.TokenCountEstimator;
import dev.langchain4j.model.chat.Capability;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
import dev.langchain4j.model.chat.listener.ChatModelListener;
Expand Down Expand Up @@ -74,6 +78,7 @@ public class AzureOpenAiChatModel implements ChatModel {
private final TokenCountEstimator tokenizer;
private final ResponseFormat responseFormat;
private final List<ChatModelListener> listeners;
private final Set<Capability> supportedCapabilities;

public AzureOpenAiChatModel(String endpoint,
String apiVersion,
Expand Down Expand Up @@ -128,9 +133,14 @@ public AzureOpenAiChatModel(String endpoint,
: ResponseFormat.builder()
.type(ResponseFormatType.valueOf(responseFormat.toUpperCase(Locale.ROOT)))
.build();

// Azure OpenAI supports JSON schema for models like gpt-4o-2024-08-06+
this.supportedCapabilities = new HashSet<>();
if (this.responseFormat != null && ResponseFormatType.JSON_SCHEMA.equals(this.responseFormat.type())) {
this.supportedCapabilities.add(RESPONSE_FORMAT_JSON_SCHEMA);
}
}

@Override
public ChatResponse doChat(ChatRequest chatRequest) {
List<ChatMessage> messages = chatRequest.messages();
List<ToolSpecification> toolSpecifications = chatRequest.toolSpecifications();
Expand All @@ -143,7 +153,7 @@ public ChatResponse doChat(ChatRequest chatRequest) {
.maxTokens(maxTokens)
.presencePenalty(presencePenalty)
.frequencyPenalty(frequencyPenalty)
.responseFormat(responseFormat);
.responseFormat(this.responseFormat);

if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
requestBuilder.functions(toFunctions(toolSpecifications));
Expand Down Expand Up @@ -239,6 +249,11 @@ private ChatResponse createModelListenerResponse(String responseId,
.build();
}

@Override
public Set<Capability> supportedCapabilities() {
return supportedCapabilities;
}

public static Builder builder() {
return new Builder();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import static dev.langchain4j.internal.Utils.isNullOrBlank;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA;
import static dev.langchain4j.model.openai.internal.OpenAiUtils.toFunctions;
import static dev.langchain4j.model.openai.internal.OpenAiUtils.toOpenAiMessages;
import static io.quarkiverse.langchain4j.azure.openai.Consts.DEFAULT_USER_AGENT;
Expand All @@ -12,9 +13,11 @@
import java.net.Proxy;
import java.time.Duration;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;

Expand All @@ -25,6 +28,7 @@
import dev.langchain4j.model.ModelProvider;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.TokenCountEstimator;
import dev.langchain4j.model.chat.Capability;
import dev.langchain4j.model.chat.StreamingChatModel;
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
import dev.langchain4j.model.chat.listener.ChatModelListener;
Expand Down Expand Up @@ -78,6 +82,7 @@ public class AzureOpenAiStreamingChatModel implements StreamingChatModel {
private final TokenCountEstimator tokenizer;
private final ResponseFormat responseFormat;
private final List<ChatModelListener> listeners;
private final Set<Capability> supportedCapabilities;

public AzureOpenAiStreamingChatModel(String endpoint,
String apiVersion,
Expand Down Expand Up @@ -124,11 +129,16 @@ public AzureOpenAiStreamingChatModel(String endpoint,
: ResponseFormat.builder()
.type(ResponseFormatType.valueOf(responseFormat.toUpperCase(Locale.ROOT)))
.build();

// Azure OpenAI supports JSON schema for models like gpt-4o-2024-08-06+
this.supportedCapabilities = new HashSet<>();
if (this.responseFormat != null && ResponseFormatType.JSON_SCHEMA.equals(this.responseFormat.type())) {
this.supportedCapabilities.add(RESPONSE_FORMAT_JSON_SCHEMA);
}
}

@Override
public void doChat(ChatRequest chatRequest, StreamingChatResponseHandler handler) {

List<ChatMessage> messages = chatRequest.messages();
List<ToolSpecification> toolSpecifications = chatRequest.toolSpecifications();

Expand All @@ -140,7 +150,7 @@ public void doChat(ChatRequest chatRequest, StreamingChatResponseHandler handler
.maxTokens(maxTokens)
.presencePenalty(presencePenalty)
.frequencyPenalty(frequencyPenalty)
.responseFormat(responseFormat);
.responseFormat(this.responseFormat);

Integer inputTokenCount = tokenizer == null ? null : tokenizer.estimateTokenCountInMessages(messages);

Expand Down Expand Up @@ -272,6 +282,11 @@ private ChatResponse createModelListenerResponse(String responseId,
.build();
}

@Override
public Set<Capability> supportedCapabilities() {
return supportedCapabilities;
}

public static Builder builder() {
return new Builder();
}
Expand Down