|
7 | 7 | from torch.distributed.device_mesh import init_device_mesh |
8 | 8 |
|
9 | 9 | from lm_saes.activation.factory import ActivationFactory |
| 10 | +from lm_saes.analysis.direct_logit_attributor import DirectLogitAttributor |
10 | 11 | from lm_saes.analysis.feature_analyzer import FeatureAnalyzer |
| 12 | +from lm_saes.backend.language_model import TransformerLensLanguageModel |
11 | 13 | from lm_saes.config import ( |
12 | 14 | ActivationFactoryConfig, |
13 | 15 | BaseSAEConfig, |
14 | 16 | CrossCoderConfig, |
| 17 | + DirectLogitAttributorConfig, |
15 | 18 | FeatureAnalyzerConfig, |
16 | 19 | LanguageModelConfig, |
17 | 20 | MongoDBConfig, |
| 21 | + SAEConfig, |
18 | 22 | ) |
19 | 23 | from lm_saes.crosscoder import CrossCoder |
20 | 24 | from lm_saes.database import MongoClient |
| 25 | +from lm_saes.resource_loaders import load_model |
| 26 | +from lm_saes.runners.utils import load_config |
21 | 27 | from lm_saes.sae import SparseAutoEncoder |
22 | 28 | from lm_saes.utils.logging import get_distributed_logger, setup_logging |
| 29 | +from lm_saes.utils.misc import is_master |
23 | 30 |
|
24 | 31 | logger = get_distributed_logger("runners.analyze") |
25 | 32 |
|
@@ -199,3 +206,102 @@ def analyze_crosscoder(settings: AnalyzeCrossCoderSettings) -> None: |
199 | 206 | ) |
200 | 207 |
|
201 | 208 | logger.info("CrossCoder analysis completed successfully") |
| 209 | + |
| 210 | + |
| 211 | +class DirectLogitAttributeSettings(BaseSettings): |
| 212 | + """Settings for analyzing a CrossCoder model.""" |
| 213 | + |
| 214 | + sae: BaseSAEConfig |
| 215 | + """Configuration for the SAE model architecture and parameters""" |
| 216 | + |
| 217 | + sae_name: str |
| 218 | + """Name of the SAE model. Use as identifier for the SAE model in the database.""" |
| 219 | + |
| 220 | + sae_series: str |
| 221 | + """Series of the SAE model. Use as identifier for the SAE model in the database.""" |
| 222 | + |
| 223 | + model: Optional[LanguageModelConfig] = None |
| 224 | + """Configuration for the language model.""" |
| 225 | + |
| 226 | + model_name: str |
| 227 | + """Name of the language model.""" |
| 228 | + |
| 229 | + direct_logit_attributor: DirectLogitAttributorConfig |
| 230 | + """Configuration for the direct logit attributor.""" |
| 231 | + |
| 232 | + mongo: MongoDBConfig |
| 233 | + """Configuration for the MongoDB database.""" |
| 234 | + |
| 235 | + device_type: str = "cuda" |
| 236 | + """Device type to use for distributed training ('cuda' or 'cpu')""" |
| 237 | + |
| 238 | + # model_parallel_size: int = 1 |
| 239 | + # """Size of model parallel (tensor parallel) mesh""" |
| 240 | + |
| 241 | + # data_parallel_size: int = 1 |
| 242 | + # """Size of data parallel mesh""" |
| 243 | + |
| 244 | + # head_parallel_size: int = 1 |
| 245 | + # """Size of head parallel mesh""" |
| 246 | + |
| 247 | + |
| 248 | +@torch.no_grad() |
| 249 | +def direct_logit_attribute(settings: DirectLogitAttributeSettings) -> None: |
| 250 | + """Direct logit attribute a SAE model. |
| 251 | +
|
| 252 | + Args: |
| 253 | + settings: Configuration settings for DirectLogitAttributor |
| 254 | + """ |
| 255 | + # Set up logging |
| 256 | + setup_logging(level="INFO") |
| 257 | + |
| 258 | + # device_mesh = ( |
| 259 | + # init_device_mesh( |
| 260 | + # device_type=settings.device_type, |
| 261 | + # mesh_shape=(settings.head_parallel_size, settings.data_parallel_size, settings.model_parallel_size), |
| 262 | + # mesh_dim_names=("head", "data", "model"), |
| 263 | + # ) |
| 264 | + # if settings.head_parallel_size > 1 or settings.data_parallel_size > 1 or settings.model_parallel_size > 1 |
| 265 | + # else None |
| 266 | + # ) |
| 267 | + |
| 268 | + mongo_client = MongoClient(settings.mongo) |
| 269 | + logger.info("MongoDB client initialized") |
| 270 | + |
| 271 | + logger.info("Loading SAE model") |
| 272 | + if isinstance(settings.sae, CrossCoderConfig): |
| 273 | + sae = CrossCoder.from_config(settings.sae) |
| 274 | + elif isinstance(settings.sae, SAEConfig): |
| 275 | + sae = SparseAutoEncoder.from_config(settings.sae) |
| 276 | + else: |
| 277 | + raise ValueError(f"Unsupported SAE config type: {type(settings.sae)}") |
| 278 | + |
| 279 | + # Load configurations |
| 280 | + model_cfg = load_config( |
| 281 | + config=settings.model, |
| 282 | + name=settings.model_name, |
| 283 | + mongo_client=mongo_client, |
| 284 | + config_type="model", |
| 285 | + required=True, |
| 286 | + ) |
| 287 | + model_cfg.device = settings.device_type |
| 288 | + model_cfg.dtype = sae.cfg.dtype |
| 289 | + |
| 290 | + model = load_model(model_cfg) |
| 291 | + assert isinstance(model, TransformerLensLanguageModel), ( |
| 292 | + "DirectLogitAttributor only supports TransformerLensLanguageModel as the model backend" |
| 293 | + ) |
| 294 | + |
| 295 | + logger.info("Direct logit attribution") |
| 296 | + direct_logit_attributor = DirectLogitAttributor(settings.direct_logit_attributor) |
| 297 | + results = direct_logit_attributor.direct_logit_attribute(sae, model) |
| 298 | + |
| 299 | + if is_master(): |
| 300 | + logger.info("Direct logit attribution completed, saving results to MongoDB") |
| 301 | + mongo_client.update_features( |
| 302 | + sae_name=settings.sae_name, |
| 303 | + sae_series=settings.sae_series, |
| 304 | + update_data=[{"logits": result} for result in results], |
| 305 | + start_idx=0, |
| 306 | + ) |
| 307 | + logger.info("Direct logit attribution completed successfully") |
0 commit comments