Skip to content

Commit 30e4cfa

Browse files
authored
Merge pull request #127 from OpenMOSS/dev
Docs update
2 parents ab8c9f0 + 6cb1392 commit 30e4cfa

File tree

12 files changed

+892
-3
lines changed

12 files changed

+892
-3
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ Please cite this library as:
9292
```
9393
@misc{Ge2024OpenMossSAEs,
9494
title = {OpenMoss Language Model Sparse Autoencoders},
95-
author = {Xuyang Ge, Fukang Zhu, Junxuan Wang, Wentao Shu, Lingjie Chen, Zhengfu He},
95+
author = {Xuyang Ge, Wentao Shu, Junxuan Wang, Guancheng Zhou, Jiaxing Wu, Fukang Zhu, Lingjie Chen, Zhengfu He},
9696
url = {https://github.com/OpenMOSS/Language-Model-SAEs},
9797
year = {2024}
9898
}

docs/assets/images/lm-saes-overview.svg

Lines changed: 4 additions & 0 deletions
Loading

docs/concepts.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Key Concepts
2+
3+
![Overview of Language Model SAEs Pipeline](assets/images/lm-saes-overview.svg)

docs/index.md

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,112 @@ This library provides:
1010

1111
- **Scalability**: Our framework is fully distributed with arbitrary combinations of data, model, and head parallelism for both training and analysis. Enjoy training SAEs with millions of features!
1212
- **Flexibility**: We support a wide range of SAE variants, including vanilla SAEs, Lorsa (Low-rank Sparse Attention), CLT (Cross-layer Transcoder), MoLT (Mixture of Linear Transforms), CrossCoder, and more. Each variant can be combined with different activation functions (e.g., ReLU, JumpReLU, TopK, BatchTopK) and sparsity penalties (e.g., L1, Tanh).
13-
- **Easy to Use**: We provide high-level `runners` APIs to quickly launch experiments with simple configurations. Check our [examples](examples) for verified hyperparameters.
13+
- **Easy to Use**: We provide high-level `runners` APIs to quickly launch experiments with simple configurations. Check our [examples](https://github.com/OpenMOSS/Language-Model-SAEs/tree/main/examples) for verified hyperparameters.
1414
- **Visualization**: We provide a unified web interface to visualize learned SAE variants and their features.
1515

16-
## Getting Started
16+
## Quick Start
17+
18+
### Installation
19+
20+
=== "Astral uv"
21+
22+
We strongly recommend users to use [uv](https://docs.astral.sh/uv/) for dependency management. uv is a modern drop-in replacement of poetry or pdm, with a lightning fast dependency resolution and package installation. See their [instructions](https://docs.astral.sh/uv/getting-started/) on how to initialize a Python project with uv.
23+
24+
To add our library as a project dependency, run:
25+
26+
```bash
27+
uv add lm-saes
28+
```
29+
30+
We also support [Ascend NPU](https://github.com/Ascend/pytorch) as an accelerator backend. To add our library as a project dependency with NPU dependency constraints, run:
31+
32+
```bash
33+
uv add lm-saes[npu]
34+
```
35+
36+
=== "Pip"
37+
38+
Of course, you can also directly use [pip](https://pypi.org/project/pip/) to install our library. To install our library with pip, run:
39+
40+
```bash
41+
pip install lm-saes
42+
```
43+
44+
Note that since we use a forked version of [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens), so it'll be better to install the package in a seperate environment created by [conda](https://github.com/conda-forge/miniforge) or [virtualenv](https://virtualenv.pypa.io/en/latest/) to avoid conflicts.
45+
46+
We also support [Ascend NPU](https://github.com/Ascend/pytorch) as an accelerator backend. To install our library with NPU dependency constraints, run:
47+
48+
```bash
49+
pip install lm-saes[npu]
50+
```
51+
52+
### Load a trained Sparse Autoencoder from HuggingFace
53+
54+
WIP
55+
56+
### Training a Sparse Autoencoder
57+
58+
To train a simple Sparse Autoencoder on `blocks.5.hook_resid_post` of a Pythia-160M model with $768*8$ features, you can use the following:
59+
60+
```python
61+
settings = TrainSAESettings(
62+
sae=SAEConfig(
63+
hook_point_in=f"blocks.5.hook_resid_post",
64+
d_model=768,
65+
expansion_factor=8,
66+
act_fn="jumprelu",
67+
),
68+
initializer=InitializerConfig(
69+
grid_search_init_norm=True,
70+
),
71+
trainer=TrainerConfig(
72+
lr=5e-5,
73+
l1_coefficient=0.3,
74+
total_training_tokens=800_000_000,
75+
sparsity_loss_type="tanh-quad",
76+
jumprelu_lr_factor=0.1,
77+
),
78+
wandb=WandbConfig(
79+
wandb_project="lm-saes",
80+
exp_name=name,
81+
),
82+
activation_factory=ActivationFactoryConfig(
83+
sources=[
84+
ActivationFactoryActivationsSource(
85+
path=Path(args.activation_path).expanduser(),
86+
name=f"pythia-160m-1d",
87+
device="cuda",
88+
dtype=torch.float32,
89+
)
90+
],
91+
target=ActivationFactoryTarget.ACTIVATIONS_1D,
92+
hook_points=["blocks.5.hook_resid_post"],
93+
batch_size=4096,
94+
buffer_size=None,
95+
),
96+
sae_name="L5R",
97+
sae_series="pythia-sae",
98+
)
99+
train_sae(settings)
100+
```
101+
102+
### Analyze a trained Sparse Autoencoder
103+
104+
WIP
105+
106+
### Convert trained Sparse Autoencoder to SAELens format
107+
108+
WIP
109+
110+
## Citation
111+
112+
If you find this library useful in your research, please cite:
113+
114+
```
115+
@misc{Ge2024OpenMossSAEs,
116+
title = {OpenMoss Language Model Sparse Autoencoders},
117+
author = {Xuyang Ge, Wentao Shu, Junxuan Wang, Guancheng Zhou, Jiaxing Wu, Fukang Zhu, Lingjie Chen, Zhengfu He},
118+
url = {https://github.com/OpenMOSS/Language-Model-SAEs},
119+
year = {2024}
120+
}
121+
```
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Codes for Evolution of Concepts in Language Model Pre-Training
2+
3+
## Install the Environment
4+
5+
We use [uv](https://docs.astral.sh/uv/getting-started/installation/) as the dependency manager. Install `uv`, and run:
6+
7+
```bash
8+
uv sync --extra default
9+
```
10+
11+
to fetch all dependencies.
12+
13+
## Replicate the Crosscoders
14+
15+
To replicate our key results, you need to generate Pythia model activations, train the crosscoders, and analyze the crosscoders.
16+
17+
### Requirements
18+
19+
The following instructions assume you have access to a GPU cluster with at least 16 NVIDIA A100s/H100s or better GPUs, with CUDA version 12.8. With some simple modifications (e.g. change all `"cuda"` to `"npu"`, and install the environment by `uv sync --extra npu`), these codes can also run on an NPU cluster with at least 32 Ascend 910B or better NPUs. The cluster should have a large disk space (>200T) to save all model activations.
20+
21+
Our scripts also require you have a [subset](https://huggingface.co/datasets/Hzfinfdu/SlimPajama-3B) of the SlimPajama dataset saved by `dataset.save_to_disk()` at `~/data/SlimPajama-3B`, and all Pythia model checkpoints at `~/models/pythia-{size}-all/step{step}`, where `size` can be `160m` or `6.9b`. You can change the paths in the scripts to your own paths.
22+
23+
### Generate Activations
24+
25+
Two types of model activations are required for training and analyzing crosscoders:
26+
27+
1. **1D Activations:** Activations where the context dimension folds into the batch dimension and re-shuffled. Typically with the shape of `(batch, d_model)`. Use for crosscoder training.
28+
2. **2D Activations:** Activations where the context dimension is reserved. Typically with the shape of `(batch, n_context, d_model)`. Use for crosscoder analyzing.
29+
30+
To generate 1D activations of Pythia-160M, run:
31+
32+
```bash
33+
uv run torchrun --nproc-per-node=8 generate-pythia-activations-1d.py --size 160m --layer 6
34+
```
35+
36+
This will take up ~40T disk space.
37+
38+
To generate 2D activations of Pythia-160M, run:
39+
40+
```bash
41+
uv run torchrun --nproc-per-node=8 generate-pythia-activations-2d.py --size 160m --layer 6
42+
```
43+
44+
To generate 1D activations of Pythia-6.9B, run:
45+
46+
```bash
47+
uv run torchrun --nproc-per-node=8 generate-pythia-activations-1d.py --size 6.9b --layer 16
48+
```
49+
50+
This will take up ~170T disk space.
51+
52+
To generate 2D activations of Pythia-160M, run:
53+
54+
```bash
55+
uv run torchrun --nproc-per-node=8 generate-pythia-activations-2d.py --size 6.9b --layer 16
56+
```
57+
58+
### Training Crosscoders
59+
60+
To train crosscoders on Pythia-160M, run:
61+
62+
```bash
63+
uv run torchrun --nproc-per-node=8 train-pythia-crosscoders.py --init_encoder_factor 1 --lr 5e-5 --l1_coefficient 0.3 --jumprelu_lr_factor 0.1 --layer 6 --expansion_factor 32 --batch_size 2048
64+
```
65+
66+
To train crosscoders on Pythia-6.9B, run:
67+
68+
```bash
69+
uv run torchrun --nproc-per-node=8 --nnodes=2 train-pythia-crosscoders.py --init_encoder_factor 1 --lr 1e-5 --l1_coefficient 0.3 --jumprelu_lr_factor 0.3 --layer 16 --expansion_factor 8 --batch_size 2048 --size 6.9b # Require 2 nodes
70+
```
71+
72+
You can modify the `expansion_factor` to get crosscoders with different dictionary sizes, and modify the `l1_coefficient` to move the trade-off between sparsity and reconstruction fidelity.
73+
74+
### Analyze Crosscoders
75+
76+
To analyze trained crosscoders, you should first have a MongoDB instance run at `localhost:27017`, and run
77+
78+
```bash
79+
uv run analyze-pythia-crosscoder.py --name <crosscoder-name> --batch-size 16
80+
```
81+
82+
where `<crosscoder-name>` is the name of your trained crosscoder. Results will be saved to the MongoDB. Afterwards, you can use our visualization tool to view the features:
83+
84+
```bash
85+
cd ../../ui
86+
bun install
87+
bun run dev
88+
```
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import argparse
2+
import os
3+
import re
4+
from pathlib import Path
5+
6+
import torch
7+
from more_itertools import batched
8+
9+
from lm_saes import (
10+
ActivationFactoryActivationsSource,
11+
ActivationFactoryConfig,
12+
ActivationFactoryTarget,
13+
AnalyzeCrossCoderSettings,
14+
CrossCoderConfig,
15+
FeatureAnalyzerConfig,
16+
MongoDBConfig,
17+
analyze_crosscoder,
18+
)
19+
20+
21+
def parse_args():
22+
parser = argparse.ArgumentParser()
23+
parser.add_argument("--size", type=str, default="160m")
24+
parser.add_argument("--name", type=str, default="L6R-lr5e-05-l1c0.5-32heads-8x-jlr0.1")
25+
parser.add_argument("--batch-size", type=int, default=16)
26+
parser.add_argument("--analysis-name", type=str, default="default")
27+
return parser.parse_args()
28+
29+
30+
d_model_map = {
31+
"70m": 512,
32+
"160m": 768,
33+
"410m": 1024,
34+
"1b": 2048,
35+
"1.4b": 2048,
36+
"2.8b": 2048,
37+
"6.9b": 4096,
38+
"12b": 5120,
39+
}
40+
41+
n_layers_map = {
42+
"70m": 6,
43+
"160m": 12,
44+
"410m": 24,
45+
"1b": 16,
46+
"1.4b": 24,
47+
"2.8b": 32,
48+
"6.9b": 32,
49+
"12b": 36,
50+
}
51+
52+
steps = [
53+
0,
54+
2,
55+
4,
56+
8,
57+
16,
58+
32,
59+
64,
60+
128,
61+
256,
62+
512,
63+
1000,
64+
2000,
65+
3000,
66+
4000,
67+
5000,
68+
6000,
69+
7000,
70+
8000,
71+
9000,
72+
10000,
73+
14000,
74+
20000,
75+
27000,
76+
34000,
77+
47000,
78+
60000,
79+
74000,
80+
87000,
81+
100000,
82+
114000,
83+
127000,
84+
143000,
85+
]
86+
87+
if __name__ == "__main__":
88+
args = parse_args()
89+
world_size = int(os.environ.get("WORLD_SIZE"))
90+
if world_size is None:
91+
raise ValueError("WORLD_SIZE is not set")
92+
assert len(steps) % world_size == 0, f"Head count {len(steps)} is not divisible by world size {world_size}"
93+
94+
head_per_device = len(steps) // world_size
95+
layer = int(re.search(r"L(\d+)R", args.name).group(1))
96+
print(f"Analyzing {args.name} at layer {layer}")
97+
settings = AnalyzeCrossCoderSettings(
98+
sae=CrossCoderConfig.from_pretrained(
99+
os.path.expanduser(f"~/results/{args.name}"),
100+
device="cuda",
101+
dtype=torch.float16,
102+
),
103+
analyzer=FeatureAnalyzerConfig(
104+
total_analyzing_tokens=100_000_000,
105+
subsamples={
106+
"top_activations": {"proportion": 1.0, "n_samples": 20},
107+
"non_activating": {"proportion": 0.3, "n_samples": 20, "max_length": 50},
108+
},
109+
ignore_token_ids=[0],
110+
),
111+
sae_name=args.name,
112+
sae_series="pythia-crosscoder",
113+
activation_factories=[
114+
ActivationFactoryConfig(
115+
sources=[
116+
ActivationFactoryActivationsSource(
117+
path={
118+
f"step{step}": Path(
119+
os.path.expanduser(
120+
f"~/activations/SlimPajama-3B-activations-pythia-{args.size}-2d-all-fp16/step{step}/blocks.{layer}.hook_resid_post"
121+
)
122+
)
123+
for step in per_device_steps
124+
},
125+
sample_weights=1.0,
126+
name="SlimPajama-3B",
127+
device="cuda",
128+
dtype=torch.float16,
129+
)
130+
],
131+
target=ActivationFactoryTarget.ACTIVATIONS_2D,
132+
hook_points=[f"step{step}" for step in per_device_steps],
133+
batch_size=args.batch_size,
134+
)
135+
for per_device_steps in batched(steps, head_per_device)
136+
],
137+
mongo=MongoDBConfig(),
138+
feature_analysis_name=args.analysis_name,
139+
device_type="cuda",
140+
)
141+
analyze_crosscoder(settings)

0 commit comments

Comments
 (0)