Skip to content

Commit d03abb1

Browse files
committed
refactor(timer): enhance hierarchical timer functionality
1 parent c66ede8 commit d03abb1

File tree

1 file changed

+217
-61
lines changed

1 file changed

+217
-61
lines changed

src/lm_saes/utils/timer.py

Lines changed: 217 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,98 @@
11
import time
2-
from collections import defaultdict
32
from contextlib import contextmanager
4-
from typing import Dict, List
3+
from typing import Dict, List, Optional, Set
54

65
import torch
76

87

8+
class TimerNode:
9+
"""A node in the hierarchical timer tree.
10+
11+
Attributes:
12+
name: The name of this timer node.
13+
total_time: Total accumulated time for this node.
14+
count: Number of times this timer has been called.
15+
parent: Parent node in the hierarchy.
16+
children: Set of child node names.
17+
start_time: Start time if currently running.
18+
"""
19+
20+
def __init__(self, name: str, parent: Optional["TimerNode"] = None):
21+
self.name = name
22+
self.total_time = 0.0
23+
self.count = 0
24+
self.parent = parent
25+
self.children: Set[str] = set()
26+
self.start_time: Optional[float] = None
27+
28+
if parent:
29+
parent.children.add(name)
30+
31+
932
class Timer:
10-
"""A singleton timer class to track time usage in different parts of the training process.
33+
"""A singleton timer class to track time usage hierarchically in different parts of the training process.
1134
1235
This class provides methods to track time usage in different parts of the training process,
13-
such as communication vs computation. It is designed as a singleton to be accessible
14-
from anywhere in the codebase.
36+
organized in a hierarchical structure where nested timers show percentages relative to their
37+
parent timers rather than the total time. An implicit root timer captures the entire session.
1538
1639
Attributes:
1740
_instance: The singleton instance of the Timer class.
18-
_timers: Dictionary mapping timer names to their accumulated time.
19-
_start_times: Dictionary mapping timer names to their start times.
20-
_counts: Dictionary mapping timer names to the number of times they've been called.
21-
_active_timers: List of currently active timers.
41+
_nodes: Dictionary mapping timer names to their TimerNode objects.
42+
_active_stack: Stack of currently active timer names (for hierarchy).
2243
_enabled: Whether the timer is enabled.
44+
_root_start_time: Start time of the implicit root timer.
45+
_session_started: Whether a timing session has been started.
2346
"""
2447

2548
_instance = None
49+
ROOT_NAME = "__root__"
2650

2751
def __new__(cls):
2852
if cls._instance is None:
2953
cls._instance = super(Timer, cls).__new__(cls)
30-
cls._instance._timers = defaultdict(float)
31-
cls._instance._start_times = {}
32-
cls._instance._counts = defaultdict(int)
33-
cls._instance._active_timers = []
54+
cls._instance._nodes = {}
55+
cls._instance._active_stack = []
3456
cls._instance._enabled = False
57+
cls._instance._root_start_time = None
58+
cls._instance._session_started = False
3559
return cls._instance
3660

61+
def _ensure_root_started(self):
62+
"""Ensure the root timer is started for the session."""
63+
if not self._session_started and self._enabled:
64+
# Synchronize CUDA operations before timing
65+
if torch.cuda.is_available():
66+
torch.cuda.synchronize()
67+
elif torch.npu.is_available(): # type: ignore
68+
torch.npu.synchronize() # type: ignore
69+
70+
self._root_start_time = time.perf_counter()
71+
self._session_started = True
72+
73+
# Create root node
74+
self._nodes[self.ROOT_NAME] = TimerNode(self.ROOT_NAME)
75+
76+
def _finalize_root(self):
77+
"""Finalize the root timer if there are no active timers."""
78+
if (
79+
self._session_started
80+
and not self._active_stack
81+
and self._root_start_time is not None
82+
and self.ROOT_NAME in self._nodes
83+
):
84+
# Synchronize CUDA operations before timing
85+
if torch.cuda.is_available():
86+
torch.cuda.synchronize()
87+
elif torch.npu.is_available(): # type: ignore
88+
torch.npu.synchronize() # type: ignore
89+
90+
root_node = self._nodes[self.ROOT_NAME]
91+
if root_node.start_time is None: # Only finalize if not already finalized
92+
elapsed = time.perf_counter() - self._root_start_time
93+
root_node.total_time = elapsed
94+
root_node.count = 1
95+
3796
@contextmanager
3897
def time(self, name: str):
3998
"""Context manager to time a block of code.
@@ -60,17 +119,38 @@ def start(self, name: str):
60119
if not self._enabled:
61120
return
62121

63-
if name in self._start_times:
122+
if name == self.ROOT_NAME:
123+
raise ValueError(f"Timer name '{self.ROOT_NAME}' is reserved for the root timer")
124+
125+
if name in self._nodes and self._nodes[name].start_time is not None:
64126
raise ValueError(f"Timer {name} is already running")
65127

128+
# Ensure root timer is started
129+
self._ensure_root_started()
130+
66131
# Synchronize CUDA operations before timing
67132
if torch.cuda.is_available():
68133
torch.cuda.synchronize()
69-
elif torch.npu.is_available():
70-
torch.npu.synchronize()
71-
72-
self._start_times[name] = time.perf_counter()
73-
self._active_timers.append(name)
134+
elif torch.npu.is_available(): # type: ignore
135+
torch.npu.synchronize() # type: ignore
136+
137+
# Determine parent
138+
parent_node = None
139+
if self._active_stack:
140+
parent_name = self._active_stack[-1]
141+
parent_node = self._nodes[parent_name]
142+
else:
143+
# If no active stack, make this a child of root
144+
parent_node = self._nodes[self.ROOT_NAME]
145+
146+
# Create or get node
147+
if name not in self._nodes:
148+
self._nodes[name] = TimerNode(name, parent_node)
149+
else:
150+
assert self._nodes[name].parent == parent_node, f"Timer {name} has a different parent"
151+
152+
self._nodes[name].start_time = time.perf_counter()
153+
self._active_stack.append(name)
74154

75155
def stop(self, name: str):
76156
"""Stop a timer.
@@ -81,42 +161,65 @@ def stop(self, name: str):
81161
if not self._enabled:
82162
return
83163

84-
if name not in self._start_times:
164+
if name not in self._nodes or self._nodes[name].start_time is None:
85165
raise ValueError(f"Timer {name} is not running")
86166

167+
if not self._active_stack or self._active_stack[-1] != name:
168+
raise ValueError(f"Timer {name} is not the most recently started timer")
169+
87170
# Synchronize CUDA operations before timing
88171
if torch.cuda.is_available():
89172
torch.cuda.synchronize()
90-
elif torch.npu.is_available():
91-
torch.npu.synchronize()
173+
elif torch.npu.is_available(): # type: ignore
174+
torch.npu.synchronize() # type: ignore
175+
176+
node = self._nodes[name]
177+
elapsed = time.perf_counter() - node.start_time
178+
node.total_time += elapsed
179+
node.count += 1
180+
node.start_time = None
181+
self._active_stack.pop()
92182

93-
elapsed = time.perf_counter() - self._start_times[name]
94-
self._timers[name] += elapsed
95-
self._counts[name] += 1
96-
del self._start_times[name]
97-
self._active_timers.remove(name)
183+
# Finalize root timer if this was the last active timer
184+
self._finalize_root()
98185

99186
def reset(self):
100187
"""Reset all timers."""
101-
self._timers = defaultdict(float)
102-
self._start_times = {}
103-
self._counts = defaultdict(int)
104-
self._active_timers = []
188+
self._nodes = {}
189+
self._active_stack = []
190+
self._root_start_time = None
191+
self._session_started = False
105192

106193
def reset_timer(self, name: str):
107-
"""Reset a specific timer.
194+
"""Reset a specific timer and all its children.
108195
109196
Args:
110197
name: The name of the timer.
111198
"""
112-
if name in self._timers:
113-
self._timers[name] = 0.0
114-
if name in self._counts:
115-
self._counts[name] = 0
116-
if name in self._start_times:
117-
del self._start_times[name]
118-
if name in self._active_timers:
119-
self._active_timers.remove(name)
199+
if name == self.ROOT_NAME:
200+
# Reset the entire session
201+
self.reset()
202+
return
203+
204+
if name not in self._nodes:
205+
return
206+
207+
node = self._nodes[name]
208+
209+
# Reset children first
210+
for child_name in list(node.children):
211+
self.reset_timer(child_name)
212+
213+
# Remove from parent's children
214+
if node.parent:
215+
node.parent.children.discard(name)
216+
217+
# Remove from active stack if present
218+
if name in self._active_stack:
219+
self._active_stack.remove(name)
220+
221+
# Remove the node
222+
del self._nodes[name]
120223

121224
def get_time(self, name: str) -> float:
122225
"""Get the accumulated time for a timer.
@@ -127,7 +230,7 @@ def get_time(self, name: str) -> float:
127230
Returns:
128231
The accumulated time in seconds.
129232
"""
130-
return self._timers.get(name, 0.0)
233+
return self._nodes[name].total_time if name in self._nodes else 0.0
131234

132235
def get_count(self, name: str) -> int:
133236
"""Get the number of times a timer has been called.
@@ -138,7 +241,7 @@ def get_count(self, name: str) -> int:
138241
Returns:
139242
The number of times the timer has been called.
140243
"""
141-
return self._counts.get(name, 0)
244+
return self._nodes[name].count if name in self._nodes else 0
142245

143246
def get_average_time(self, name: str) -> float:
144247
"""Get the average time for a timer.
@@ -149,60 +252,113 @@ def get_average_time(self, name: str) -> float:
149252
Returns:
150253
The average time in seconds.
151254
"""
152-
count = self._counts.get(name, 0)
153-
if count == 0:
255+
if name not in self._nodes:
154256
return 0.0
155-
return self._timers.get(name, 0.0) / count
257+
node = self._nodes[name]
258+
return node.total_time / node.count if node.count > 0 else 0.0
156259

157260
def get_all_timers(self) -> Dict[str, float]:
158261
"""Get all timers.
159262
160263
Returns:
161264
Dictionary mapping timer names to their accumulated time.
162265
"""
163-
return dict(self._timers)
266+
return {name: node.total_time for name, node in self._nodes.items()}
164267

165268
def get_all_counts(self) -> Dict[str, int]:
166269
"""Get all counts.
167270
168271
Returns:
169272
Dictionary mapping timer names to their call counts.
170273
"""
171-
return dict(self._counts)
274+
return {name: node.count for name, node in self._nodes.items()}
172275

173276
def get_all_average_times(self) -> Dict[str, float]:
174277
"""Get all average times.
175278
176279
Returns:
177280
Dictionary mapping timer names to their average time.
178281
"""
179-
return {name: self.get_average_time(name) for name in self._timers}
282+
return {name: self.get_average_time(name) for name in self._nodes}
180283

181284
def get_active_timers(self) -> List[str]:
182285
"""Get all currently active timers.
183286
184287
Returns:
185-
List of active timer names.
288+
List of active timer names in stack order.
289+
"""
290+
return self._active_stack.copy()
291+
292+
def _format_node(self, name: str, depth: int = 0, parent_time: Optional[float] = None) -> List[str]:
293+
"""Format a timer node and its children for display.
294+
295+
Args:
296+
name: The name of the timer node.
297+
depth: Current depth in the hierarchy.
298+
parent_time: Total time of the parent node for percentage calculation.
299+
300+
Returns:
301+
List of formatted strings for this node and its children.
186302
"""
187-
return self._active_timers.copy()
303+
if name not in self._nodes:
304+
return []
305+
306+
node = self._nodes[name]
307+
308+
# Skip displaying the root node itself, but process its children
309+
if name == self.ROOT_NAME:
310+
result = []
311+
# Sort children by total time (descending)
312+
sorted_children = sorted(node.children, key=lambda x: self._nodes[x].total_time, reverse=True)
313+
314+
# Recursively format children with root time as parent time
315+
for child_name in sorted_children:
316+
result.extend(self._format_node(child_name, depth, node.total_time))
317+
318+
return result
319+
320+
indent = " " * depth
321+
322+
# Calculate percentage relative to parent
323+
if parent_time and parent_time > 0:
324+
percentage = (node.total_time / parent_time) * 100
325+
percentage_str = f"{percentage:.2f}% of parent"
326+
else:
327+
percentage_str = "root"
328+
329+
avg_time = node.total_time / node.count if node.count > 0 else 0
330+
331+
result = [
332+
f"{indent}{name}: {node.total_time:.4f}s total, {avg_time:.6f}s avg ({node.count} calls), {percentage_str}"
333+
]
334+
335+
# Sort children by total time (descending)
336+
sorted_children = sorted(node.children, key=lambda x: self._nodes[x].total_time, reverse=True)
337+
338+
# Recursively format children
339+
for child_name in sorted_children:
340+
result.extend(self._format_node(child_name, depth + 1, node.total_time))
341+
342+
return result
188343

189344
def summary(self) -> str:
190-
"""Get a summary of all timers.
345+
"""Get a hierarchical summary of all timers.
191346
192347
Returns:
193-
A string summarizing all timers.
348+
A string summarizing all timers in hierarchical format with percentages relative to parent nodes.
194349
"""
195-
result = []
196-
total_time = sum(self._timers.values())
350+
if not self._nodes:
351+
return "No timers recorded."
352+
353+
# Ensure root is finalized
354+
self._finalize_root()
197355

198-
for name, time_value in sorted(self._timers.items(), key=lambda x: x[1], reverse=True):
199-
count = self._counts[name]
200-
avg_time = time_value / count if count > 0 else 0
201-
percentage = (time_value / total_time * 100) if total_time > 0 else 0
356+
if self.ROOT_NAME not in self._nodes:
357+
return "No root timer found."
202358

203-
result.append(
204-
f"{name}: {time_value:.4f}s total, {avg_time:.6f}s avg ({count} calls), {percentage:.2f}% of total"
205-
)
359+
root_node = self._nodes[self.ROOT_NAME]
360+
result = [f"Total session time: {root_node.total_time:.4f}s"]
361+
result.extend(self._format_node(self.ROOT_NAME))
206362

207363
return "\n".join(result)
208364

0 commit comments

Comments
 (0)