Skip to content

Commit 7f4df2c

Browse files
committed
Reapply "Reapply - Add grpo loop functional test (#2411)"
This reverts commit 6cc29a2.
1 parent 6d2a123 commit 7f4df2c

File tree

8 files changed

+799
-59
lines changed

8 files changed

+799
-59
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
- agent_type: examples.rl.environments.countdown.countdown_agent.CountdownAgent
2+
agent_args: {}
3+
weight: 1.0

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ dev = [
8585
"wget",
8686
"onnxscript",
8787
"fastapi~=0.50", # Forcing a little bit more recent version of fastapi to be compatible with pydantic 2.0
88+
"datasets",
8889
]
8990

9091
lts = [
@@ -103,6 +104,7 @@ lts = [
103104
"wget",
104105
"onnxscript",
105106
"fastapi~=0.50", # Forcing a little bit more recent version of fastapi to be compatible with pydantic 2.0
107+
"datasets",
106108
]
107109

108110
[dependency-groups]
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
3+
import json
4+
import logging
5+
import math
6+
from statistics import median
7+
8+
logging.basicConfig(level=logging.INFO)
9+
logger = logging.getLogger(__name__)
10+
11+
12+
def test_grpo_training_loop(golden_values_path: str, test_values_path: str) -> None:
13+
14+
with open(golden_values_path, 'r') as f1, open(test_values_path, 'r') as f2:
15+
golden_values_content = f1.read()
16+
tensorboard_content = f2.read()
17+
18+
output_groundtruth = json.loads(golden_values_content)
19+
20+
if isinstance(output_groundtruth, str):
21+
# Handle JSONL output, assume only one line in this case.
22+
output_groundtruth = json.loads(output_groundtruth)
23+
24+
output_current = json.loads(tensorboard_content)
25+
if isinstance(output_current, str):
26+
# Handle JSONL output, assume only one line in this case.
27+
output_current = json.loads(output_current)
28+
29+
assert set(output_groundtruth.keys()).issuperset(
30+
set(output_current.keys())
31+
), f"Some IDs from groundtruth are missing in current: {output_groundtruth.keys()} vs {output_current.keys()}"
32+
if set(output_groundtruth.keys()) != set(output_current.keys()):
33+
logger.warning(
34+
f"Some IDs from groundtruth are missing in output, only the subset of ids in groundtruth will be tested: {output_groundtruth.keys()} vs {output_current.keys()}"
35+
)
36+
assert len(output_groundtruth) > 0, "No test performed for output"
37+
38+
if "iteration-time" in output_groundtruth.keys():
39+
40+
# First warmup iteration is excluded from iteration-time statistics.
41+
iteration_time_sampled = median(
42+
[l for l in output_current["iteration-time"]['values'].values()][1:]
43+
)
44+
iteration_time_golden = median(
45+
[l for l in output_groundtruth["iteration-time"]['values'].values()][1:]
46+
)
47+
48+
# 10% is empirically observed to be within hardware variance.
49+
assert (
50+
0.9 * iteration_time_golden <= iteration_time_sampled <= 1.2 * iteration_time_golden
51+
), (
52+
f"Iteration time {iteration_time_sampled} ms not within 10% below or 20% above "
53+
f"golden value ~{iteration_time_golden} ms. "
54+
f"Sampled: {output_current['iteration-time']} ms. "
55+
f"Please update golden values in the functional tests if this is expected."
56+
)
57+
58+
output_groundtruth.pop('iteration-time')

tests/functional_tests/shell_test_utils/run_ci_test.sh

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,24 @@ for i in $(seq 1 $N_REPEAT); do
314314
fi
315315
fi
316316

317+
# For rl jobs
318+
if [[ "$MODE" == "rl" && ("$TRAINING_EXIT_CODE" -eq 0 || "$TEST_TYPE" == "release") ]]; then
319+
if [[ "$TEST_TYPE" == "frozen-start" ]]; then
320+
TRAIN_ITERS=$(cat $TRAINING_PARAMS_PATH |
321+
/usr/local/bin/yq '.MODEL_ARGS."--exit-interval" // "50"')
322+
uv run --no-sync python $ROOT_DIR/tests/functional_tests/python_test_utils/get_test_results_from_tensorboard_logs.py \
323+
--logs-dir $TENSORBOARD_PATH \
324+
--train-iters $TRAIN_ITERS \
325+
--output-path ${OUTPUT_PATH}/$(basename $GOLDEN_VALUES_PATH) \
326+
"${EXTRACT_ARGS[@]}"
327+
uv run --no-sync pytest -s -o log_cli=true --log-cli-level=info $ROOT_DIR/tests/functional_tests/python_test_utils/test_grpo_training_loop.py \
328+
--golden-values-path $GOLDEN_VALUES_PATH \
329+
--test-values-path ${OUTPUT_PATH}/$(basename $GOLDEN_VALUES_PATH) \
330+
--model-config-path ${TRAINING_PARAMS_PATH} \
331+
$ALLOW_NONDETERMINISTIC_ALGO_ARG
332+
fi
333+
fi
334+
317335
# Abort if training failed
318336
if [[ "$TRAINING_EXIT_CODE" -ne 0 && "$TEST_TYPE" != "release" ]]; then
319337
echo "Training failed. Aborting."
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
{
2+
"lm loss": {
3+
"start_step": 1,
4+
"end_step": 50,
5+
"step_interval": 1,
6+
"values": {
7+
"1": 0.0,
8+
"2": 0.04415,
9+
"3": 0.0378,
10+
"4": 0.02944,
11+
"5": 0.0,
12+
"6": 0.0,
13+
"7": 0.0,
14+
"8": 0.08111,
15+
"9": 0.0,
16+
"10": 0.0,
17+
"11": 0.0,
18+
"12": 0.0,
19+
"13": 0.0,
20+
"14": 0.05935,
21+
"15": 0.0,
22+
"16": 0.05496,
23+
"17": 0.0,
24+
"18": 0.0,
25+
"19": 0.0,
26+
"20": 0.04534,
27+
"21": 0.0,
28+
"22": 0.0,
29+
"23": 0.0,
30+
"24": 0.0,
31+
"25": 0.0,
32+
"26": 0.0,
33+
"27": 0.0,
34+
"28": 0.0,
35+
"29": 0.0,
36+
"30": 0.0,
37+
"31": 0.0,
38+
"32": 0.0,
39+
"33": 0.0,
40+
"34": 0.0,
41+
"35": 0.0,
42+
"36": 0.0,
43+
"37": 0.0099,
44+
"38": 0.0,
45+
"39": 0.0,
46+
"40": 0.0,
47+
"41": 0.03221,
48+
"42": 0.0,
49+
"43": 0.0,
50+
"44": 0.0,
51+
"45": 0.0,
52+
"46": 0.0,
53+
"47": 0.0,
54+
"48": 0.0,
55+
"49": 0.0,
56+
"50": 0.0
57+
}
58+
},
59+
"num-zeros": {
60+
"start_step": 1,
61+
"end_step": 50,
62+
"step_interval": 1,
63+
"values": {
64+
"1": 583687296.0,
65+
"2": 0.0,
66+
"3": 0.0,
67+
"4": 49.0,
68+
"5": 583687296.0,
69+
"6": 583687296.0,
70+
"7": 583687296.0,
71+
"8": 12.0,
72+
"9": 583687296.0,
73+
"10": 583687296.0,
74+
"11": 583687296.0,
75+
"12": 583687296.0,
76+
"13": 583687296.0,
77+
"14": 6.0,
78+
"15": 583687296.0,
79+
"16": 62.0,
80+
"17": 583687296.0,
81+
"18": 583687296.0,
82+
"19": 583687296.0,
83+
"20": 23.0,
84+
"21": 583687296.0,
85+
"22": 583687296.0,
86+
"23": 583687296.0,
87+
"24": 583687296.0,
88+
"25": 583687296.0,
89+
"26": 583687296.0,
90+
"27": 583687296.0,
91+
"28": 583687296.0,
92+
"29": 583687296.0,
93+
"30": 583687296.0,
94+
"31": 583687296.0,
95+
"32": 583687296.0,
96+
"33": 583687296.0,
97+
"34": 583687296.0,
98+
"35": 583687296.0,
99+
"36": 583687296.0,
100+
"37": 37.0,
101+
"38": 583687296.0,
102+
"39": 583687296.0,
103+
"40": 583687296.0,
104+
"41": 53.0,
105+
"42": 583687296.0,
106+
"43": 583687296.0,
107+
"44": 583687296.0,
108+
"45": 583687296.0,
109+
"46": 583687296.0,
110+
"47": 583687296.0,
111+
"48": 583687296.0,
112+
"49": 583687296.0,
113+
"50": 583687296.0
114+
}
115+
},
116+
"mem-allocated-bytes": {
117+
"start_step": 1,
118+
"end_step": 50,
119+
"step_interval": 1,
120+
"values": {
121+
"1": 55320928256.0,
122+
"2": 55319695360.0,
123+
"3": 55319674880.0,
124+
"4": 55319638016.0,
125+
"5": 55319638016.0,
126+
"6": 55319638016.0,
127+
"7": 55319633920.0,
128+
"8": 55319625728.0,
129+
"9": 55319621632.0,
130+
"10": 55319625728.0,
131+
"11": 55319625728.0,
132+
"12": 55319629824.0,
133+
"13": 55319547904.0,
134+
"14": 55319552000.0,
135+
"15": 55319552000.0,
136+
"16": 55319552000.0,
137+
"17": 55319552000.0,
138+
"18": 55319552000.0,
139+
"19": 55319556096.0,
140+
"20": 55319556096.0,
141+
"21": 55319556096.0,
142+
"22": 55319556096.0,
143+
"23": 55319556096.0,
144+
"24": 55319560192.0,
145+
"25": 55319560192.0,
146+
"26": 55319560192.0,
147+
"27": 55319560192.0,
148+
"28": 55319552000.0,
149+
"29": 55319552000.0,
150+
"30": 55319552000.0,
151+
"31": 55319552000.0,
152+
"32": 55319552000.0,
153+
"33": 55319552000.0,
154+
"34": 55319556096.0,
155+
"35": 55319556096.0,
156+
"36": 55319556096.0,
157+
"37": 55319560192.0,
158+
"38": 55319560192.0,
159+
"39": 55319560192.0,
160+
"40": 55319556096.0,
161+
"41": 55319552000.0,
162+
"42": 55319552000.0,
163+
"43": 55319552000.0,
164+
"44": 55319552000.0,
165+
"45": 55319552000.0,
166+
"46": 55319552000.0,
167+
"47": 55319556096.0,
168+
"48": 55319556096.0,
169+
"49": 55319556096.0,
170+
"50": 55319552000.0
171+
}
172+
},
173+
"mem-max-allocated-bytes": {
174+
"start_step": 1,
175+
"end_step": 50,
176+
"step_interval": 1,
177+
"values": {
178+
"1": 64753942528.0,
179+
"2": 69804253184.0,
180+
"3": 69804253184.0,
181+
"4": 69804253184.0,
182+
"5": 69804253184.0,
183+
"6": 69804253184.0,
184+
"7": 69804253184.0,
185+
"8": 69804253184.0,
186+
"9": 69804253184.0,
187+
"10": 69804253184.0,
188+
"11": 69804253184.0,
189+
"12": 69804253184.0,
190+
"13": 69804253184.0,
191+
"14": 69804253184.0,
192+
"15": 69804253184.0,
193+
"16": 69804253184.0,
194+
"17": 69804253184.0,
195+
"18": 69804253184.0,
196+
"19": 69804253184.0,
197+
"20": 69804253184.0,
198+
"21": 69804253184.0,
199+
"22": 69804253184.0,
200+
"23": 69804253184.0,
201+
"24": 69804253184.0,
202+
"25": 69804253184.0,
203+
"26": 69804253184.0,
204+
"27": 69804253184.0,
205+
"28": 69804253184.0,
206+
"29": 69804253184.0,
207+
"30": 69804253184.0,
208+
"31": 69804253184.0,
209+
"32": 69804253184.0,
210+
"33": 69804253184.0,
211+
"34": 69804253184.0,
212+
"35": 69804253184.0,
213+
"36": 69804253184.0,
214+
"37": 69804253184.0,
215+
"38": 69804253184.0,
216+
"39": 69804253184.0,
217+
"40": 69804253184.0,
218+
"41": 69804253184.0,
219+
"42": 69804253184.0,
220+
"43": 69804253184.0,
221+
"44": 69804253184.0,
222+
"45": 69804253184.0,
223+
"46": 69804253184.0,
224+
"47": 69804253184.0,
225+
"48": 69804253184.0,
226+
"49": 69804253184.0,
227+
"50": 69804253184.0
228+
}
229+
},
230+
"iteration-time": {
231+
"start_step": 1,
232+
"end_step": 50,
233+
"step_interval": 1,
234+
"values": {
235+
"1": 74.35665,
236+
"2": 5.25731,
237+
"3": 5.75582,
238+
"4": 4.02061,
239+
"5": 3.8529,
240+
"6": 3.91732,
241+
"7": 4.14616,
242+
"8": 3.83737,
243+
"9": 3.75158,
244+
"10": 3.91902,
245+
"11": 3.96073,
246+
"12": 3.83611,
247+
"13": 3.86989,
248+
"14": 3.88658,
249+
"15": 4.46432,
250+
"16": 3.90389,
251+
"17": 3.8143,
252+
"18": 3.86593,
253+
"19": 3.78307,
254+
"20": 3.90922,
255+
"21": 3.82247,
256+
"22": 3.76037,
257+
"23": 4.00863,
258+
"24": 3.74678,
259+
"25": 3.86492,
260+
"26": 3.83492,
261+
"27": 3.86387,
262+
"28": 3.99894,
263+
"29": 3.85812,
264+
"30": 4.34066,
265+
"31": 3.88411,
266+
"32": 3.80617,
267+
"33": 3.90347,
268+
"34": 3.7771,
269+
"35": 3.84701,
270+
"36": 3.81111,
271+
"37": 3.75554,
272+
"38": 3.99552,
273+
"39": 3.87227,
274+
"40": 3.81079,
275+
"41": 3.83039,
276+
"42": 3.74567,
277+
"43": 3.82531,
278+
"44": 3.78258,
279+
"45": 3.73294,
280+
"46": 4.579,
281+
"47": 3.72516,
282+
"48": 3.8117,
283+
"49": 3.80651,
284+
"50": 3.78283
285+
}
286+
}
287+
}

0 commit comments

Comments
 (0)