Skip to content

Commit 8161b93

Browse files
Merge pull request #115 from SevKod/main
add features to top_k and various fixes
2 parents 42a2f81 + 10aecde commit 8161b93

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

src/sdialog/interpretability/__init__.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ class ActivationHook(BaseHook):
313313
:meta private:
314314
"""
315315
def __init__(self, cache_key, layer_key, agent, response_hook,
316-
steering_function=None, steering_interval=(0, -1), inspector=None):
316+
steering_function=None, steering_interval=("*", "*"), inspector=None):
317317
super().__init__(layer_key, self._hook, agent=None)
318318
self.cache_key = cache_key
319319
self.agent = agent
@@ -467,7 +467,9 @@ def __init__(self, cache_key, layer_key, agent, response_hook, inspector=None):
467467
self.agent = agent
468468
self.response_hook = response_hook
469469
self.inspector = inspector
470-
self.register(agent.base_model, self.inspector.inspect_input)
470+
# We always need a forward hook instead of a pre-hook if we want to capture the lm_head predictions.
471+
# So, no pre hook here is mandatory.
472+
self.register(agent.base_model, is_pre_hook=False)
471473

472474
# Initialize the logits cache for this response
473475
_ = self.agent._hook_response_logit[len(self.response_hook.responses)] = []
@@ -1088,15 +1090,15 @@ def top_k(self):
10881090
"""
10891091
Return top-k predicted tokens and their probabilities that led to this token being generated.
10901092
1091-
Returns a list of tuples (token_string, probability) sorted by probability (descending).
1093+
Returns a list of tuples (token_string, probability, token_id) sorted by probability (descending).
10921094
The number of predictions (k) is determined by the Inspector's top_k parameter.
10931095
If top_k=-1, returns all tokens in the vocabulary with their probabilities.
10941096
10951097
Note: System prompt tokens do not have top_k predictions (they are not generated).
10961098
For generated tokens, this returns the probabilities that predicted this token.
10971099
1098-
:return: List of (token_string, probability) tuples.
1099-
:rtype: List[Tuple[str, float]]
1100+
:return: List of (token_string, probability, token_id) tuples.
1101+
:rtype: List[Tuple[str, float, int]]
11001102
:raises KeyError: If logits are not available (logits hook not registered).
11011103
:raises ValueError: If called on a system prompt token.
11021104
"""
@@ -1132,6 +1134,9 @@ def top_k(self):
11321134
if token_logits.ndim > 1:
11331135
token_logits = token_logits.squeeze()
11341136

1137+
# Convert to float64 for numerical precision (matching original logit-based approach)
1138+
token_logits = token_logits.to(torch.float64)
1139+
11351140
# Apply softmax to convert logits to probabilities
11361141
probs = torch.softmax(token_logits, dim=-1)
11371142

@@ -1158,7 +1163,8 @@ def top_k(self):
11581163
sorted_indices = sorted_indices.unsqueeze(0)
11591164
for prob, idx in zip(sorted_probs.tolist(), sorted_indices.tolist()):
11601165
# Decode the token ID
1161-
token_str = tokenizer.decode([int(idx)], skip_special_tokens=False)
1162-
results.append((token_str, float(prob)))
1166+
token_id = int(idx)
1167+
token_str = tokenizer.decode([token_id], skip_special_tokens=False)
1168+
results.append((token_str, float(prob), token_id))
11631169

11641170
return results

0 commit comments

Comments
 (0)