1+ from transformer_lens import hook_points
2+ from lm_saes import post_process_topk_to_jumprelu_runner , LanguageModelSAERunnerConfig , SAEConfig
3+ import os
4+ import torch
5+ import jsonlines
6+
7+ layer = 15
8+
9+ hook_point_in = 'R'
10+ hook_point_out = hook_point_in if hook_point_in != 'TC' else 'M'
11+ exp_factor = 8
12+
13+ HOOK_SUFFIX = {"M" : "hook_mlp_out" , "A" : "hook_attn_out" , "R" : "hook_resid_post" , "TC" : "ln2.hook_normalized" ,
14+ "Emb" : "hook_resid_pre" }
15+
16+
17+ hook_suffix_in = HOOK_SUFFIX [hook_point_in ]
18+ hook_suffix_out = HOOK_SUFFIX [hook_point_out ]
19+ ckpt_path = f"<base_path>/Llama3_1Base-LX{ hook_point_in } -{ exp_factor } x"
20+ ckpt_path = os .path .join (ckpt_path , f"Llama3_1Base-L{ layer } { hook_point_in } -{ exp_factor } x" )
21+ sae_config = SAEConfig .from_pretrained (ckpt_path ).to_dict ()
22+
23+
24+
25+ model_name = "meta-llama/Llama-3.1-8B"
26+ # model_from_pretrained_path = "<local_model_path>"
27+
28+ hook_points = [
29+ f"blocks.{ layer } .{ hook_suffix_in } " ,
30+ f"blocks.{ layer } .{ hook_suffix_out } " ,
31+ ]
32+
33+ cfg = LanguageModelSAERunnerConfig .from_flattened (
34+ dict (
35+ ** sae_config ,
36+ model_name = model_name ,
37+ model_from_pretrained_path = model_from_pretrained_path ,
38+ # d_model=4096,
39+ dataset_path = "<local_dataset_path>" ,
40+ is_dataset_tokenized = False ,
41+ is_dataset_on_disk = True ,
42+ concat_tokens = False ,
43+ context_size = 1024 ,
44+ store_batch_size = 4 ,
45+ hook_points = hook_points ,
46+ use_cached_activations = False ,
47+ hook_points_in = hook_points [0 ],
48+ hook_points_out = hook_points [1 ],
49+ # norm_activation="token-wise",
50+ decoder_exactly_unit_norm = False ,
51+ decoder_bias_init_method = "geometric_median" ,
52+ # use_glu_encoder=False,
53+ # use_ghost_grads=False, # Whether to use the ghost gradients for saving dead features.
54+ # use_decoder_bias=True,
55+ # sparsity_include_decoder_norm=True,
56+ remove_gradient_parallel_to_decoder_directions = False ,
57+ # apply_decoder_bias_to_pre_encoder=True,
58+ # init_encoder_with_decoder_transpose=True,
59+ # expansion_factor=8,
60+ train_batch_size = 2048 ,
61+ log_to_wandb = False ,
62+ # device="cuda",
63+ # seed=44,
64+ # dtype=torch.bfloat16,
65+ exp_name = "eval" ,
66+ exp_result_dir = f"./result/{ layer } _{ hook_point_in } " ,
67+ )
68+ )
69+
70+ post_process_topk_to_jumprelu_runner (cfg )
0 commit comments