@@ -83,7 +83,7 @@ def from_dataloader(
8383 for key , value in batch .items ():
8484 if mask_padding and (key == "input_ids" ) and "attention_mask" in batch :
8585 value = cls ._mask_padding (value , batch ["attention_mask" ])
86- values [key ] = IntermediateValue (value = value , device = model_device )
86+ values [key ] = cls . _offload_value (value , offload_device , model_device )
8787
8888 batch_intermediates .append (values )
8989
@@ -114,7 +114,8 @@ def update(self, batch_index: int, values: Dict[str, Any]):
114114 :param batch_index: index of batch whose values will be updated
115115 :param values: dictionary mapping keys to values used for update
116116 """
117- intermediates = {k : self ._offload_value (v ) for k , v in values .items ()}
117+ device = self .offload_device
118+ intermediates = {k : self ._offload_value (v , device ) for k , v in values .items ()}
118119 self .batch_intermediates [batch_index ].update (intermediates )
119120
120121 def delete (self , batch_index : int , consumed_names : Optional [List [str ]] = None ):
@@ -189,59 +190,80 @@ def __iter__(self) -> Generator[Any, None, None]:
189190 def __len__ (self ) -> int :
190191 return len (self .batch_intermediates )
191192
192- def _onload_value (self , intermediate : IntermediateValue ) -> Any :
193+ @classmethod
194+ def _onload_value (cls , intermediate : IntermediateValue ) -> Any :
195+ """
196+ Onload a value's tensors to the onload device
197+
198+ :param intermediate: intermediates value representation to onload
199+ :return: original value with tensors onloaded to the onload device
200+ """
193201 value = intermediate .value
194202 device = intermediate .device
195203
196204 match value :
197205 case torch .Tensor ():
198206 return value .to (device = device )
199207 case list ():
200- return [self ._onload_value (v ) for v in value ]
208+ return [cls ._onload_value (v ) for v in value ]
201209 case tuple ():
202- return tuple (self ._onload_value (v ) for v in value )
210+ return tuple (cls ._onload_value (v ) for v in value )
203211 case dict ():
204- return {k : self ._onload_value (v ) for k , v in value .items ()}
212+ return {k : cls ._onload_value (v ) for k , v in value .items ()}
205213 case _ if is_dataclass (value ):
206214 for field in fields (value ):
207215 v = getattr (value , field .name )
208- setattr (value , field .name , self ._onload_value (v ))
216+ setattr (value , field .name , cls ._onload_value (v ))
209217 return value
210218 case _:
211219 # handles primitive values that should be returned as is.
212220 # without this, a MatchError would be raised for unhandled types.
213221 return value
214222
215- def _offload_value (self , value : Any ) -> IntermediateValue :
223+ @classmethod
224+ def _offload_value (
225+ cls ,
226+ value : Any ,
227+ offload_device : torch .device | None ,
228+ onload_device : Optional [torch .device ] = None ,
229+ ) -> IntermediateValue :
230+ """
231+ Offload a value's tensors to the offload device
232+
233+ :param value: value to offload
234+ :param offload_device: device to offload `torch.Tensor` values to
235+ :param onload_device: device used when onloading `torch.Tensor` values.
236+ If None is provided, use the tensor's current device
237+ :return: Instance of IntermediateValue representing the offloaded value
238+ """
239+ kwargs = {"offload_device" : offload_device , "onload_device" : onload_device }
216240 match value :
217241 case torch .Tensor ():
218242 return IntermediateValue (
219- value = (
220- value
221- if self .offload_device is None
222- else value .to (device = self .offload_device )
223- ),
224- device = value .device ,
243+ value = value .to (device = offload_device ),
244+ device = (onload_device if onload_device else value .device ),
225245 )
226246 case list ():
227247 return IntermediateValue (
228- value = [self ._offload_value (v ) for v in value ],
248+ value = [cls ._offload_value (v , ** kwargs ) for v in value ],
229249 device = None ,
230250 )
231251 case tuple ():
232252 return IntermediateValue (
233- value = tuple (self ._offload_value (v ) for v in value ),
253+ value = tuple (cls ._offload_value (v , ** kwargs ) for v in value ),
234254 device = None ,
235255 )
236256 case dict ():
237257 return IntermediateValue (
238- value = {k : self ._offload_value (v ) for k , v in value .items ()},
258+ value = {
259+ k : cls ._offload_value (v , ** kwargs ) for k , v in value .items ()
260+ },
239261 device = None ,
240262 )
241263 case _ if is_dataclass (value ):
242264 for field in fields (value ):
243265 v = getattr (value , field .name )
244- setattr (value , field .name , self ._offload_value (v ))
266+ setattr (value , field .name , cls ._offload_value (v , ** kwargs ))
245267 return IntermediateValue (value = value , device = None )
246268 case _:
247269 # handles primitive values and provides a warning for unsupported types.
0 commit comments