88import json
99import numbers
1010from pathlib import Path
11- from typing import Literal , Optional , Sequence , Tuple , Union
11+ from typing import List , Literal , Optional , Sequence , Tuple , Union
1212
1313import torch
1414from torch import device as torch_device , Tensor
1919 create_decoder ,
2020 ERROR_REPORTING_INSTRUCTIONS ,
2121)
22- from torchcodec .transforms import DecoderNativeTransform , Resize
22+ from torchcodec .transforms import DecoderTransform , Resize
2323
2424
2525class VideoDecoder :
@@ -67,6 +67,10 @@ class VideoDecoder:
6767 probably is. Default: "exact".
6868 Read more about this parameter in:
6969 :ref:`sphx_glr_generated_examples_decoding_approximate_mode.py`
70+ transforms (sequence of transform objects, optional): Sequence of transforms to be
71+ applied to the decoded frames by the decoder itself, in order. Accepts both
72+ torchcodec.transforms.DecoderTransform and torchvision.transforms.v2.Transform
73+ objects. All transforms are applied in the ouput pixel format and colorspace.
7074 custom_frame_mappings (str, bytes, or file-like object, optional):
7175 Mapping of frames to their metadata, typically generated via ffprobe.
7276 This enables accurate frame seeking without requiring a full video scan.
@@ -104,8 +108,8 @@ def __init__(
104108 dimension_order : Literal ["NCHW" , "NHWC" ] = "NCHW" ,
105109 num_ffmpeg_threads : int = 1 ,
106110 device : Optional [Union [str , torch_device ]] = "cpu" ,
107- transforms : Optional [Sequence [DecoderNativeTransform ]] = None ,
108111 seek_mode : Literal ["exact" , "approximate" ] = "exact" ,
112+ transforms : Optional [Sequence [DecoderTransform ]] = None ,
109113 custom_frame_mappings : Optional [
110114 Union [str , bytes , io .RawIOBase , io .BufferedReader ]
111115 ] = None ,
@@ -437,15 +441,23 @@ def _get_and_validate_stream_metadata(
437441 )
438442
439443
440- # This function, _make_transform_specs, and the transforms argument to
441- # VideoDecoder actually accept a union of DecoderNativeTransform and
442- # TorchVision transforms. We don't put that in our type annotation because
443- # that would require importing torchvision at module scope which would mean we
444- # have a hard dependency on torchvision.
445- # TODO: better explanation of the above.
446444def _convert_to_decoder_native_transforms (
447- transforms : Sequence [DecoderNativeTransform ],
448- ) -> Sequence [DecoderNativeTransform ]:
445+ transforms : Sequence [DecoderTransform ],
446+ ) -> List [DecoderTransform ]:
447+ """Convert a sequence of transforms that may contain TorchVision transform
448+ objects into a list of only TorchCodec transform objects.
449+
450+ Args:
451+ transforms: Squence of transform objects. The objects can be one of two
452+ types:
453+ 1. torchcodec.transforms.DecoderTransform
454+ 2. torchvision.transforms.v2.Transform
455+ Our type annotation only mentions the first type so that we don't
456+ have a hard dependency on TorchVision.
457+
458+ Returns:
459+ List of DecoderTransform objects.
460+ """
449461 try :
450462 from torchvision .transforms import v2
451463
@@ -455,11 +467,11 @@ def _convert_to_decoder_native_transforms(
455467
456468 converted_transforms = []
457469 for transform in transforms :
458- if not isinstance (transform , DecoderNativeTransform ):
470+ if not isinstance (transform , DecoderTransform ):
459471 if not tv_available :
460472 raise ValueError (
461473 f"The supplied transform, { transform } , is not a TorchCodec "
462- " DecoderNativeTransform . TorchCodec also accept TorchVision "
474+ " DecoderTransform . TorchCodec also accept TorchVision "
463475 "v2 transforms, but TorchVision is not installed."
464476 )
465477 if isinstance (transform , v2 .Resize ):
@@ -472,7 +484,7 @@ def _convert_to_decoder_native_transforms(
472484 else :
473485 raise ValueError (
474486 f"Unsupported transform: { transform } . Transforms must be "
475- "either a TorchCodec DecoderNativeTransform or a TorchVision "
487+ "either a TorchCodec DecoderTransform or a TorchVision "
476488 "v2 transform."
477489 )
478490 else :
@@ -482,8 +494,23 @@ def _convert_to_decoder_native_transforms(
482494
483495
484496def _make_transform_specs (
485- transforms : Optional [Sequence [DecoderNativeTransform ]],
497+ transforms : Optional [Sequence [DecoderTransform ]],
486498) -> str :
499+ """Given a sequence of transforms, turn those into the specification string
500+ the core API expects.
501+
502+ Args:
503+ transforms: Optional sequence of transform objects. The objects can be
504+ one of two types:
505+ 1. torchcodec.transforms.DecoderTransform
506+ 2. torchvision.transforms.v2.Transform
507+ Our type annotation only mentions the first type so that we don't
508+ have a hard dependency on TorchVision.
509+
510+ Returns:
511+ String of transforms in the format the core API expects: transform
512+ specifications separate by semicolons.
513+ """
487514 if transforms is None :
488515 return ""
489516
0 commit comments