@@ -313,7 +313,7 @@ class ActivationHook(BaseHook):
313313 :meta private:
314314 """
315315 def __init__ (self , cache_key , layer_key , agent , response_hook ,
316- steering_function = None , steering_interval = (0 , - 1 ), inspector = None ):
316+ steering_function = None , steering_interval = ("*" , "*" ), inspector = None ):
317317 super ().__init__ (layer_key , self ._hook , agent = None )
318318 self .cache_key = cache_key
319319 self .agent = agent
@@ -467,7 +467,9 @@ def __init__(self, cache_key, layer_key, agent, response_hook, inspector=None):
467467 self .agent = agent
468468 self .response_hook = response_hook
469469 self .inspector = inspector
470- self .register (agent .base_model , self .inspector .inspect_input )
470+ # We always need a forward hook instead of a pre-hook if we want to capture the lm_head predictions.
471+ # So, no pre hook here is mandatory.
472+ self .register (agent .base_model , is_pre_hook = False )
471473
472474 # Initialize the logits cache for this response
473475 _ = self .agent ._hook_response_logit [len (self .response_hook .responses )] = []
@@ -1088,15 +1090,15 @@ def top_k(self):
10881090 """
10891091 Return top-k predicted tokens and their probabilities that led to this token being generated.
10901092
1091- Returns a list of tuples (token_string, probability) sorted by probability (descending).
1093+ Returns a list of tuples (token_string, probability, token_id ) sorted by probability (descending).
10921094 The number of predictions (k) is determined by the Inspector's top_k parameter.
10931095 If top_k=-1, returns all tokens in the vocabulary with their probabilities.
10941096
10951097 Note: System prompt tokens do not have top_k predictions (they are not generated).
10961098 For generated tokens, this returns the probabilities that predicted this token.
10971099
1098- :return: List of (token_string, probability) tuples.
1099- :rtype: List[Tuple[str, float]]
1100+ :return: List of (token_string, probability, token_id ) tuples.
1101+ :rtype: List[Tuple[str, float, int ]]
11001102 :raises KeyError: If logits are not available (logits hook not registered).
11011103 :raises ValueError: If called on a system prompt token.
11021104 """
@@ -1132,6 +1134,9 @@ def top_k(self):
11321134 if token_logits .ndim > 1 :
11331135 token_logits = token_logits .squeeze ()
11341136
1137+ # Convert to float64 for numerical precision (matching original logit-based approach)
1138+ token_logits = token_logits .to (torch .float64 )
1139+
11351140 # Apply softmax to convert logits to probabilities
11361141 probs = torch .softmax (token_logits , dim = - 1 )
11371142
@@ -1158,7 +1163,8 @@ def top_k(self):
11581163 sorted_indices = sorted_indices .unsqueeze (0 )
11591164 for prob , idx in zip (sorted_probs .tolist (), sorted_indices .tolist ()):
11601165 # Decode the token ID
1161- token_str = tokenizer .decode ([int (idx )], skip_special_tokens = False )
1162- results .append ((token_str , float (prob )))
1166+ token_id = int (idx )
1167+ token_str = tokenizer .decode ([token_id ], skip_special_tokens = False )
1168+ results .append ((token_str , float (prob ), token_id ))
11631169
11641170 return results
0 commit comments