Skip to content

Commit 637fcae

Browse files
author
zhengfuhe
committed
add post_process example
1 parent c63e796 commit 637fcae

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from transformer_lens import hook_points
2+
from lm_saes import post_process_topk_to_jumprelu_runner, LanguageModelSAERunnerConfig, SAEConfig
3+
import os
4+
import torch
5+
import jsonlines
6+
7+
layer = 15
8+
9+
hook_point_in = 'R'
10+
hook_point_out = hook_point_in if hook_point_in != 'TC' else 'M'
11+
exp_factor = 8
12+
13+
HOOK_SUFFIX = {"M": "hook_mlp_out", "A": "hook_attn_out", "R": "hook_resid_post", "TC": "ln2.hook_normalized",
14+
"Emb": "hook_resid_pre"}
15+
16+
17+
hook_suffix_in = HOOK_SUFFIX[hook_point_in]
18+
hook_suffix_out = HOOK_SUFFIX[hook_point_out]
19+
ckpt_path = f"<base_path>/Llama3_1Base-LX{hook_point_in}-{exp_factor}x"
20+
ckpt_path = os.path.join(ckpt_path, f"Llama3_1Base-L{layer}{hook_point_in}-{exp_factor}x")
21+
sae_config = SAEConfig.from_pretrained(ckpt_path).to_dict()
22+
23+
24+
25+
model_name = "meta-llama/Llama-3.1-8B"
26+
# model_from_pretrained_path = "<local_model_path>"
27+
28+
hook_points = [
29+
f"blocks.{layer}.{hook_suffix_in}",
30+
f"blocks.{layer}.{hook_suffix_out}",
31+
]
32+
33+
cfg = LanguageModelSAERunnerConfig.from_flattened(
34+
dict(
35+
**sae_config,
36+
model_name=model_name,
37+
model_from_pretrained_path=model_from_pretrained_path,
38+
# d_model=4096,
39+
dataset_path="<local_dataset_path>",
40+
is_dataset_tokenized=False,
41+
is_dataset_on_disk=True,
42+
concat_tokens=False,
43+
context_size=1024,
44+
store_batch_size=4,
45+
hook_points=hook_points,
46+
use_cached_activations=False,
47+
hook_points_in=hook_points[0],
48+
hook_points_out=hook_points[1],
49+
# norm_activation="token-wise",
50+
decoder_exactly_unit_norm=False,
51+
decoder_bias_init_method="geometric_median",
52+
# use_glu_encoder=False,
53+
# use_ghost_grads=False, # Whether to use the ghost gradients for saving dead features.
54+
# use_decoder_bias=True,
55+
# sparsity_include_decoder_norm=True,
56+
remove_gradient_parallel_to_decoder_directions=False,
57+
# apply_decoder_bias_to_pre_encoder=True,
58+
# init_encoder_with_decoder_transpose=True,
59+
# expansion_factor=8,
60+
train_batch_size=2048,
61+
log_to_wandb=False,
62+
# device="cuda",
63+
# seed=44,
64+
# dtype=torch.bfloat16,
65+
exp_name="eval",
66+
exp_result_dir=f"./result/{layer}_{hook_point_in}",
67+
)
68+
)
69+
70+
post_process_topk_to_jumprelu_runner(cfg)

0 commit comments

Comments
 (0)