Skip to content

Commit 62eb585

Browse files
committed
Type checking
1 parent 705d1ef commit 62eb585

File tree

2 files changed

+27
-7
lines changed

2 files changed

+27
-7
lines changed

src/torchcodec/decoders/_video_decoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,9 @@ def _make_transform_specs(
513513
# our metadata. For each transform, we always calculate its output
514514
# dimensions from its input dimensions. We store these with the converted
515515
# transform, to be all used together when we generate the specs.
516-
converted_transforms: list[(DecoderTransform, Tuple[int, int])] = []
516+
converted_transforms: list[
517+
Tuple[DecoderTransform, Tuple[Optional[int], Optional[int]]]
518+
] = []
517519
curr_input_dims = input_dims
518520
for transform in transforms:
519521
if isinstance(transform, DecoderTransform):

src/torchcodec/transforms/_decoder_transforms.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ class DecoderTransform(ABC):
3838
"""
3939

4040
@abstractmethod
41-
def _make_transform_spec(self, input_dims: Tuple[int, int]) -> str:
41+
def _make_transform_spec(
42+
self, input_dims: Tuple[Optional[int], Optional[int]]
43+
) -> str:
4244
pass
4345

4446
def _calculate_output_dims(
@@ -71,7 +73,9 @@ class Resize(DecoderTransform):
7173

7274
size: Sequence[int]
7375

74-
def _make_transform_spec(self, input_dims: Tuple[int, int]) -> str:
76+
def _make_transform_spec(
77+
self, input_dims: Tuple[Optional[int], Optional[int]]
78+
) -> str:
7579
# TODO: establish this invariant in the constructor during refactor
7680
assert len(self.size) == 2
7781
return f"resize, {self.size[0]}, {self.size[1]}"
@@ -131,26 +135,40 @@ class RandomCrop(DecoderTransform):
131135
_top: Optional[int] = None
132136
_left: Optional[int] = None
133137

134-
def _make_transform_spec(self, input_dims: Tuple[int, int]) -> str:
138+
def _make_transform_spec(
139+
self, input_dims: Tuple[Optional[int], Optional[int]]
140+
) -> str:
135141
if len(self.size) != 2:
136142
raise ValueError(
137143
f"RandomCrop's size must be a sequence of length 2, got {self.size}. "
138144
"This should never happen, please report a bug."
139145
)
140146

147+
height, width = input_dims
148+
if height is None:
149+
raise ValueError(
150+
"Video metadata has no height. "
151+
"RandomCrop can only be used when input frame dimensions are known."
152+
)
153+
if width is None:
154+
raise ValueError(
155+
"Video metadata has no width. "
156+
"RandomCrop can only be used when input frame dimensions are known."
157+
)
158+
141159
# Note: This logic below must match the logic in
142160
# torchvision.transforms.v2.RandomCrop.make_params(). Given
143161
# the same seed, they should get the same result. This is an
144162
# API guarantee with our users.
145-
if input_dims[0] < self.size[0] or input_dims[1] < self.size[1]:
163+
if height < self.size[0] or width < self.size[1]:
146164
raise ValueError(
147165
f"Input dimensions {input_dims} are smaller than the crop size {self.size}."
148166
)
149167

150-
top = int(torch.randint(0, input_dims[0] - self.size[0] + 1, size=()).item())
168+
top = int(torch.randint(0, height - self.size[0] + 1, size=()).item())
151169
self._top = top
152170

153-
left = int(torch.randint(0, input_dims[1] - self.size[1] + 1, size=()).item())
171+
left = int(torch.randint(0, width - self.size[1] + 1, size=()).item())
154172
self._left = left
155173

156174
return f"crop, {self.size[0]}, {self.size[1]}, {left}, {top}"

0 commit comments

Comments
 (0)