Skip to content

Commit 13576f8

Browse files
authored
Merge pull request #120 from OpenMOSS/llada
Fix batch size validation for data parallelism and adjust total count for activation processing in CachedActivationLoader
2 parents 989f7de + 3eb0ef6 commit 13576f8

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

src/lm_saes/activation/processors/activation.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -349,11 +349,15 @@ def process(self, data: Iterable[dict[str, Any]], **kwargs) -> Iterable[dict[str
349349
"""
350350
buffer = ActivationBuffer(generator=self.perm_generator, device_mesh=self.device_mesh)
351351
pbar = tqdm(total=self.buffer_size, desc="Buffer monitor", miniters=1, disable=True)
352-
352+
dp_size = get_mesh_dim_size(self.device_mesh, "data")
353353
for d in data:
354+
355+
def get_batch_size(x):
356+
return len(x) if isinstance(x, DTensor) else len(x) * dp_size
357+
354358
# Validate input: ensure all tensors and lists have consistent shapes
355-
assert all(len(d[k]) == len(d[next(iter(d.keys()))]) for k in d.keys()), (
356-
"All tensors and lists must have the same batch size"
359+
assert all(get_batch_size(d[k]) == get_batch_size(d[next(iter(d.keys()))]) for k in d.keys()), (
360+
f"All tensors and lists must have the same batch size, {[(k, len(d[k])) for k in d.keys()]}"
357361
)
358362

359363
# Add new data to buffer

src/lm_saes/activation/processors/cached_activation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def _process_chunks(self, hook_chunks: dict[str, list[ChunkInfo]], num_chunks: i
240240
else:
241241
for data in tqdm(
242242
dataloader,
243-
total=len(cached_activation_dataset),
243+
total=len(cached_activation_dataset) // self.device_mesh.size(),
244244
desc="Processing activation chunks",
245245
disable=not is_master(),
246246
):

0 commit comments

Comments
 (0)