Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 65 additions & 17 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,8 +762,10 @@ def __init__(
(low_priority, high_priority) = torch.cuda.Stream.priority_range()
# GPU stream for SSD cache eviction
self.ssd_eviction_stream = torch.cuda.Stream(priority=low_priority)
# GPU stream for SSD memory copy
# GPU stream for SSD memory copy (also reused for feature score D2H)
self.ssd_memcpy_stream = torch.cuda.Stream(priority=low_priority)
# GPU stream for async metadata operation
self.feature_score_stream = torch.cuda.Stream(priority=low_priority)

# SSD get completion event
self.ssd_event_get = torch.cuda.Event()
Expand Down Expand Up @@ -1675,6 +1677,56 @@ def _update_cache_counter_and_pointers(
unique_indices_length_curr=curr_data.actions_count_gpu,
)

def _update_feature_score_metadata(
self,
linear_cache_indices: Tensor,
weights: Tensor,
d2h_stream: torch.cuda.Stream,
write_stream: torch.cuda.Stream,
pre_event_for_write: torch.cuda.Event,
post_event: Optional[torch.cuda.Event] = None,
) -> None:
"""
Write feature score metadata to DRAM

This method performs D2H copy on d2h_stream, then writes to DRAM on write_stream.
The caller is responsible for ensuring d2h_stream doesn't compete with other D2H operations.

Args:
linear_cache_indices: GPU tensor containing cache indices
weights: GPU tensor containing feature scores
d2h_stream: Stream for D2H copy operation (should already be synchronized appropriately)
write_stream: Stream for metadata write operation
pre_event_for_write: Event to wait on before writing metadata (e.g., wait for eviction)
post_event: Event to record when the operation is done
"""
# Start D2H copy on d2h_stream
with torch.cuda.stream(d2h_stream):
# Record streams to prevent premature deallocation
linear_cache_indices.record_stream(d2h_stream)
weights.record_stream(d2h_stream)
# Do the D2H copy
linear_cache_indices_cpu = self.to_pinned_cpu(linear_cache_indices)
score_weights_cpu = self.to_pinned_cpu(weights)

# Write feature score metadata to DRAM
with record_function("## ssd_write_feature_score_metadata ##"):
with torch.cuda.stream(write_stream):
write_stream.wait_event(pre_event_for_write)
write_stream.wait_stream(d2h_stream)
self.record_function_via_dummy_profile(
"## ssd_write_feature_score_metadata ##",
self.ssd_db.set_feature_score_metadata_cuda,
linear_cache_indices_cpu,
torch.tensor(
[score_weights_cpu.shape[0]], device="cpu", dtype=torch.long
),
score_weights_cpu,
)

if post_event is not None:
write_stream.record_event(post_event)

def prefetch(
self,
indices: Tensor,
Expand Down Expand Up @@ -1747,12 +1799,6 @@ def _prefetch( # noqa C901

self.timestep += 1
self.timesteps_prefetched.append(self.timestep)
if self.backend_type == BackendType.DRAM and weights is not None:
# DRAM backend supports feature score eviction, if there is weights available
# in the prefetch call, we will set metadata for feature score eviction asynchronously
cloned_linear_cache_indices = linear_cache_indices.clone()
else:
cloned_linear_cache_indices = None

# Lookup and virtually insert indices into L1. After this operator,
# we know:
Expand Down Expand Up @@ -2114,16 +2160,18 @@ def _prefetch( # noqa C901
name="cache",
is_bwd=False,
)
if self.backend_type == BackendType.DRAM and weights is not None:
# Write feature score metadata to DRAM
self.record_function_via_dummy_profile(
"## ssd_write_feature_score_metadata ##",
self.ssd_db.set_feature_score_metadata_cuda,
cloned_linear_cache_indices.cpu(),
torch.tensor(
[weights.shape[0]], device="cpu", dtype=torch.long
),
weights.cpu(),
if (
self.backend_type == BackendType.DRAM
and weights is not None
and linear_cache_indices.numel() > 0
):
# Reuse ssd_memcpy_stream for feature score D2H since critical D2H is done
self._update_feature_score_metadata(
linear_cache_indices=linear_cache_indices,
weights=weights,
d2h_stream=self.ssd_memcpy_stream,
write_stream=self.feature_score_stream,
pre_event_for_write=self.ssd_event_cache_evict,
)

# Generate row addresses (pointing to either L1 or the current
Expand Down
Loading