Skip to content

Commit 0d2492e

Browse files
committed
Better names, better docs
1 parent 55d362c commit 0d2492e

File tree

4 files changed

+101
-55
lines changed

4 files changed

+101
-55
lines changed

src/torchcodec/decoders/_video_decoder.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import json
99
import numbers
1010
from pathlib import Path
11-
from typing import Literal, Optional, Sequence, Tuple, Union
11+
from typing import List, Literal, Optional, Sequence, Tuple, Union
1212

1313
import torch
1414
from torch import device as torch_device, Tensor
@@ -19,7 +19,7 @@
1919
create_decoder,
2020
ERROR_REPORTING_INSTRUCTIONS,
2121
)
22-
from torchcodec.transforms import DecoderNativeTransform, Resize
22+
from torchcodec.transforms import DecoderTransform, Resize
2323

2424

2525
class VideoDecoder:
@@ -67,6 +67,10 @@ class VideoDecoder:
6767
probably is. Default: "exact".
6868
Read more about this parameter in:
6969
:ref:`sphx_glr_generated_examples_decoding_approximate_mode.py`
70+
transforms (sequence of transform objects, optional): Sequence of transforms to be
71+
applied to the decoded frames by the decoder itself, in order. Accepts both
72+
torchcodec.transforms.DecoderTransform and torchvision.transforms.v2.Transform
73+
objects. All transforms are applied in the ouput pixel format and colorspace.
7074
custom_frame_mappings (str, bytes, or file-like object, optional):
7175
Mapping of frames to their metadata, typically generated via ffprobe.
7276
This enables accurate frame seeking without requiring a full video scan.
@@ -104,8 +108,8 @@ def __init__(
104108
dimension_order: Literal["NCHW", "NHWC"] = "NCHW",
105109
num_ffmpeg_threads: int = 1,
106110
device: Optional[Union[str, torch_device]] = "cpu",
107-
transforms: Optional[Sequence[DecoderNativeTransform]] = None,
108111
seek_mode: Literal["exact", "approximate"] = "exact",
112+
transforms: Optional[Sequence[DecoderTransform]] = None,
109113
custom_frame_mappings: Optional[
110114
Union[str, bytes, io.RawIOBase, io.BufferedReader]
111115
] = None,
@@ -437,15 +441,23 @@ def _get_and_validate_stream_metadata(
437441
)
438442

439443

440-
# This function, _make_transform_specs, and the transforms argument to
441-
# VideoDecoder actually accept a union of DecoderNativeTransform and
442-
# TorchVision transforms. We don't put that in our type annotation because
443-
# that would require importing torchvision at module scope which would mean we
444-
# have a hard dependency on torchvision.
445-
# TODO: better explanation of the above.
446444
def _convert_to_decoder_native_transforms(
447-
transforms: Sequence[DecoderNativeTransform],
448-
) -> Sequence[DecoderNativeTransform]:
445+
transforms: Sequence[DecoderTransform],
446+
) -> List[DecoderTransform]:
447+
"""Convert a sequence of transforms that may contain TorchVision transform
448+
objects into a list of only TorchCodec transform objects.
449+
450+
Args:
451+
transforms: Squence of transform objects. The objects can be one of two
452+
types:
453+
1. torchcodec.transforms.DecoderTransform
454+
2. torchvision.transforms.v2.Transform
455+
Our type annotation only mentions the first type so that we don't
456+
have a hard dependency on TorchVision.
457+
458+
Returns:
459+
List of DecoderTransform objects.
460+
"""
449461
try:
450462
from torchvision.transforms import v2
451463

@@ -455,11 +467,11 @@ def _convert_to_decoder_native_transforms(
455467

456468
converted_transforms = []
457469
for transform in transforms:
458-
if not isinstance(transform, DecoderNativeTransform):
470+
if not isinstance(transform, DecoderTransform):
459471
if not tv_available:
460472
raise ValueError(
461473
f"The supplied transform, {transform}, is not a TorchCodec "
462-
" DecoderNativeTransform. TorchCodec also accept TorchVision "
474+
" DecoderTransform. TorchCodec also accept TorchVision "
463475
"v2 transforms, but TorchVision is not installed."
464476
)
465477
if isinstance(transform, v2.Resize):
@@ -472,7 +484,7 @@ def _convert_to_decoder_native_transforms(
472484
else:
473485
raise ValueError(
474486
f"Unsupported transform: {transform}. Transforms must be "
475-
"either a TorchCodec DecoderNativeTransform or a TorchVision "
487+
"either a TorchCodec DecoderTransform or a TorchVision "
476488
"v2 transform."
477489
)
478490
else:
@@ -482,8 +494,23 @@ def _convert_to_decoder_native_transforms(
482494

483495

484496
def _make_transform_specs(
485-
transforms: Optional[Sequence[DecoderNativeTransform]],
497+
transforms: Optional[Sequence[DecoderTransform]],
486498
) -> str:
499+
"""Given a sequence of transforms, turn those into the specification string
500+
the core API expects.
501+
502+
Args:
503+
transforms: Optional sequence of transform objects. The objects can be
504+
one of two types:
505+
1. torchcodec.transforms.DecoderTransform
506+
2. torchvision.transforms.v2.Transform
507+
Our type annotation only mentions the first type so that we don't
508+
have a hard dependency on TorchVision.
509+
510+
Returns:
511+
String of transforms in the format the core API expects: transform
512+
specifications separate by semicolons.
513+
"""
487514
if transforms is None:
488515
return ""
489516

src/torchcodec/transforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from ._decoder_native_transforms import DecoderNativeTransform, Resize # noqa
7+
from ._decoder_transforms import DecoderTransform, Resize # noqa

src/torchcodec/transforms/_decoder_native_transforms.py

Lines changed: 0 additions & 39 deletions
This file was deleted.
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from abc import ABC, abstractmethod
8+
from dataclasses import dataclass
9+
from typing import Sequence
10+
11+
12+
@dataclass
13+
class DecoderTransform(ABC):
14+
"""Base class for all decoder transforms.
15+
16+
A DecoderTransform is a transform that is applied by the decoder before
17+
returning the decoded frame. The implementation does not live in TorchCodec
18+
itself, but in the underyling decoder. Applying DecoderTransforms to frames
19+
should be both faster and more memory efficient than receiving normally
20+
decoded frames and applying the same kind of transform.
21+
22+
Most DecoderTransforms have a complementary transform in TorchVision,
23+
specificially in torchvision.transforms.v2. For such transforms, we ensure
24+
that:
25+
26+
1. Default behaviors are the same.
27+
2. The parameters for the DecoderTransform are a subset of the
28+
TorchVision transform.
29+
3. Parameters with the same name control the same behavior and accept a
30+
subset of the same types.
31+
4. The difference between the frames returned by a DecoderTransform and
32+
the complementary TorchVision transform are small.
33+
34+
All DecoderTranforms are applied in the output pixel format and colorspace.
35+
"""
36+
37+
@abstractmethod
38+
def make_params(self) -> str:
39+
pass
40+
41+
42+
@dataclass
43+
class Resize(DecoderTransform):
44+
"""Resize the decoded frame to a given size.
45+
46+
Complementary TorchVision transform: torchvision.transforms.v2.Resize.
47+
Interpolation is always bilinear. Anti-aliasing is always on.
48+
49+
Args:
50+
size: (sequence of int): Desired output size. Must be a sequence of
51+
the form (height, width).
52+
"""
53+
54+
size: Sequence[int]
55+
56+
def make_params(self) -> str:
57+
assert len(self.size) == 2
58+
return f"resize, {self.size[0]}, {self.size[1]}"

0 commit comments

Comments
 (0)