77from transformer_lens import HookedTransformer
88
99from lm_saes .activation .processors .core import BaseActivationProcessor
10+ from lm_saes .config import BufferShuffleConfig
1011
1112
1213@dataclass
@@ -23,6 +24,7 @@ class ActivationBuffer:
2324 """
2425
2526 buffer : list [dict [str , Any ]] = field (default_factory = list )
27+ generator : torch .Generator = torch .Generator () # Generator passed from ActivationBatchler
2628
2729 def __len__ (self ) -> int :
2830 """Get the number of samples in the buffer.
@@ -41,7 +43,7 @@ def cat(self, activations: dict[str, Any]) -> "ActivationBuffer":
4143 Returns:
4244 ActivationBuffer: New buffer containing concatenated activations
4345 """
44- return ActivationBuffer (buffer = self .buffer + [activations ])
46+ return ActivationBuffer (buffer = self .buffer + [activations ], generator = self . generator )
4547
4648 def consume (self ) -> dict [str , torch .Tensor | list [Any ]]:
4749 """Consume the buffer and return the activations as a dictionary."""
@@ -68,7 +70,7 @@ def yield_batch(self, batch_size: int) -> tuple[dict[str, torch.Tensor | list[An
6870 data = self .consume ()
6971 batch = {k : v [:batch_size ] for k , v in data .items ()}
7072 buffer = {k : v [batch_size :] for k , v in data .items ()}
71- return batch , ActivationBuffer (buffer = [buffer ])
73+ return batch , ActivationBuffer (buffer = [buffer ], generator = self . generator )
7274
7375 def shuffle (self ) -> "ActivationBuffer" :
7476 """Randomly shuffle all samples in the buffer.
@@ -81,9 +83,11 @@ def shuffle(self) -> "ActivationBuffer":
8183 isinstance (data [k ], torch .Tensor ) for k in data .keys ()
8284 ), "All data must be tensors to perform shuffling"
8385 data = cast (dict [str , torch .Tensor ], data )
84- perm = torch .randperm (data [list (data .keys ())[0 ]].shape [0 ])
86+
87+ # Use the passed generator for shuffling
88+ perm = torch .randperm (data [list (data .keys ())[0 ]].shape [0 ], generator = self .generator , device = self .generator .device )
8589 buffer = {k : v [perm ] for k , v in data .items ()}
86- return ActivationBuffer (buffer = [buffer ])
90+ return ActivationBuffer (buffer = [buffer ], generator = self . generator )
8791
8892
8993class ActivationGenerator (BaseActivationProcessor [Iterable [dict [str , Any ]], Iterable [dict [str , Any ]]]):
@@ -254,7 +258,7 @@ class ActivationBatchler(BaseActivationProcessor[Iterable[dict[str, Any]], Itera
254258 data will be refilled into the buffer whenever the buffer is less than half full, and then re-shuffled.
255259 """
256260
257- def __init__ (self , hook_points : list [str ], batch_size : int , buffer_size : Optional [int ] = None ):
261+ def __init__ (self , hook_points : list [str ], batch_size : int , buffer_size : Optional [int ] = None , buffer_shuffle_config : Optional [ BufferShuffleConfig ] = None ):
258262 """Initialize the ActivationBatchler.
259263
260264 Args:
@@ -265,6 +269,10 @@ def __init__(self, hook_points: list[str], batch_size: int, buffer_size: Optiona
265269 self .hook_points = hook_points
266270 self .batch_size = batch_size
267271 self .buffer_size = buffer_size
272+ self .perm_generator = torch .Generator ()
273+ if buffer_shuffle_config is not None :
274+ self .perm_generator = torch .Generator (buffer_shuffle_config .generator_device )
275+ self .perm_generator .manual_seed (buffer_shuffle_config .perm_seed ) # Set seed if provided
268276
269277 def process (self , data : Iterable [dict [str , Any ]], ** kwargs ) -> Iterable [dict [str , Any ]]:
270278 """Process input data by batching activations.
@@ -283,7 +291,7 @@ def process(self, data: Iterable[dict[str, Any]], **kwargs) -> Iterable[dict[str
283291 Raises:
284292 AssertionError: If hook points are missing or tensors have invalid shapes
285293 """
286- buffer = ActivationBuffer ()
294+ buffer = ActivationBuffer (generator = self . perm_generator )
287295 pbar = tqdm (total = self .buffer_size , desc = "Buffer monitor" , miniters = 1 )
288296
289297 for d in data :
0 commit comments