Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions docs/source/reference/collectors_weightsync.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Weight update schemes can be used outside of collectors for custom synchronizati
The new simplified API provides four core methods for weight synchronization:

- ``init_on_sender(model_id, **kwargs)`` - Initialize on the main process (trainer) side
- ``init_on_worker(model_id, **kwargs)`` - Initialize on worker process side
- ``init_on_receiver(model_id, **kwargs)`` - Initialize on worker process side
- ``get_sender()`` - Get the configured sender instance
- ``get_receiver()`` - Get the configured receiver instance

Expand Down Expand Up @@ -85,16 +85,16 @@ Here's a basic example:
# or sender.send_async(weights); sender.wait_async() # Asynchronous send

# On the worker process side:
# scheme.init_on_worker(model_id="policy", pipe=child_pipe, model=policy)
# scheme.init_on_receiver(model_id="policy", pipe=child_pipe, model=policy)
# receiver = scheme.get_receiver()
# # Non-blocking check for new weights
# if receiver.receive(timeout=0.001):
# # Weights were received and applied

# Example 2: Shared memory weight synchronization
# ------------------------------------------------
# Create shared memory scheme with auto-registration
shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True)
# Create shared memory scheme
shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict")

# Initialize with pipes for lazy registration
parent_pipe2, child_pipe2 = mp.Pipe()
Expand Down Expand Up @@ -159,7 +159,7 @@ across multiple inference workers:
# Example 2: Multiple collectors with shared memory
# --------------------------------------------------
# Shared memory is more efficient for frequent updates
shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True)
shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict")

collector = MultiSyncDataCollector(
create_env_fn=[
Expand Down Expand Up @@ -198,6 +198,9 @@ Weight Senders
:template: rl_template.rst

WeightSender
MPWeightSender
RPCWeightSender
DistributedWeightSender
RayModuleTransformSender

Weight Receivers
Expand All @@ -208,6 +211,9 @@ Weight Receivers
:template: rl_template.rst

WeightReceiver
MPWeightReceiver
RPCWeightReceiver
DistributedWeightReceiver
RayModuleTransformReceiver

Transports
Expand Down
2 changes: 1 addition & 1 deletion examples/collectors/multi_weight_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms.module import ModuleTransform
from torchrl.weight_update.weight_sync_schemes import MultiProcessWeightSyncScheme
from torchrl.weight_update import MultiProcessWeightSyncScheme


def make_module():
Expand Down
2 changes: 1 addition & 1 deletion examples/collectors/weight_sync_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def example_multi_collector_shared_memory():
env.close()

# Shared memory is more efficient for frequent updates
scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True)
scheme = SharedMemWeightSyncScheme(strategy="tensordict")

print("Creating multi-collector with shared memory...")
collector = MultiSyncDataCollector(
Expand Down
4 changes: 2 additions & 2 deletions examples/collectors/weight_sync_standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def example_shared_memory_sync():
# Create a simple policy
policy = nn.Linear(4, 2)

# Create shared memory scheme with auto-registration
scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True)
# Create shared memory scheme
scheme = SharedMemWeightSyncScheme(strategy="tensordict")
sender = scheme.create_sender()

# Create pipe for lazy registration
Expand Down
16 changes: 9 additions & 7 deletions test/services/test_python_executor_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_service_execution(self, ray_init):
result = x + y
print(f"Result: {result}")
"""
result = ray.get(executor.execute.remote(code), timeout=2)
result = ray.get(executor.execute.remote(code), timeout=10)

assert result["success"] is True
assert "Result: 30" in result["stdout"]
Expand Down Expand Up @@ -101,7 +101,7 @@ def test_service_execution_error(self, ray_init):

# Execute code with an error
code = "raise ValueError('Test error')"
result = ray.get(executor.execute.remote(code), timeout=2)
result = ray.get(executor.execute.remote(code), timeout=10)

assert result["success"] is False
assert "ValueError: Test error" in result["stderr"]
Expand All @@ -119,7 +119,7 @@ def test_multiple_executions(self, ray_init):
"python_executor",
PythonExecutorService,
pool_size=4,
timeout=5.0,
timeout=10.0,
num_cpus=4,
max_concurrency=4,
)
Expand All @@ -132,14 +132,16 @@ def test_multiple_executions(self, ray_init):
code = f"print('Execution {i}')"
futures.append(executor.execute.remote(code))

# Wait for all to complete
results = ray.get(futures, timeout=5)
# Wait for all to complete with longer timeout
results = ray.get(futures, timeout=30)

# All should succeed
assert len(results) == 8
for i, result in enumerate(results):
assert result["success"] is True
assert f"Execution {i}" in result["stdout"]
assert result["success"] is True, f"Execution {i} failed: {result}"
assert (
f"Execution {i}" in result["stdout"]
), f"Expected 'Execution {i}' in stdout, got: {result['stdout']!r}"

finally:
services.reset()
Expand Down
Loading