Skip to content

Commit eb5088a

Browse files
committed
fix(example): fix error in loading example
1 parent ba1dbcc commit eb5088a

File tree

4 files changed

+53
-32
lines changed

4 files changed

+53
-32
lines changed

examples/loading_llamascope_saes.ipynb

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,11 @@
2727
}
2828
],
2929
"source": [
30-
"import os\n",
3130
"import torch\n",
32-
"import transformers\n",
33-
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
31+
"from transformer_lens import HookedTransformer\n",
32+
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
3433
"\n",
35-
"from lm_saes import SparseAutoEncoder, SAEConfig\n",
36-
"\n",
37-
"from tqdm import tqdm\n",
38-
"from transformer_lens import HookedTransformer, HookedTransformerConfig"
34+
"from lm_saes import SparseAutoEncoder"
3935
]
4036
},
4137
{
@@ -61,28 +57,18 @@
6157
],
6258
"source": [
6359
"model_name = \"meta-llama/Llama-3.1-8B\"\n",
64-
" \n",
65-
"hf_model = AutoModelForCausalLM.from_pretrained(\n",
66-
" (\n",
67-
" model_name\n",
68-
" if model_from_pretrained_path is None\n",
69-
" else model_from_pretrained_path\n",
70-
" ),\n",
71-
")\n",
60+
"\n",
61+
"hf_model = AutoModelForCausalLM.from_pretrained(model_name)\n",
7262
"\n",
7363
"hf_tokenizer = AutoTokenizer.from_pretrained(\n",
74-
" (\n",
75-
" model_name\n",
76-
" if model_from_pretrained_path is None\n",
77-
" else model_from_pretrained_path\n",
78-
" ),\n",
64+
" model_name,\n",
7965
" trust_remote_code=True,\n",
8066
" use_fast=True,\n",
8167
" add_bos_token=True,\n",
8268
")\n",
8369
"model = HookedTransformer.from_pretrained_no_processing(\n",
8470
" model_name,\n",
85-
" device='cuda',\n",
71+
" device=\"cuda\",\n",
8672
" hf_model=hf_model,\n",
8773
" tokenizer=hf_tokenizer,\n",
8874
" dtype=torch.bfloat16,\n",
@@ -127,7 +113,7 @@
127113
}
128114
],
129115
"source": [
130-
"sae = SparseAutoEncoder.from_pretrained('fnlp/Llama3_1-8B-Base-L15R-8x')"
116+
"sae = SparseAutoEncoder.from_pretrained(\"fnlp/Llama3_1-8B-Base-L15R-8x\")"
131117
]
132118
},
133119
{
@@ -179,7 +165,7 @@
179165
],
180166
"source": [
181167
"# L0 Sparsity. The first token is <bos> which extremely out-of-distribution.\n",
182-
"(sae.compute_loss(cache['blocks.15.hook_resid_post'])[1][1]['feature_acts'] > 0).sum(-1)"
168+
"(sae.compute_loss(cache[\"blocks.15.hook_resid_post\"])[1][1][\"feature_acts\"] > 0).sum(-1)"
183169
]
184170
},
185171
{
@@ -201,7 +187,10 @@
201187
],
202188
"source": [
203189
"# Reconstruction loss\n",
204-
"(sae.compute_loss(cache['blocks.15.hook_resid_post'][:, 1:])[1][1]['reconstructed'] - cache['blocks.15.hook_resid_post'][:, 1:]).pow(2).mean()"
190+
"(\n",
191+
" sae.compute_loss(cache[\"blocks.15.hook_resid_post\"][:, 1:])[1][1][\"reconstructed\"]\n",
192+
" - cache[\"blocks.15.hook_resid_post\"][:, 1:]\n",
193+
").pow(2).mean()"
205194
]
206195
},
207196
{

examples/programmatic/post_process_topk.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
import os
2-
from lm_saes import post_process_topk_to_jumprelu_runner, LanguageModelSAERunnerConfig, SAEConfig
32

3+
from lm_saes import LanguageModelSAERunnerConfig, SAEConfig, post_process_topk_to_jumprelu_runner
44

55
layer = 15
66

7-
hook_point_in = 'R'
8-
hook_point_out = hook_point_in if hook_point_in != 'TC' else 'M'
7+
hook_point_in = "R"
8+
hook_point_out = hook_point_in if hook_point_in != "TC" else "M"
99
exp_factor = 8
1010

11-
HOOK_SUFFIX = {"M": "hook_mlp_out", "A": "hook_attn_out", "R": "hook_resid_post", "TC": "ln2.hook_normalized",
12-
"Emb": "hook_resid_pre"}
11+
HOOK_SUFFIX = {
12+
"M": "hook_mlp_out",
13+
"A": "hook_attn_out",
14+
"R": "hook_resid_post",
15+
"TC": "ln2.hook_normalized",
16+
"Emb": "hook_resid_pre",
17+
}
1318

1419

1520
hook_suffix_in = HOOK_SUFFIX[hook_point_in]
@@ -19,7 +24,6 @@
1924
sae_config = SAEConfig.from_pretrained(ckpt_path).to_dict()
2025

2126

22-
2327
model_name = "meta-llama/Llama-3.1-8B"
2428
# model_from_pretrained_path = "<local_model_path>"
2529

@@ -65,4 +69,4 @@
6569
)
6670
)
6771

68-
post_process_topk_to_jumprelu_runner(cfg)
72+
post_process_topk_to_jumprelu_runner(cfg)

pdm.lock

Lines changed: 28 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ dev = [
5454
"nbformat>=5.10.4",
5555
"kaleido==0.2.1",
5656
"pre-commit>=4.0.1",
57+
"ruff>=0.7.1",
5758
]
5859

5960
[tool.mypy]

0 commit comments

Comments
 (0)