Skip to content

Commit 4cb3dc3

Browse files
committed
huge refactor
1 parent 3d5dd1a commit 4cb3dc3

File tree

23 files changed

+5794
-5120
lines changed

23 files changed

+5794
-5120
lines changed

test/services/test_python_executor_service.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_service_execution(self, ray_init):
7373
result = x + y
7474
print(f"Result: {result}")
7575
"""
76-
result = ray.get(executor.execute.remote(code), timeout=2)
76+
result = ray.get(executor.execute.remote(code), timeout=10)
7777

7878
assert result["success"] is True
7979
assert "Result: 30" in result["stdout"]
@@ -101,7 +101,7 @@ def test_service_execution_error(self, ray_init):
101101

102102
# Execute code with an error
103103
code = "raise ValueError('Test error')"
104-
result = ray.get(executor.execute.remote(code), timeout=2)
104+
result = ray.get(executor.execute.remote(code), timeout=10)
105105

106106
assert result["success"] is False
107107
assert "ValueError: Test error" in result["stderr"]
@@ -119,7 +119,7 @@ def test_multiple_executions(self, ray_init):
119119
"python_executor",
120120
PythonExecutorService,
121121
pool_size=4,
122-
timeout=5.0,
122+
timeout=10.0,
123123
num_cpus=4,
124124
max_concurrency=4,
125125
)
@@ -132,14 +132,16 @@ def test_multiple_executions(self, ray_init):
132132
code = f"print('Execution {i}')"
133133
futures.append(executor.execute.remote(code))
134134

135-
# Wait for all to complete
136-
results = ray.get(futures, timeout=5)
135+
# Wait for all to complete with longer timeout
136+
results = ray.get(futures, timeout=30)
137137

138138
# All should succeed
139139
assert len(results) == 8
140140
for i, result in enumerate(results):
141-
assert result["success"] is True
142-
assert f"Execution {i}" in result["stdout"]
141+
assert result["success"] is True, f"Execution {i} failed: {result}"
142+
assert (
143+
f"Execution {i}" in result["stdout"]
144+
), f"Expected 'Execution {i}' in stdout, got: {result['stdout']!r}"
143145

144146
finally:
145147
services.reset()

test/test_collector.py

Lines changed: 112 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
import subprocess
1414
import sys
1515
import time
16+
from contextlib import nullcontext
1617
from unittest.mock import patch
1718

1819
import numpy as np
1920
import pytest
2021
import torch
22+
23+
import torchrl.collectors._runner
2124
from packaging import version
2225
from tensordict import (
2326
assert_allclose_td,
@@ -33,7 +36,6 @@
3336
TensorDictSequential,
3437
)
3538
from torch import nn
36-
3739
from torchrl._utils import (
3840
_make_ordinal_device,
3941
_replace_last,
@@ -48,7 +50,7 @@
4850
SyncDataCollector,
4951
WeightUpdaterBase,
5052
)
51-
from torchrl.collectors.collectors import _Interruptor
53+
from torchrl.collectors._constants import _Interruptor
5254

5355
from torchrl.collectors.utils import split_trajectories
5456
from torchrl.data import (
@@ -1487,12 +1489,14 @@ def env_fn(seed):
14871489
assert_allclose_td(data10, data20)
14881490

14891491
@pytest.mark.parametrize("use_async", [False, True])
1490-
@pytest.mark.parametrize("cudagraph", [False, True])
1492+
@pytest.mark.parametrize(
1493+
"cudagraph", [False, True] if torch.cuda.is_available() else [False]
1494+
)
14911495
@pytest.mark.parametrize(
14921496
"weight_sync_scheme",
14931497
[None, MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme],
14941498
)
1495-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found")
1499+
# @pytest.mark.skipif(not torch.cuda.is_available() and not torch.mps.is_available(), reason="no cuda/mps device found")
14961500
def test_update_weights(self, use_async, cudagraph, weight_sync_scheme):
14971501
def create_env():
14981502
return ContinuousActionVecMockEnv()
@@ -1509,11 +1513,12 @@ def create_env():
15091513
kwargs = {}
15101514
if weight_sync_scheme is not None:
15111515
kwargs["weight_sync_schemes"] = {"policy": weight_sync_scheme()}
1516+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
15121517
collector = collector_class(
15131518
[create_env] * 3,
15141519
policy=policy,
1515-
device=[torch.device("cuda:0")] * 3,
1516-
storing_device=[torch.device("cuda:0")] * 3,
1520+
device=[torch.device(device)] * 3,
1521+
storing_device=[torch.device(device)] * 3,
15171522
frames_per_batch=20,
15181523
cat_results="stack",
15191524
cudagraph_policy=cudagraph,
@@ -1544,7 +1549,9 @@ def create_env():
15441549
# check they don't match
15451550
for worker in range(3):
15461551
for k in state_dict[f"worker{worker}"]["policy_state_dict"]:
1547-
with pytest.raises(AssertionError):
1552+
with pytest.raises(
1553+
AssertionError
1554+
) if torch.cuda.is_available() else nullcontext():
15481555
torch.testing.assert_close(
15491556
state_dict[f"worker{worker}"]["policy_state_dict"][k],
15501557
policy_state_dict[k].cpu(),
@@ -2401,7 +2408,9 @@ def test_auto_wrap_error(self, collector_class, env_maker, num_envs):
24012408
policy = UnwrappablePolicy(out_features=env_maker().action_spec.shape[-1])
24022409
with pytest.raises(
24032410
TypeError,
2404-
match=("Arguments to policy.forward are incompatible with entries in"),
2411+
match=(
2412+
"Arguments to policy.forward are incompatible with entries in|Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True."
2413+
),
24052414
):
24062415
collector_class(
24072416
**self._create_collector_kwargs(
@@ -2980,6 +2989,94 @@ def test_param_sync_mixed_device(
29802989
col.shutdown()
29812990
del col
29822991

2992+
@pytest.mark.skipif(
2993+
not torch.cuda.is_available() or torch.cuda.device_count() < 3,
2994+
reason="requires at least 3 CUDA devices",
2995+
)
2996+
@pytest.mark.parametrize(
2997+
"weight_sync_scheme",
2998+
[SharedMemWeightSyncScheme, MultiProcessWeightSyncScheme],
2999+
)
3000+
def test_shared_device_weight_update(self, weight_sync_scheme):
3001+
"""Test that weight updates work correctly when multiple workers share the same device.
3002+
3003+
This test specifically validates the per-worker queue implementation in SharedMemWeightSyncScheme.
3004+
When workers 0 and 2 share cuda:2, each should receive its own copy of the weights through
3005+
dedicated queues, preventing race conditions that could occur with a single shared queue.
3006+
"""
3007+
# Create policy on cuda:0
3008+
policy = TensorDictModule(
3009+
nn.Linear(7, 7, device="cuda:0"),
3010+
in_keys=["observation"],
3011+
out_keys=["action"],
3012+
)
3013+
3014+
def make_env():
3015+
return ContinuousActionVecMockEnv()
3016+
3017+
# Create collector with workers on cuda:2, cuda:1, cuda:2
3018+
# Workers 0 and 2 share cuda:2 - this is the key test case
3019+
collector = MultiaSyncDataCollector(
3020+
[make_env, make_env, make_env],
3021+
policy=policy,
3022+
frames_per_batch=30,
3023+
total_frames=300,
3024+
device=["cuda:2", "cuda:1", "cuda:2"],
3025+
storing_device=["cuda:2", "cuda:1", "cuda:2"],
3026+
weight_sync_schemes={"policy": weight_sync_scheme()},
3027+
)
3028+
3029+
try:
3030+
# Collect first batch to initialize workers
3031+
for _ in collector:
3032+
break
3033+
3034+
# Get initial weights
3035+
old_weight = policy.module.weight.data.clone()
3036+
3037+
# Modify policy weights on cuda:0
3038+
for p in policy.parameters():
3039+
p.data += torch.randn_like(p)
3040+
3041+
new_weight = policy.module.weight.data.clone()
3042+
assert not torch.allclose(
3043+
old_weight, new_weight
3044+
), "Weights should have changed"
3045+
3046+
# Update weights - this should propagate to all workers via their dedicated queues
3047+
collector.update_policy_weights_()
3048+
3049+
# Collect more batches to ensure weights are propagated
3050+
for i, _ in enumerate(collector):
3051+
if i >= 2:
3052+
break
3053+
3054+
# Get state dict from all workers
3055+
state_dict = collector.state_dict()
3056+
3057+
# Verify all workers have the new weights, including both workers on cuda:2
3058+
for worker_idx in range(3):
3059+
worker_key = f"worker{worker_idx}"
3060+
assert (
3061+
"policy_state_dict" in state_dict[worker_key]
3062+
), f"Worker {worker_idx} should have policy_state_dict"
3063+
worker_weight = state_dict[worker_key]["policy_state_dict"][
3064+
"module.weight"
3065+
]
3066+
torch.testing.assert_close(
3067+
worker_weight.cpu(),
3068+
new_weight.cpu(),
3069+
msg=(
3070+
f"Worker {worker_idx} weights don't match expected weights. "
3071+
f"Workers 0 and 2 share device cuda:2, worker 1 is on cuda:1. "
3072+
f"This test validates that the per-worker queue system correctly "
3073+
f"distributes weights even when multiple workers share a device."
3074+
),
3075+
)
3076+
finally:
3077+
collector.shutdown()
3078+
del collector
3079+
29833080

29843081
class TestAggregateReset:
29853082
def test_aggregate_reset_to_root(self):
@@ -3176,11 +3273,11 @@ class TestLibThreading:
31763273
reason="setting different threads across workers can randomly fail on OSX.",
31773274
)
31783275
def test_num_threads(self):
3179-
from torchrl.collectors import collectors
3276+
pass
31803277

3181-
_main_async_collector_saved = collectors._main_async_collector
3182-
collectors._main_async_collector = decorate_thread_sub_func(
3183-
collectors._main_async_collector, num_threads=3
3278+
_main_async_collector_saved = torchrl.collectors._runner._main_async_collector
3279+
torchrl.collectors._runner._main_async_collector = decorate_thread_sub_func(
3280+
torchrl.collectors._runner._main_async_collector, num_threads=3
31843281
)
31853282
num_threads = torch.get_num_threads()
31863283
try:
@@ -3204,7 +3301,9 @@ def test_num_threads(self):
32043301
except Exception:
32053302
torchrl_logger.info("Failed to shut down collector")
32063303
# reset vals
3207-
collectors._main_async_collector = _main_async_collector_saved
3304+
torchrl.collectors._runner._main_async_collector = (
3305+
_main_async_collector_saved
3306+
)
32083307
torch.set_num_threads(num_threads)
32093308

32103309
@pytest.mark.skipif(

torchrl/collectors/__init__.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55

66
from torchrl.envs.utils import RandomPolicy
77

8-
from .collectors import (
9-
aSyncDataCollector,
10-
DataCollectorBase,
11-
MultiaSyncDataCollector,
12-
MultiSyncDataCollector,
13-
SyncDataCollector,
14-
)
8+
from ._multi_async import MultiaSyncDataCollector
9+
from ._multi_sync import MultiSyncDataCollector
10+
from ._single import SyncDataCollector
11+
12+
from ._single_async import aSyncDataCollector
13+
from .base import DataCollectorBase
1514
from .weight_update import (
1615
MultiProcessedWeightUpdater,
1716
RayWeightUpdater,

torchrl/collectors/_constants.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""Constants and helper classes for collectors."""
6+
from __future__ import annotations
7+
8+
import os
9+
import sys
10+
from multiprocessing.managers import SyncManager
11+
12+
import torch
13+
from torch import multiprocessing as mp
14+
15+
from torchrl.envs.utils import ExplorationType
16+
17+
try:
18+
from torch.compiler import cudagraph_mark_step_begin
19+
except ImportError:
20+
21+
def cudagraph_mark_step_begin():
22+
"""Placeholder for missing cudagraph_mark_step_begin method."""
23+
raise NotImplementedError("cudagraph_mark_step_begin not implemented.")
24+
25+
26+
__all__ = [
27+
"_TIMEOUT",
28+
"INSTANTIATE_TIMEOUT",
29+
"_MIN_TIMEOUT",
30+
"_MAX_IDLE_COUNT",
31+
"DEFAULT_EXPLORATION_TYPE",
32+
"_is_osx",
33+
"_Interruptor",
34+
"_InterruptorManager",
35+
"cudagraph_mark_step_begin",
36+
]
37+
38+
_TIMEOUT = 1.0
39+
INSTANTIATE_TIMEOUT = 20
40+
_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory
41+
# MAX_IDLE_COUNT is the maximum number of times a Dataloader worker can timeout with his queue.
42+
_MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", torch.iinfo(torch.int64).max))
43+
44+
DEFAULT_EXPLORATION_TYPE: ExplorationType = ExplorationType.RANDOM
45+
46+
_is_osx = sys.platform.startswith("darwin")
47+
48+
49+
class _Interruptor:
50+
"""A class for managing the collection state of a process.
51+
52+
This class provides methods to start and stop collection, and to check
53+
whether collection has been stopped. The collection state is protected
54+
by a lock to ensure thread-safety.
55+
"""
56+
57+
# interrupter vs interruptor: google trends seems to indicate that "or" is more
58+
# widely used than "er" even if my IDE complains about that...
59+
def __init__(self):
60+
self._collect = True
61+
self._lock = mp.Lock()
62+
63+
def start_collection(self):
64+
with self._lock:
65+
self._collect = True
66+
67+
def stop_collection(self):
68+
with self._lock:
69+
self._collect = False
70+
71+
def collection_stopped(self):
72+
with self._lock:
73+
return self._collect is False
74+
75+
76+
class _InterruptorManager(SyncManager):
77+
"""A custom SyncManager for managing the collection state of a process.
78+
79+
This class extends the SyncManager class and allows to share an Interruptor object
80+
between processes.
81+
"""
82+
83+
84+
_InterruptorManager.register("_Interruptor", _Interruptor)

0 commit comments

Comments
 (0)