Skip to content

Commit 8a33609

Browse files
committed
onload from dataloader
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 9a8dc1f commit 8a33609

File tree

2 files changed

+92
-72
lines changed

2 files changed

+92
-72
lines changed

src/llmcompressor/pipelines/cache.py

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def from_dataloader(
8383
for key, value in batch.items():
8484
if mask_padding and (key == "input_ids") and "attention_mask" in batch:
8585
value = cls._mask_padding(value, batch["attention_mask"])
86-
values[key] = IntermediateValue(value=value, device=model_device)
86+
values[key] = cls._offload_value(value, offload_device, model_device)
8787

8888
batch_intermediates.append(values)
8989

@@ -114,7 +114,8 @@ def update(self, batch_index: int, values: Dict[str, Any]):
114114
:param batch_index: index of batch whose values will be updated
115115
:param values: dictionary mapping keys to values used for update
116116
"""
117-
intermediates = {k: self._offload_value(v) for k, v in values.items()}
117+
device = self.offload_device
118+
intermediates = {k: self._offload_value(v, device) for k, v in values.items()}
118119
self.batch_intermediates[batch_index].update(intermediates)
119120

120121
def delete(self, batch_index: int, consumed_names: Optional[List[str]] = None):
@@ -189,59 +190,80 @@ def __iter__(self) -> Generator[Any, None, None]:
189190
def __len__(self) -> int:
190191
return len(self.batch_intermediates)
191192

192-
def _onload_value(self, intermediate: IntermediateValue) -> Any:
193+
@classmethod
194+
def _onload_value(cls, intermediate: IntermediateValue) -> Any:
195+
"""
196+
Onload a value's tensors to the onload device
197+
198+
:param intermediate: intermediates value representation to onload
199+
:return: original value with tensors onloaded to the onload device
200+
"""
193201
value = intermediate.value
194202
device = intermediate.device
195203

196204
match value:
197205
case torch.Tensor():
198206
return value.to(device=device)
199207
case list():
200-
return [self._onload_value(v) for v in value]
208+
return [cls._onload_value(v) for v in value]
201209
case tuple():
202-
return tuple(self._onload_value(v) for v in value)
210+
return tuple(cls._onload_value(v) for v in value)
203211
case dict():
204-
return {k: self._onload_value(v) for k, v in value.items()}
212+
return {k: cls._onload_value(v) for k, v in value.items()}
205213
case _ if is_dataclass(value):
206214
for field in fields(value):
207215
v = getattr(value, field.name)
208-
setattr(value, field.name, self._onload_value(v))
216+
setattr(value, field.name, cls._onload_value(v))
209217
return value
210218
case _:
211219
# handles primitive values that should be returned as is.
212220
# without this, a MatchError would be raised for unhandled types.
213221
return value
214222

215-
def _offload_value(self, value: Any) -> IntermediateValue:
223+
@classmethod
224+
def _offload_value(
225+
cls,
226+
value: Any,
227+
offload_device: torch.device | None,
228+
onload_device: Optional[torch.device] = None,
229+
) -> IntermediateValue:
230+
"""
231+
Offload a value's tensors to the offload device
232+
233+
:param value: value to offload
234+
:param offload_device: device to offload `torch.Tensor` values to
235+
:param onload_device: device used when onloading `torch.Tensor` values.
236+
If None is provided, use the tensor's current device
237+
:return: Instance of IntermediateValue representing the offloaded value
238+
"""
239+
kwargs = {"offload_device": offload_device, "onload_device": onload_device}
216240
match value:
217241
case torch.Tensor():
218242
return IntermediateValue(
219-
value=(
220-
value
221-
if self.offload_device is None
222-
else value.to(device=self.offload_device)
223-
),
224-
device=value.device,
243+
value=value.to(device=offload_device),
244+
device=(onload_device if onload_device else value.device),
225245
)
226246
case list():
227247
return IntermediateValue(
228-
value=[self._offload_value(v) for v in value],
248+
value=[cls._offload_value(v, **kwargs) for v in value],
229249
device=None,
230250
)
231251
case tuple():
232252
return IntermediateValue(
233-
value=tuple(self._offload_value(v) for v in value),
253+
value=tuple(cls._offload_value(v, **kwargs) for v in value),
234254
device=None,
235255
)
236256
case dict():
237257
return IntermediateValue(
238-
value={k: self._offload_value(v) for k, v in value.items()},
258+
value={
259+
k: cls._offload_value(v, **kwargs) for k, v in value.items()
260+
},
239261
device=None,
240262
)
241263
case _ if is_dataclass(value):
242264
for field in fields(value):
243265
v = getattr(value, field.name)
244-
setattr(value, field.name, self._offload_value(v))
266+
setattr(value, field.name, cls._offload_value(v, **kwargs))
245267
return IntermediateValue(value=value, device=None)
246268
case _:
247269
# handles primitive values and provides a warning for unsupported types.

tests/llmcompressor/pipelines/test_cache.py

Lines changed: 52 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
1-
from dataclasses import dataclass
1+
from dataclasses import dataclass, fields, is_dataclass
22

33
import pytest
44
import torch
55
from torch.utils.data import DataLoader, StackDataset
66

7-
from llmcompressor.pipelines.cache import IntermediatesCache, IntermediateValue
7+
from llmcompressor.pipelines.cache import IntermediatesCache
8+
9+
10+
@dataclass
11+
class SampleDataclass:
12+
a: torch.Tensor
13+
b: int
814

915

1016
@pytest.fixture
@@ -28,6 +34,14 @@ def sample_cache(sample_dataloader):
2834
)
2935

3036

37+
values_to_test = [
38+
torch.randn(2, 3).to("cpu"),
39+
SampleDataclass(a=torch.randn(2, 3), b=42),
40+
torch.float32,
41+
[1, 2, 3],
42+
]
43+
44+
3145
@pytest.mark.unit
3246
def test_initialization(sample_dataloader):
3347
cache = IntermediatesCache.from_dataloader(
@@ -95,62 +109,22 @@ def test_mask_padding():
95109

96110

97111
@pytest.mark.unit
98-
def test_offload_and_onload_tensor():
99-
cache = IntermediatesCache([], torch.device("cpu"))
100-
101-
# Test tensor offloading
102-
original_tensor = torch.randn(2, 3).to("cpu")
103-
offloaded = cache._offload_value(original_tensor)
112+
@pytest.mark.parametrize("value", values_to_test)
113+
def test_from_dataloader(value):
114+
dataset = StackDataset(value=[value])
115+
dataloader = DataLoader(dataset, batch_size=1, collate_fn=lambda x: x[0])
116+
cache = IntermediatesCache.from_dataloader(dataloader)
104117

105-
assert isinstance(offloaded, IntermediateValue)
106-
assert isinstance(offloaded.value, torch.Tensor)
107-
assert offloaded.device == original_tensor.device
108-
109-
# Test tensor onloading
110-
onloaded = cache._onload_value(offloaded)
111-
assert torch.equal(onloaded, original_tensor)
112-
113-
114-
@dataclass
115-
class SampleDataclass:
116-
a: torch.Tensor
117-
b: int
118+
onloaded = cache.fetch(0, ["value"])["value"]
119+
assert deep_equal(onloaded, value)
118120

119121

120122
@pytest.mark.unit
121-
def test_offload_and_onload_dataclass():
122-
cache = IntermediatesCache([], torch.device("cpu"))
123-
124-
# Create a sample dataclass instance
125-
sample_data = SampleDataclass(a=torch.randn(2, 3), b=42)
126-
127-
# Test dataclass offloading
128-
offloaded = cache._offload_value(sample_data)
129-
assert isinstance(offloaded, IntermediateValue)
130-
assert isinstance(offloaded.value, SampleDataclass)
131-
assert isinstance(offloaded.value.a, IntermediateValue)
132-
assert isinstance(offloaded.value.b, IntermediateValue)
133-
134-
# Test dataclass onloading
135-
onloaded = cache._onload_value(offloaded)
136-
assert onloaded == sample_data
137-
138-
139-
@pytest.mark.unit
140-
def test_offload_and_onload_dtype():
141-
cache = IntermediatesCache([], torch.device("cpu"))
142-
143-
# Create a sample dataclass instance
144-
sample_data = torch.float32
145-
146-
# Test dataclass offloading
147-
offloaded = cache._offload_value(sample_data)
148-
assert isinstance(offloaded, IntermediateValue)
149-
assert isinstance(offloaded.value, torch.dtype)
150-
151-
# Test dataclass onloading
152-
onloaded = cache._onload_value(offloaded)
153-
assert onloaded == sample_data
123+
@pytest.mark.parametrize("value", values_to_test)
124+
def test_offload_and_onload(value):
125+
offloaded = IntermediatesCache._offload_value(value, torch.device("cpu"))
126+
onloaded = IntermediatesCache._onload_value(offloaded)
127+
assert deep_equal(onloaded, value)
154128

155129

156130
@pytest.mark.unit
@@ -190,3 +164,27 @@ def test_device_handling(sample_dataloader):
190164
# Verify tensors are loaded back to GPU when fetched
191165
fetched = cache.fetch(0, ["hidden_states"])
192166
assert fetched["hidden_states"].device.type == "cuda"
167+
168+
169+
def deep_equal(a, b) -> bool:
170+
if type(a) != type(b):
171+
return False
172+
173+
match a:
174+
case torch.Tensor():
175+
return torch.equal(a, b)
176+
case list() | tuple():
177+
if len(a) != len(b):
178+
return False
179+
return all(deep_equal(_a, _b) for _a, _b in zip(a, b))
180+
case dict():
181+
if a.keys() != b.keys():
182+
return False
183+
return all(deep_equal(a[key], b[key]) for key in a.keys())
184+
case _ if is_dataclass(a):
185+
a_dict = {field: getattr(a, field.name) for field in fields(a)}
186+
b_dict = {field: getattr(b, field.name) for field in fields(b)}
187+
188+
return deep_equal(a_dict, b_dict)
189+
case _:
190+
return a == b

0 commit comments

Comments
 (0)