Skip to content

Commit e8e0209

Browse files
authored
[AI Agents Extension] Add models to service config (#6042)
* Add models to service config Signed-off-by: trangevi <[email protected]> * Add env var set during preprovision Signed-off-by: trangevi <[email protected]> * PR comments Signed-off-by: trangevi <[email protected]> * Move containerapp handling Signed-off-by: trangevi <[email protected]> --------- Signed-off-by: trangevi <[email protected]>
1 parent 1e59052 commit e8e0209

File tree

4 files changed

+285
-156
lines changed

4 files changed

+285
-156
lines changed

cli/azd/extensions/azure.ai.agents/internal/cmd/init.go

Lines changed: 41 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"azureaiagent/internal/pkg/agents/agent_yaml"
1818
"azureaiagent/internal/pkg/agents/registry_api"
1919
"azureaiagent/internal/pkg/azure/ai"
20+
"azureaiagent/internal/project"
2021

2122
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
2223
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
@@ -28,6 +29,7 @@ import (
2829
"github.com/azure/azure-dev/cli/azd/pkg/ux"
2930
"github.com/fatih/color"
3031
"github.com/spf13/cobra"
32+
"google.golang.org/protobuf/types/known/structpb"
3133
"gopkg.in/yaml.v3"
3234
)
3335

@@ -65,6 +67,7 @@ type GitHubUrlInfo struct {
6567
}
6668

6769
const AiAgentHost = "azure.ai.agent"
70+
const ContainerAppHost = "containerapp"
6871

6972
func newInitCommand() *cobra.Command {
7073
flags := &initFlags{}
@@ -201,11 +204,6 @@ func (a *InitAction) Run(ctx context.Context, flags *initFlags) error {
201204
return fmt.Errorf("failed to add agent to azure.yaml: %w", err)
202205
}
203206

204-
// Update environment with necessary env vars
205-
if err := a.updateEnvironment(ctx, agentManifest, flags.host); err != nil {
206-
return fmt.Errorf("failed to update environment: %w", err)
207-
}
208-
209207
color.Green("\nAI agent added to your project successfully!")
210208
}
211209

@@ -780,16 +778,49 @@ func (a *InitAction) addToProject(ctx context.Context, targetDir string, agentMa
780778

781779
switch host {
782780
case "containerapp":
783-
serviceHost = "containerapp"
781+
serviceHost = ContainerAppHost
784782
default:
785783
serviceHost = AiAgentHost
786784
}
787785

786+
var agentConfig = project.ServiceTargetAgentConfig{}
787+
788+
deploymentDetails := []project.Deployment{}
789+
switch agentDef.Kind {
790+
case agent_yaml.AgentKindPrompt:
791+
agentDef := agentManifest.Template.(agent_yaml.PromptAgent)
792+
793+
modelDeployment, err := a.getModelDeploymentDetails(ctx, agentDef.Model)
794+
if err != nil {
795+
return fmt.Errorf("failed to get model deployment details: %w", err)
796+
}
797+
deploymentDetails = append(deploymentDetails, *modelDeployment)
798+
case agent_yaml.AgentKindHosted:
799+
agentDef := agentManifest.Template.(agent_yaml.HostedContainerAgent)
800+
801+
// Iterate over all models in the hosted container agent
802+
for _, model := range agentDef.Models {
803+
modelDeployment, err := a.getModelDeploymentDetails(ctx, model)
804+
if err != nil {
805+
return fmt.Errorf("failed to get model deployment details: %w", err)
806+
}
807+
deploymentDetails = append(deploymentDetails, *modelDeployment)
808+
}
809+
}
810+
811+
agentConfig.Deployments = deploymentDetails
812+
813+
var agentConfigStruct *structpb.Struct
814+
if agentConfigStruct, err = project.MarshalStruct(&agentConfig); err != nil {
815+
return fmt.Errorf("failed to marshal agent config: %w", err)
816+
}
817+
788818
serviceConfig := &azdext.ServiceConfig{
789819
Name: strings.ReplaceAll(agentDef.Name, " ", ""),
790820
RelativePath: targetDir,
791821
Host: serviceHost,
792822
Language: "python",
823+
Config: agentConfigStruct,
793824
}
794825

795826
req := &azdext.AddServiceRequest{Service: serviceConfig}
@@ -1279,90 +1310,6 @@ func (a *InitAction) selectFromList(
12791310
return options[*resp.Value], nil
12801311
}
12811312

1282-
func (a *InitAction) updateEnvironment(ctx context.Context, agentManifest *agent_yaml.AgentManifest, host string) error {
1283-
// Convert the template to bytes
1284-
templateBytes, err := json.Marshal(agentManifest.Template)
1285-
if err != nil {
1286-
return fmt.Errorf("failed to marshal agent template to JSON: %w", err)
1287-
}
1288-
1289-
// Convert the bytes to a dictionary
1290-
var templateDict map[string]interface{}
1291-
if err := json.Unmarshal(templateBytes, &templateDict); err != nil {
1292-
return fmt.Errorf("failed to unmarshal agent template from JSON: %w", err)
1293-
}
1294-
1295-
// Convert the dictionary to bytes
1296-
dictJsonBytes, err := json.Marshal(templateDict)
1297-
if err != nil {
1298-
return fmt.Errorf("failed to marshal templateDict to JSON: %w", err)
1299-
}
1300-
1301-
// Convert the bytes to an Agent Definition
1302-
var agentDef agent_yaml.AgentDefinition
1303-
if err := json.Unmarshal(dictJsonBytes, &agentDef); err != nil {
1304-
return fmt.Errorf("failed to unmarshal JSON to AgentDefinition: %w", err)
1305-
}
1306-
1307-
fmt.Printf("Updating environment variables for agent kind: %s\n", agentDef.Kind)
1308-
1309-
// Get current environment
1310-
envResponse, err := a.azdClient.Environment().GetCurrent(ctx, &azdext.EmptyRequest{})
1311-
if err != nil {
1312-
return fmt.Errorf("failed to get current environment: %w", err)
1313-
}
1314-
1315-
if envResponse.Environment == nil {
1316-
return fmt.Errorf("no current environment found")
1317-
}
1318-
1319-
envName := envResponse.Environment.Name
1320-
deploymentDetails := []Deployment{}
1321-
1322-
// Set environment variables based on agent kind
1323-
switch agentDef.Kind {
1324-
case agent_yaml.AgentKindPrompt:
1325-
agentDef := agentManifest.Template.(agent_yaml.PromptAgent)
1326-
1327-
modelDeployment, err := a.getModelDeploymentDetails(ctx, agentDef.Model)
1328-
if err != nil {
1329-
return fmt.Errorf("failed to get model deployment details: %w", err)
1330-
}
1331-
deploymentDetails = append(deploymentDetails, *modelDeployment)
1332-
case agent_yaml.AgentKindHosted:
1333-
agentDef := agentManifest.Template.(agent_yaml.HostedContainerAgent)
1334-
if err := a.setEnvVar(ctx, envName, "ENABLE_HOSTED_AGENTS", "true"); err != nil {
1335-
return err
1336-
}
1337-
1338-
// Iterate over all models in the hosted container agent
1339-
for _, model := range agentDef.Models {
1340-
modelDeployment, err := a.getModelDeploymentDetails(ctx, model)
1341-
if err != nil {
1342-
return fmt.Errorf("failed to get model deployment details: %w", err)
1343-
}
1344-
deploymentDetails = append(deploymentDetails, *modelDeployment)
1345-
}
1346-
}
1347-
1348-
if host == "containerapp" {
1349-
if err := a.setEnvVar(ctx, envName, "ENABLE_CONTAINER_AGENTS", "true"); err != nil {
1350-
return err
1351-
}
1352-
}
1353-
1354-
deploymentsJson, err := json.Marshal(deploymentDetails)
1355-
if err != nil {
1356-
return fmt.Errorf("failed to marshal deployment details to JSON: %w", err)
1357-
}
1358-
if err := a.setEnvVar(ctx, envName, "AI_PROJECT_DEPLOYMENTS", string(deploymentsJson)); err != nil {
1359-
return err
1360-
}
1361-
1362-
fmt.Printf("Successfully updated environment variables for agent kind: %s\n", agentDef.Kind)
1363-
return nil
1364-
}
1365-
13661313
func (a *InitAction) setEnvVar(ctx context.Context, envName, key, value string) error {
13671314
_, err := a.azdClient.Environment().SetValue(ctx, &azdext.SetEnvRequest{
13681315
EnvName: envName,
@@ -1377,40 +1324,7 @@ func (a *InitAction) setEnvVar(ctx context.Context, envName, key, value string)
13771324
return nil
13781325
}
13791326

1380-
// Deployment represents a single cognitive service account deployment
1381-
type Deployment struct {
1382-
// Specify the name of cognitive service account deployment.
1383-
Name string `json:"name"`
1384-
1385-
// Required. Properties of Cognitive Services account deployment model.
1386-
Model DeploymentModel `json:"model"`
1387-
1388-
// The resource model definition representing SKU.
1389-
Sku DeploymentSku `json:"sku"`
1390-
}
1391-
1392-
// DeploymentModel represents the model configuration for a cognitive services deployment
1393-
type DeploymentModel struct {
1394-
// Required. The name of Cognitive Services account deployment model.
1395-
Name string `json:"name"`
1396-
1397-
// Required. The format of Cognitive Services account deployment model.
1398-
Format string `json:"format"`
1399-
1400-
// Required. The version of Cognitive Services account deployment model.
1401-
Version string `json:"version"`
1402-
}
1403-
1404-
// DeploymentSku represents the resource model definition representing SKU
1405-
type DeploymentSku struct {
1406-
// Required. The name of the resource model definition representing SKU.
1407-
Name string `json:"name"`
1408-
1409-
// The capacity of the resource model definition representing SKU.
1410-
Capacity int `json:"capacity"`
1411-
}
1412-
1413-
func (a *InitAction) getModelDeploymentDetails(ctx context.Context, model agent_yaml.Model) (*Deployment, error) {
1327+
func (a *InitAction) getModelDeploymentDetails(ctx context.Context, model agent_yaml.Model) (*project.Deployment, error) {
14141328
version := ""
14151329
if model.Version != nil {
14161330
version = *model.Version
@@ -1440,14 +1354,14 @@ func (a *InitAction) getModelDeploymentDetails(ctx context.Context, model agent_
14401354
modelDeployment = modelDeploymentInput.Value
14411355
}
14421356

1443-
return &Deployment{
1357+
return &project.Deployment{
14441358
Name: modelDeployment,
1445-
Model: DeploymentModel{
1359+
Model: project.DeploymentModel{
14461360
Name: model.Id,
14471361
Format: modelDetails.Format,
14481362
Version: modelDetails.Version,
14491363
},
1450-
Sku: DeploymentSku{
1364+
Sku: project.DeploymentSku{
14511365
Name: modelDetails.Sku.Name,
14521366
Capacity: int(modelDetails.Sku.Capacity),
14531367
},

0 commit comments

Comments
 (0)