Skip to content

Commit e4eeba2

Browse files
committed
feat(activation): add tqdm in loading cached activation
1 parent 2fcec8b commit e4eeba2

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/lm_saes/activation/processors/cached_activation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ class SequentialCachedActivationLoader(BaseCachedActivationLoader):
216216
"""Sequential implementation of cached activation loader."""
217217

218218
def _process_chunks(self, hook_chunks: dict[str, list[ChunkInfo]], total_chunks: int) -> Iterator[dict[str, Any]]:
219-
for chunk_idx in range(total_chunks):
219+
for chunk_idx in tqdm(range(total_chunks), desc="Loading chunks", smoothing=0.001, miniters=1):
220220
chunk_data = self._load_chunk_for_hooks(chunk_idx, hook_chunks)
221221
chunk_data = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in chunk_data.items()}
222222
yield chunk_data
@@ -268,7 +268,8 @@ def _process_chunks(self, hook_chunks: dict[str, list[ChunkInfo]], total_chunks:
268268
for future in tqdm(done, desc="Processing chunks", smoothing=0.001, leave=False, disable=True):
269269
chunk_data = future.result()
270270
chunk_data = {
271-
k: v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else v for k, v in chunk_data.items()
271+
k: v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else v
272+
for k, v in chunk_data.items()
272273
}
273274
yield chunk_data
274275
pbar.update(1)

0 commit comments

Comments
 (0)