|
27 | 27 | } |
28 | 28 | ], |
29 | 29 | "source": [ |
30 | | - "import os\n", |
31 | 30 | "import torch\n", |
32 | | - "import transformers\n", |
33 | | - "from transformers import AutoTokenizer, AutoModelForCausalLM\n", |
| 31 | + "from transformer_lens import HookedTransformer\n", |
| 32 | + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", |
34 | 33 | "\n", |
35 | | - "from lm_saes import SparseAutoEncoder, SAEConfig\n", |
36 | | - "\n", |
37 | | - "from tqdm import tqdm\n", |
38 | | - "from transformer_lens import HookedTransformer, HookedTransformerConfig" |
| 34 | + "from lm_saes import SparseAutoEncoder" |
39 | 35 | ] |
40 | 36 | }, |
41 | 37 | { |
|
61 | 57 | ], |
62 | 58 | "source": [ |
63 | 59 | "model_name = \"meta-llama/Llama-3.1-8B\"\n", |
64 | | - " \n", |
65 | | - "hf_model = AutoModelForCausalLM.from_pretrained(\n", |
66 | | - " (\n", |
67 | | - " model_name\n", |
68 | | - " if model_from_pretrained_path is None\n", |
69 | | - " else model_from_pretrained_path\n", |
70 | | - " ),\n", |
71 | | - ")\n", |
| 60 | + "\n", |
| 61 | + "hf_model = AutoModelForCausalLM.from_pretrained(model_name)\n", |
72 | 62 | "\n", |
73 | 63 | "hf_tokenizer = AutoTokenizer.from_pretrained(\n", |
74 | | - " (\n", |
75 | | - " model_name\n", |
76 | | - " if model_from_pretrained_path is None\n", |
77 | | - " else model_from_pretrained_path\n", |
78 | | - " ),\n", |
| 64 | + " model_name,\n", |
79 | 65 | " trust_remote_code=True,\n", |
80 | 66 | " use_fast=True,\n", |
81 | 67 | " add_bos_token=True,\n", |
82 | 68 | ")\n", |
83 | 69 | "model = HookedTransformer.from_pretrained_no_processing(\n", |
84 | 70 | " model_name,\n", |
85 | | - " device='cuda',\n", |
| 71 | + " device=\"cuda\",\n", |
86 | 72 | " hf_model=hf_model,\n", |
87 | 73 | " tokenizer=hf_tokenizer,\n", |
88 | 74 | " dtype=torch.bfloat16,\n", |
|
127 | 113 | } |
128 | 114 | ], |
129 | 115 | "source": [ |
130 | | - "sae = SparseAutoEncoder.from_pretrained('fnlp/Llama3_1-8B-Base-L15R-8x')" |
| 116 | + "sae = SparseAutoEncoder.from_pretrained(\"fnlp/Llama3_1-8B-Base-L15R-8x\")" |
131 | 117 | ] |
132 | 118 | }, |
133 | 119 | { |
|
179 | 165 | ], |
180 | 166 | "source": [ |
181 | 167 | "# L0 Sparsity. The first token is <bos> which extremely out-of-distribution.\n", |
182 | | - "(sae.compute_loss(cache['blocks.15.hook_resid_post'])[1][1]['feature_acts'] > 0).sum(-1)" |
| 168 | + "(sae.compute_loss(cache[\"blocks.15.hook_resid_post\"])[1][1][\"feature_acts\"] > 0).sum(-1)" |
183 | 169 | ] |
184 | 170 | }, |
185 | 171 | { |
|
201 | 187 | ], |
202 | 188 | "source": [ |
203 | 189 | "# Reconstruction loss\n", |
204 | | - "(sae.compute_loss(cache['blocks.15.hook_resid_post'][:, 1:])[1][1]['reconstructed'] - cache['blocks.15.hook_resid_post'][:, 1:]).pow(2).mean()" |
| 190 | + "(\n", |
| 191 | + " sae.compute_loss(cache[\"blocks.15.hook_resid_post\"][:, 1:])[1][1][\"reconstructed\"]\n", |
| 192 | + " - cache[\"blocks.15.hook_resid_post\"][:, 1:]\n", |
| 193 | + ").pow(2).mean()" |
205 | 194 | ] |
206 | 195 | }, |
207 | 196 | { |
|
0 commit comments