66
77from abc import ABC , abstractmethod
88from dataclasses import dataclass
9+ from types import ModuleType
910from typing import Sequence
1011
1112from torch import nn
@@ -20,34 +21,41 @@ class DecoderTransform(ABC):
2021 should be both faster and more memory efficient than receiving normally
2122 decoded frames and applying the same kind of transform.
2223
23- Most `DecoderTransform` objects have a complementary transform in TorchVision,
24- specificially in
25- `torchvision.transforms.v2 <https://docs.pytorch.org/vision/stable/transforms.html#v2-api-reference-recommended>`_.
26- For such transforms, we ensure that:
24+ Most ``DecoderTransform`` objects have a complementary transform in TorchVision,
25+ specificially in `torchvision.transforms.v2 <https://docs.pytorch.org/vision/stable/transforms.html>`_. For such transforms, we
26+ ensure that:
2727
2828 1. The names are the same.
2929 2. Default behaviors are the same.
30- 3. The parameters for the `DecoderTransform` object are a subset of the
31- TorchVision transform object.
30+ 3. The parameters for the `` DecoderTransform` ` object are a subset of the
31+ TorchVision :class:`~torchvision.transforms.v2.Transform` object.
3232 4. Parameters with the same name control the same behavior and accept a
3333 subset of the same types.
3434 5. The difference between the frames returned by a decoder transform and
35- the complementary TorchVision transform are small.
36-
37- All decoder transforms are applied in the output pixel format and colorspace.
35+ the complementary TorchVision transform are such that a model should
36+ not be able to tell the difference.
3837 """
3938
4039 @abstractmethod
41- def _make_params (self ) -> str :
40+ def _make_transform_spec (self ) -> str :
4241 pass
4342
4443
44+ def import_torchvision_transforms_v2 () -> ModuleType :
45+ try :
46+ from torchvision .transforms import v2
47+ except ImportError as e :
48+ raise RuntimeError (
49+ "Cannot import TorchVision; this should never happen, please report a bug."
50+ ) from e
51+ return v2
52+
53+
4554@dataclass
4655class Resize (DecoderTransform ):
4756 """Resize the decoded frame to a given size.
4857
49- Complementary TorchVision transform:
50- `torchvision.transforms.v2.Resize <https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.v2.Resize.html#torchvision.transforms.v2.Resize>`_.
58+ Complementary TorchVision transform: :class:`~torchvision.transforms.v2.Resize`.
5159 Interpolation is always bilinear. Anti-aliasing is always on.
5260
5361 Args:
@@ -57,13 +65,13 @@ class Resize(DecoderTransform):
5765
5866 size : Sequence [int ]
5967
60- def _make_params (self ) -> str :
68+ def _make_transform_spec (self ) -> str :
6169 assert len (self .size ) == 2
6270 return f"resize, { self .size [0 ]} , { self .size [1 ]} "
6371
6472 @classmethod
6573 def _from_torchvision (cls , resize_tv : nn .Module ):
66- from torchvision . transforms import v2
74+ v2 = import_torchvision_transforms_v2 ()
6775
6876 assert isinstance (resize_tv , v2 .Resize )
6977
0 commit comments