Skip to content

Commit 3c0a6a7

Browse files
authored
feat(agent): adapt the existing LoopAgent to the new version of the Agent API (#2679)
* feat: adapt LoopAgent * feat: adapt LoopAgent * fix: add docs
1 parent 7eb5366 commit 3c0a6a7

File tree

9 files changed

+766
-796
lines changed

9 files changed

+766
-796
lines changed

spring-ai-alibaba-agent-framework/src/main/java/com/alibaba/cloud/ai/graph/agent/flow/agent/LoopAgent.java

Lines changed: 117 additions & 450 deletions
Large diffs are not rendered by default.
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Copyright 2024-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.alibaba.cloud.ai.graph.agent.flow.agent.loop;
18+
19+
import com.alibaba.cloud.ai.graph.OverAllState;
20+
import org.springframework.ai.chat.messages.Message;
21+
import org.springframework.ai.chat.messages.SystemMessage;
22+
import org.springframework.ai.chat.messages.UserMessage;
23+
import org.springframework.ai.util.json.JsonParser;
24+
import org.springframework.core.convert.converter.Converter;
25+
26+
import java.util.List;
27+
import java.util.Map;
28+
29+
/**
30+
* JSON array loop strategy that retrieves a JSON array from the current message state,
31+
* sends each array element as a message to the model, and returns the result.
32+
* By default, the text of the last message is treated as a JSON array, but users can customize the converter.
33+
*
34+
* @author vlsmb
35+
* @since 2025/11/1
36+
*/
37+
public class ArrayLoopStrategy implements LoopStrategy {
38+
39+
private final Converter<List<Message>, List<?>> converter;
40+
41+
public ArrayLoopStrategy(Converter<List<Message>, List<?>> converter) {
42+
this.converter = converter;
43+
}
44+
45+
public ArrayLoopStrategy() {
46+
this(DEFAULT_MESSAGE_CONVERTER);
47+
}
48+
49+
@Override
50+
public Map<String, Object> loopInit(OverAllState state) {
51+
@SuppressWarnings("unchecked")
52+
List<Message> messages = (List<Message>) state.value(LoopStrategy.MESSAGE_KEY).orElse(List.of());
53+
List<?> list = converter.convert(messages);
54+
if(list != null) {
55+
return Map.of(loopCountKey(), 0, loopFlagKey(), true, loopListKey(), list);
56+
}
57+
return Map.of(loopCountKey(), 0, loopFlagKey(), false, loopListKey(), List.of(),
58+
LoopStrategy.MESSAGE_KEY, new SystemMessage("Invalid json array format"));
59+
}
60+
61+
@Override
62+
public Map<String, Object> loopDispatch(OverAllState state) {
63+
List<?> list = state.value(loopListKey(), List.class).orElse(List.of());
64+
int index = state.value(loopCountKey(), maxLoopCount());
65+
if(index < list.size()) {
66+
UserMessage message = new UserMessage(list.get(index).toString());
67+
return Map.of(loopCountKey(), index + 1, loopFlagKey(), true,
68+
LoopStrategy.MESSAGE_KEY, message);
69+
} else {
70+
return Map.of(loopFlagKey(), false);
71+
}
72+
}
73+
74+
/**
75+
* 默认的转换器,将最后一个消息的文本作为json数组
76+
*/
77+
private static final Converter<List<Message>, List<?>> DEFAULT_MESSAGE_CONVERTER =
78+
messages -> {
79+
String lastMessage;
80+
if(!messages.isEmpty()) {
81+
lastMessage = messages.get(messages.size() - 1).getText();
82+
} else {
83+
lastMessage = null;
84+
}
85+
if(lastMessage == null) {
86+
return null;
87+
}
88+
return JsonParser.fromJson(lastMessage, List.class);
89+
};
90+
91+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Copyright 2024-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.alibaba.cloud.ai.graph.agent.flow.agent.loop;
18+
19+
import com.alibaba.cloud.ai.graph.OverAllState;
20+
import org.springframework.ai.chat.messages.Message;
21+
import org.springframework.ai.chat.messages.SystemMessage;
22+
23+
import java.util.List;
24+
import java.util.Map;
25+
import java.util.function.Predicate;
26+
27+
/**
28+
* Conditional loop strategy that retries until the Predicate is satisfied or the maximum count is reached.
29+
*
30+
* @author vlsmb
31+
* @since 2025/11/1
32+
*/
33+
public class ConditionLoopStrategy implements LoopStrategy {
34+
35+
private final Predicate<List<Message>> messagePredicate;
36+
37+
private final int maxCount = maxLoopCount();
38+
39+
public ConditionLoopStrategy(Predicate<List<Message>> messagePredicate) {
40+
this.messagePredicate = messagePredicate;
41+
}
42+
43+
@Override
44+
public Map<String, Object> loopInit(OverAllState state) {
45+
return Map.of(loopCountKey(), 0, loopFlagKey(), true);
46+
}
47+
48+
@Override
49+
public Map<String, Object> loopDispatch(OverAllState state) {
50+
@SuppressWarnings("unchecked")
51+
List<Message> messages = (List<Message>) state.value(LoopStrategy.MESSAGE_KEY).orElse(List.of());
52+
if(messagePredicate.test(messages)) {
53+
return Map.of(loopFlagKey(), false);
54+
} else {
55+
int count = state.value(loopCountKey(), maxCount);
56+
if(count < maxCount) {
57+
return Map.of(loopCountKey(), count + 1, loopFlagKey(), true);
58+
} else {
59+
return Map.of(LoopStrategy.MESSAGE_KEY, new SystemMessage("Max loop count reached"), loopFlagKey(), false);
60+
}
61+
}
62+
}
63+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright 2024-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.alibaba.cloud.ai.graph.agent.flow.agent.loop;
18+
19+
import com.alibaba.cloud.ai.graph.OverAllState;
20+
21+
import java.util.Map;
22+
23+
/**
24+
* Fixed count loop strategy
25+
*
26+
* @author vlsmb
27+
* @since 2025/11/1
28+
*/
29+
public class CountLoopStrategy implements LoopStrategy {
30+
31+
private final int maxCount;
32+
33+
public CountLoopStrategy(int maxCount) {
34+
this.maxCount = Math.min(maxCount, maxLoopCount());
35+
}
36+
37+
@Override
38+
public Map<String, Object> loopInit(OverAllState state) {
39+
return Map.of(loopCountKey(), 0, loopFlagKey(), maxCount > 0);
40+
}
41+
42+
@Override
43+
public Map<String, Object> loopDispatch(OverAllState state) {
44+
int count = state.value(loopCountKey(), maxCount);
45+
if (count < maxCount) {
46+
return Map.of(loopCountKey(), count + 1, loopFlagKey(), true);
47+
} else {
48+
return Map.of(loopFlagKey(), false);
49+
}
50+
}
51+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright 2024-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.alibaba.cloud.ai.graph.agent.flow.agent.loop;
18+
19+
import org.springframework.ai.chat.messages.Message;
20+
import org.springframework.core.convert.converter.Converter;
21+
22+
import java.util.List;
23+
import java.util.function.Predicate;
24+
25+
/**
26+
* Built-in loop strategies for LoopAgent
27+
*
28+
* @author vlsmb
29+
* @since 2025/11/1
30+
*/
31+
public final class LoopMode {
32+
private LoopMode() {
33+
throw new UnsupportedOperationException();
34+
}
35+
36+
public static CountLoopStrategy count(int maxCount) {
37+
return new CountLoopStrategy(maxCount);
38+
}
39+
40+
public static ArrayLoopStrategy array() {
41+
return new ArrayLoopStrategy();
42+
}
43+
44+
public static ArrayLoopStrategy array(Converter<List<Message>, List<?>> converter) {
45+
return new ArrayLoopStrategy(converter);
46+
}
47+
48+
public static ConditionLoopStrategy condition(Predicate<List<Message>> messagePredicate) {
49+
return new ConditionLoopStrategy(messagePredicate);
50+
}
51+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* Copyright 2024-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.alibaba.cloud.ai.graph.agent.flow.agent.loop;
18+
19+
import com.alibaba.cloud.ai.graph.OverAllState;
20+
21+
import java.util.List;
22+
import java.util.Map;
23+
24+
/**
25+
* <p>Loop strategy for LoopAgent, used to control the behavior of LoopAgent.</p>
26+
* <p>This part is equivalent to defining the loopInitNode and loopDispatchNode for the StateGraph corresponding to LoopAgent.</p>
27+
* <p>Built-in strategies provided by LoopMode can be used directly when in use. If custom loop logic is required, this interface can be implemented.</p>
28+
*
29+
* @author vlsmb
30+
* @since 2025/11/1
31+
*/
32+
public interface LoopStrategy {
33+
34+
int ITERABLE_ELEMENT_COUNT = 1000;
35+
36+
String LOOP_FLAG_PREFIX = "__loop_flag__";
37+
38+
String LOOP_LIST_PREFIX = "__loop_list__";
39+
40+
String LOOP_COUNT_PREFIX = "__loop_count__";
41+
42+
String INIT_NODE_NAME = "_loop_init__";
43+
44+
String DISPATCH_NODE_NAME = "_loop_dispatch__";
45+
46+
String MESSAGE_KEY = "messages";
47+
48+
Map<String, Object> loopInit(OverAllState state);
49+
50+
Map<String, Object> loopDispatch(OverAllState state);
51+
52+
default String uniqueKey() {
53+
return String.valueOf(System.identityHashCode(this));
54+
}
55+
56+
default List<String> tempKeys() {
57+
return List.of(
58+
loopFlagKey(),
59+
loopListKey(),
60+
loopCountKey()
61+
);
62+
}
63+
64+
default int maxLoopCount() {
65+
return ITERABLE_ELEMENT_COUNT;
66+
}
67+
68+
default String loopFlagKey() {
69+
return LOOP_FLAG_PREFIX + uniqueKey();
70+
}
71+
72+
default String loopListKey() {
73+
return LOOP_LIST_PREFIX + uniqueKey();
74+
}
75+
76+
default String loopCountKey() {
77+
return LOOP_COUNT_PREFIX + uniqueKey();
78+
}
79+
80+
default String loopInitNodeName() {
81+
return INIT_NODE_NAME + uniqueKey();
82+
}
83+
84+
default String loopDispatchNodeName() {
85+
return DISPATCH_NODE_NAME + uniqueKey();
86+
}
87+
88+
}

spring-ai-alibaba-agent-framework/src/main/java/com/alibaba/cloud/ai/graph/agent/flow/strategy/FlowGraphBuildingStrategyRegistry.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ private void registerDefaultStrategies() {
131131
registerStrategy(new RoutingGraphBuildingStrategy());
132132
registerStrategy(new ParallelGraphBuildingStrategy());
133133
registerStrategy(new ConditionalGraphBuildingStrategy());
134-
// registerStrategy(new LoopGraphBuildingStrategy());
134+
registerStrategy(new LoopGraphBuildingStrategy());
135135
}
136136

137137
}

0 commit comments

Comments
 (0)