Skip to content

Commit 57de36b

Browse files
crabshellmandest1n1s
authored andcommitted
feat(activation):enable aligned permutation of crossmodel gen
This is achieved via setting same seeds for randperm generators and make sure they all lie in cuda device
1 parent ea07cad commit 57de36b

File tree

5 files changed

+31
-7
lines changed

5 files changed

+31
-7
lines changed

src/lm_saes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
ActivationFactoryDatasetSource,
77
ActivationFactoryTarget,
88
ActivationWriterConfig,
9+
BufferShuffleConfig,
910
CrossCoderConfig,
1011
DatasetConfig,
1112
FeatureAnalyzerConfig,
@@ -38,6 +39,7 @@
3839
"ActivationFactoryDatasetSource",
3940
"ActivationFactoryConfig",
4041
"ActivationWriterConfig",
42+
"BufferShuffleConfig",
4143
"ActivationFactoryTarget",
4244
"load_dataset",
4345
"load_model",

src/lm_saes/activation/factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def build_batchler():
207207
"""Create batchler for batched-activations-1d target."""
208208
assert cfg.batch_size is not None, "Batch size must be provided for outputting batched-activations-1d"
209209
return ActivationBatchler(
210-
hook_points=cfg.hook_points, batch_size=cfg.batch_size, buffer_size=cfg.buffer_size
210+
hook_points=cfg.hook_points, batch_size=cfg.batch_size, buffer_size=cfg.buffer_size, buffer_shuffle_config=cfg.buffer_shuffle_config
211211
)
212212

213213
processors = [build_batchler()] if cfg.target >= ActivationFactoryTarget.BATCHED_ACTIVATIONS_1D else []

src/lm_saes/activation/processors/activation.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from transformer_lens import HookedTransformer
88

99
from lm_saes.activation.processors.core import BaseActivationProcessor
10+
from lm_saes.config import BufferShuffleConfig
1011

1112

1213
@dataclass
@@ -23,6 +24,7 @@ class ActivationBuffer:
2324
"""
2425

2526
buffer: list[dict[str, Any]] = field(default_factory=list)
27+
generator: torch.Generator = torch.Generator() # Generator passed from ActivationBatchler
2628

2729
def __len__(self) -> int:
2830
"""Get the number of samples in the buffer.
@@ -41,7 +43,7 @@ def cat(self, activations: dict[str, Any]) -> "ActivationBuffer":
4143
Returns:
4244
ActivationBuffer: New buffer containing concatenated activations
4345
"""
44-
return ActivationBuffer(buffer=self.buffer + [activations])
46+
return ActivationBuffer(buffer=self.buffer + [activations], generator=self.generator)
4547

4648
def consume(self) -> dict[str, torch.Tensor | list[Any]]:
4749
"""Consume the buffer and return the activations as a dictionary."""
@@ -68,7 +70,7 @@ def yield_batch(self, batch_size: int) -> tuple[dict[str, torch.Tensor | list[An
6870
data = self.consume()
6971
batch = {k: v[:batch_size] for k, v in data.items()}
7072
buffer = {k: v[batch_size:] for k, v in data.items()}
71-
return batch, ActivationBuffer(buffer=[buffer])
73+
return batch, ActivationBuffer(buffer=[buffer], generator=self.generator)
7274

7375
def shuffle(self) -> "ActivationBuffer":
7476
"""Randomly shuffle all samples in the buffer.
@@ -81,9 +83,11 @@ def shuffle(self) -> "ActivationBuffer":
8183
isinstance(data[k], torch.Tensor) for k in data.keys()
8284
), "All data must be tensors to perform shuffling"
8385
data = cast(dict[str, torch.Tensor], data)
84-
perm = torch.randperm(data[list(data.keys())[0]].shape[0])
86+
87+
# Use the passed generator for shuffling
88+
perm = torch.randperm(data[list(data.keys())[0]].shape[0], generator=self.generator, device=self.generator.device)
8589
buffer = {k: v[perm] for k, v in data.items()}
86-
return ActivationBuffer(buffer=[buffer])
90+
return ActivationBuffer(buffer=[buffer], generator=self.generator)
8791

8892

8993
class ActivationGenerator(BaseActivationProcessor[Iterable[dict[str, Any]], Iterable[dict[str, Any]]]):
@@ -254,7 +258,7 @@ class ActivationBatchler(BaseActivationProcessor[Iterable[dict[str, Any]], Itera
254258
data will be refilled into the buffer whenever the buffer is less than half full, and then re-shuffled.
255259
"""
256260

257-
def __init__(self, hook_points: list[str], batch_size: int, buffer_size: Optional[int] = None):
261+
def __init__(self, hook_points: list[str], batch_size: int, buffer_size: Optional[int] = None, buffer_shuffle_config: Optional[BufferShuffleConfig] = None):
258262
"""Initialize the ActivationBatchler.
259263
260264
Args:
@@ -265,6 +269,10 @@ def __init__(self, hook_points: list[str], batch_size: int, buffer_size: Optiona
265269
self.hook_points = hook_points
266270
self.batch_size = batch_size
267271
self.buffer_size = buffer_size
272+
self.perm_generator = torch.Generator()
273+
if buffer_shuffle_config is not None:
274+
self.perm_generator = torch.Generator(buffer_shuffle_config.generator_device)
275+
self.perm_generator.manual_seed(buffer_shuffle_config.perm_seed) # Set seed if provided
268276

269277
def process(self, data: Iterable[dict[str, Any]], **kwargs) -> Iterable[dict[str, Any]]:
270278
"""Process input data by batching activations.
@@ -283,7 +291,7 @@ def process(self, data: Iterable[dict[str, Any]], **kwargs) -> Iterable[dict[str
283291
Raises:
284292
AssertionError: If hook points are missing or tensors have invalid shapes
285293
"""
286-
buffer = ActivationBuffer()
294+
buffer = ActivationBuffer(generator=self.perm_generator)
287295
pbar = tqdm(total=self.buffer_size, desc="Buffer monitor", miniters=1)
288296

289297
for d in data:

src/lm_saes/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,13 @@ def __le__(self, other: "ActivationFactoryTarget") -> bool:
231231
return self.stage <= other.stage
232232

233233

234+
class BufferShuffleConfig(BaseConfig):
235+
perm_seed: int = 42
236+
""" Perm seed for aligned permutation for generating activations. If `None`, will not use manual seed for Generator. """
237+
generator_device: Optional[str]= None
238+
""" The device to be assigned for the torch.Generator. If 'None', generator will be initialized on cpu as pytorch default. """
239+
240+
234241
class ActivationFactoryConfig(BaseConfig):
235242
sources: list[ActivationFactoryDatasetSource | ActivationFactoryActivationsSource]
236243
""" List of sources to use for activations. Can be a dataset or a path to activations. """
@@ -254,6 +261,8 @@ class ActivationFactoryConfig(BaseConfig):
254261
else None
255262
)
256263
""" Buffer size for online shuffling. If `None`, no shuffling will be performed. """
264+
buffer_shuffle_config: Optional[BufferShuffleConfig] = None
265+
"""" Manual seed and device of generator for generating randomperm in buffer. """
257266
ignore_token_ids: Optional[list[int]] = None
258267
""" Tokens to ignore in the activations. """
259268

src/lm_saes/runner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ActivationFactoryTarget,
1818
ActivationWriterConfig,
1919
BaseSAEConfig,
20+
BufferShuffleConfig,
2021
DatasetConfig,
2122
FeatureAnalyzerConfig,
2223
InitializerConfig,
@@ -123,6 +124,9 @@ class GenerateActivationsSettings(BaseSettings):
123124

124125
buffer_size: Optional[int] = None
125126
"""Size of the buffer for activation generation"""
127+
128+
buffer_shuffle_config: Optional[BufferShuffleConfig] = None
129+
""""Manual seed and device of generator for generating randomperm in buffer"""
126130

127131
total_tokens: Optional[int] = None
128132
"""Optional total number of tokens to generate"""
@@ -198,6 +202,7 @@ def generate_activations(settings: GenerateActivationsSettings) -> None:
198202
model_batch_size=settings.model_batch_size,
199203
batch_size=settings.batch_size,
200204
buffer_size=settings.buffer_size,
205+
buffer_shuffle_config=settings.buffer_shuffle_config,
201206
)
202207

203208
# Configure activation writer

0 commit comments

Comments
 (0)