|
11 | 11 | from typing import List, Literal, Optional, Sequence, Tuple, Union |
12 | 12 |
|
13 | 13 | import torch |
14 | | -from torch import device as torch_device, Tensor |
| 14 | +from torch import device as torch_device, nn, Tensor |
15 | 15 |
|
16 | 16 | from torchcodec import _core as core, Frame, FrameBatch |
17 | 17 | from torchcodec.decoders._decoder_utils import ( |
@@ -69,8 +69,10 @@ class VideoDecoder: |
69 | 69 | :ref:`sphx_glr_generated_examples_decoding_approximate_mode.py` |
70 | 70 | transforms (sequence of transform objects, optional): Sequence of transforms to be |
71 | 71 | applied to the decoded frames by the decoder itself, in order. Accepts both |
72 | | - ``torchcodec.transforms.DecoderTransform`` and ``torchvision.transforms.v2.Transform`` |
73 | | - objects. All transforms are applied in the ouput pixel format and colorspace. |
| 72 | + :class:`torchcodec.transforms.DecoderTransform` and |
| 73 | + :class:`torchvision.transforms.v2.Transform` objects. All transforms are applied |
| 74 | + in the ouput pixel format and colorspace. Read more about this parameter in: |
| 75 | + SCOTT_NEEDS_TO_WRITE_A_TUTORIAL. |
74 | 76 | custom_frame_mappings (str, bytes, or file-like object, optional): |
75 | 77 | Mapping of frames to their metadata, typically generated via ffprobe. |
76 | 78 | This enables accurate frame seeking without requiring a full video scan. |
@@ -109,7 +111,7 @@ def __init__( |
109 | 111 | num_ffmpeg_threads: int = 1, |
110 | 112 | device: Optional[Union[str, torch_device]] = "cpu", |
111 | 113 | seek_mode: Literal["exact", "approximate"] = "exact", |
112 | | - transforms: Optional[Sequence[DecoderTransform]] = None, |
| 114 | + transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]] = None, |
113 | 115 | custom_frame_mappings: Optional[ |
114 | 116 | Union[str, bytes, io.RawIOBase, io.BufferedReader] |
115 | 117 | ] = None, |
@@ -442,7 +444,7 @@ def _get_and_validate_stream_metadata( |
442 | 444 |
|
443 | 445 |
|
444 | 446 | def _convert_to_decoder_native_transforms( |
445 | | - transforms: Sequence[DecoderTransform], |
| 447 | + transforms: Sequence[Union[DecoderTransform, nn.Module]], |
446 | 448 | ) -> List[DecoderTransform]: |
447 | 449 | """Convert a sequence of transforms that may contain TorchVision transform |
448 | 450 | objects into a list of only TorchCodec transform objects. |
@@ -494,7 +496,7 @@ def _convert_to_decoder_native_transforms( |
494 | 496 |
|
495 | 497 |
|
496 | 498 | def _make_transform_specs( |
497 | | - transforms: Optional[Sequence[DecoderTransform]], |
| 499 | + transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]], |
498 | 500 | ) -> str: |
499 | 501 | """Given a sequence of transforms, turn those into the specification string |
500 | 502 | the core API expects. |
|
0 commit comments