Skip to content

Commit dc16154

Browse files
authored
Add batch decoding support to CUDA (#319)
1 parent dd9ffb3 commit dc16154

File tree

8 files changed

+170
-58
lines changed

8 files changed

+170
-58
lines changed
Lines changed: 67 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import argparse
12
from pathlib import Path
23
from time import perf_counter_ns
34

@@ -45,51 +46,76 @@ def report_stats(times, num_frames, unit="ms"):
4546
return med, fps
4647

4748

48-
def sample(sampler, **kwargs):
49-
decoder = VideoDecoder(VIDEO_PATH)
49+
def sample(decoder, sampler, **kwargs):
5050
return sampler(
5151
decoder,
5252
num_frames_per_clip=10,
5353
**kwargs,
5454
)
5555

5656

57-
VIDEO_PATH = Path(__file__).parent / "../../test/resources/nasa_13013.mp4"
58-
NUM_EXP = 30
59-
60-
for num_clips in (1, 50):
61-
print("-" * 10)
62-
print(f"{num_clips = }")
63-
64-
print("clips_at_random_indices ", end="")
65-
times, num_frames = bench(
66-
sample, clips_at_random_indices, num_clips=num_clips, num_exp=NUM_EXP, warmup=2
67-
)
68-
report_stats(times, num_frames, unit="ms")
69-
70-
print("clips_at_regular_indices ", end="")
71-
times, num_frames = bench(
72-
sample, clips_at_regular_indices, num_clips=num_clips, num_exp=NUM_EXP, warmup=2
73-
)
74-
report_stats(times, num_frames, unit="ms")
75-
76-
print("clips_at_random_timestamps ", end="")
77-
times, num_frames = bench(
78-
sample,
79-
clips_at_random_timestamps,
80-
num_clips=num_clips,
81-
num_exp=NUM_EXP,
82-
warmup=2,
83-
)
84-
report_stats(times, num_frames, unit="ms")
85-
86-
print("clips_at_regular_timestamps ", end="")
87-
seconds_between_clip_starts = 13 / num_clips # approximate. video is 13s long
88-
times, num_frames = bench(
89-
sample,
90-
clips_at_regular_timestamps,
91-
seconds_between_clip_starts=seconds_between_clip_starts,
92-
num_exp=NUM_EXP,
93-
warmup=2,
94-
)
95-
report_stats(times, num_frames, unit="ms")
57+
def run_sampler_benchmarks(device, video):
58+
NUM_EXP = 30
59+
60+
for num_clips in (1, 50):
61+
print("-" * 10)
62+
print(f"{num_clips = }")
63+
64+
print("clips_at_random_indices ", end="")
65+
decoder = VideoDecoder(video, device=device)
66+
times, num_frames = bench(
67+
sample,
68+
decoder,
69+
clips_at_random_indices,
70+
num_clips=num_clips,
71+
num_exp=NUM_EXP,
72+
warmup=2,
73+
)
74+
report_stats(times, num_frames, unit="ms")
75+
76+
print("clips_at_regular_indices ", end="")
77+
times, num_frames = bench(
78+
sample,
79+
decoder,
80+
clips_at_regular_indices,
81+
num_clips=num_clips,
82+
num_exp=NUM_EXP,
83+
warmup=2,
84+
)
85+
report_stats(times, num_frames, unit="ms")
86+
87+
print("clips_at_random_timestamps ", end="")
88+
times, num_frames = bench(
89+
sample,
90+
decoder,
91+
clips_at_random_timestamps,
92+
num_clips=num_clips,
93+
num_exp=NUM_EXP,
94+
warmup=2,
95+
)
96+
report_stats(times, num_frames, unit="ms")
97+
98+
print("clips_at_regular_timestamps ", end="")
99+
seconds_between_clip_starts = 13 / num_clips # approximate. video is 13s long
100+
times, num_frames = bench(
101+
sample,
102+
decoder,
103+
clips_at_regular_timestamps,
104+
seconds_between_clip_starts=seconds_between_clip_starts,
105+
num_exp=NUM_EXP,
106+
warmup=2,
107+
)
108+
report_stats(times, num_frames, unit="ms")
109+
110+
111+
def main():
112+
DEFAULT_VIDEO_PATH = Path(__file__).parent / "../../test/resources/nasa_13013.mp4"
113+
parser = argparse.ArgumentParser()
114+
parser.add_argument("--device", type=str, default="cpu")
115+
parser.add_argument("--video", type=str, default=str(DEFAULT_VIDEO_PATH))
116+
args = parser.parse_args()
117+
run_sampler_benchmarks(args.device, args.video)
118+
119+
120+
if __name__ == "__main__":
121+
main()

src/torchcodec/decoders/_core/CPUOnlyDevice.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ void convertAVFrameToDecodedOutputOnCuda(
1919
const VideoDecoder::VideoStreamDecoderOptions& options,
2020
AVCodecContext* codecContext,
2121
VideoDecoder::RawDecodedOutput& rawOutput,
22-
VideoDecoder::DecodedOutput& output) {
22+
VideoDecoder::DecodedOutput& output,
23+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
2324
throwUnsupportedDeviceError(device);
2425
}
2526

src/torchcodec/decoders/_core/CudaDevice.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@ void convertAVFrameToDecodedOutputOnCuda(
201201
const VideoDecoder::VideoStreamDecoderOptions& options,
202202
AVCodecContext* codecContext,
203203
VideoDecoder::RawDecodedOutput& rawOutput,
204-
VideoDecoder::DecodedOutput& output) {
204+
VideoDecoder::DecodedOutput& output,
205+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
205206
AVFrame* src = rawOutput.frame.get();
206207

207208
TORCH_CHECK(
@@ -213,7 +214,21 @@ void convertAVFrameToDecodedOutputOnCuda(
213214
NppiSize oSizeROI = {width, height};
214215
Npp8u* input[2] = {src->data[0], src->data[1]};
215216
torch::Tensor& dst = output.frame;
216-
dst = allocateDeviceTensor({height, width, 3}, options.device);
217+
if (preAllocatedOutputTensor.has_value()) {
218+
dst = preAllocatedOutputTensor.value();
219+
auto shape = dst.sizes();
220+
TORCH_CHECK(
221+
(shape.size() == 3) && (shape[0] == height) && (shape[1] == width) &&
222+
(shape[2] == 3),
223+
"Expected tensor of shape ",
224+
height,
225+
"x",
226+
width,
227+
"x3, got ",
228+
shape);
229+
} else {
230+
dst = allocateDeviceTensor({height, width, 3}, options.device);
231+
}
217232

218233
// Use the user-requested GPU for running the NPP kernel.
219234
c10::cuda::CUDAGuard deviceGuard(device);

src/torchcodec/decoders/_core/DeviceInterface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ void convertAVFrameToDecodedOutputOnCuda(
3737
const VideoDecoder::VideoStreamDecoderOptions& options,
3838
AVCodecContext* codecContext,
3939
VideoDecoder::RawDecodedOutput& rawOutput,
40-
VideoDecoder::DecodedOutput& output);
40+
VideoDecoder::DecodedOutput& output,
41+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
4142

4243
void releaseContextOnCuda(
4344
const torch::Device& device,

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
196196
options.height.value_or(*metadata.height),
197197
options.width.value_or(*metadata.width),
198198
3},
199-
{torch::kUInt8})),
199+
at::TensorOptions(options.device).dtype(torch::kUInt8))),
200200
ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
201201
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {}
202202

@@ -855,17 +855,18 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
855855
output.duration = getDuration(frame);
856856
output.durationSeconds = ptsToSeconds(
857857
getDuration(frame), formatContext_->streams[streamIndex]->time_base);
858+
// TODO: we should fold preAllocatedOutputTensor into RawDecodedOutput.
858859
if (streamInfo.options.device.type() == torch::kCPU) {
859860
convertAVFrameToDecodedOutputOnCPU(
860861
rawOutput, output, preAllocatedOutputTensor);
861862
} else if (streamInfo.options.device.type() == torch::kCUDA) {
862-
// TODO: handle pre-allocated output tensor
863863
convertAVFrameToDecodedOutputOnCuda(
864864
streamInfo.options.device,
865865
streamInfo.options,
866866
streamInfo.codecContext.get(),
867867
rawOutput,
868-
output);
868+
output,
869+
preAllocatedOutputTensor);
869870
} else {
870871
TORCH_CHECK(
871872
false, "Invalid device type: " + streamInfo.options.device.str());

src/torchcodec/decoders/_video_decoder.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pathlib import Path
99
from typing import Literal, Optional, Tuple, Union
1010

11-
from torch import Tensor
11+
from torch import device, Tensor
1212

1313
from torchcodec import Frame, FrameBatch
1414
from torchcodec.decoders import _core as core
@@ -36,19 +36,20 @@ class VideoDecoder:
3636
This can be either "NCHW" (default) or "NHWC", where N is the batch
3737
size, C is the number of channels, H is the height, and W is the
3838
width of the frames.
39-
num_ffmpeg_threads (int, optional): The number of threads to use for decoding.
40-
Use 1 for single-threaded decoding which may be best if you are running multiple
41-
instances of ``VideoDecoder`` in parallel. Use a higher number for multi-threaded
42-
decoding which is best if you are running a single instance of ``VideoDecoder``.
43-
Default: 1.
44-
4539
.. note::
4640
4741
Frames are natively decoded in NHWC format by the underlying
4842
FFmpeg implementation. Converting those into NCHW format is a
4943
cheap no-copy operation that allows these frames to be
5044
transformed using the `torchvision transforms
5145
<https://pytorch.org/vision/stable/transforms.html>`_.
46+
num_ffmpeg_threads (int, optional): The number of threads to use for decoding.
47+
Use 1 for single-threaded decoding which may be best if you are running multiple
48+
instances of ``VideoDecoder`` in parallel. Use a higher number for multi-threaded
49+
decoding which is best if you are running a single instance of ``VideoDecoder``.
50+
Default: 1.
51+
device (str or torch.device, optional): The device to use for decoding. Default: "cpu".
52+
5253
5354
Attributes:
5455
metadata (VideoStreamMetadata): Metadata of the video stream.
@@ -64,6 +65,7 @@ def __init__(
6465
stream_index: Optional[int] = None,
6566
dimension_order: Literal["NCHW", "NHWC"] = "NCHW",
6667
num_ffmpeg_threads: int = 1,
68+
device: Optional[Union[str, device]] = "cpu",
6769
):
6870
if isinstance(source, str):
6971
self._decoder = core.create_from_file(source)
@@ -92,6 +94,7 @@ def __init__(
9294
stream_index=stream_index,
9395
dimension_order=dimension_order,
9496
num_threads=num_ffmpeg_threads,
97+
device=device,
9598
)
9699

97100
self.metadata, self.stream_index = _get_and_validate_stream_metadata(

test/decoders/test_video_decoder_ops.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@
3636
seek_to_pts,
3737
)
3838

39-
from ..utils import assert_tensor_equal, NASA_AUDIO, NASA_VIDEO, needs_cuda
39+
from ..utils import (
40+
assert_tensor_close_on_at_least,
41+
assert_tensor_equal,
42+
NASA_AUDIO,
43+
NASA_VIDEO,
44+
needs_cuda,
45+
)
4046

4147
torch._dynamo.config.capture_dynamic_output_shape_ops = True
4248

@@ -137,6 +143,24 @@ def test_get_frames_at_indices(self):
137143
assert_tensor_equal(frames0and180[0], reference_frame0)
138144
assert_tensor_equal(frames0and180[1], reference_frame180)
139145

146+
@needs_cuda
147+
def test_get_frames_at_indices_with_cuda(self):
148+
decoder = create_from_file(str(NASA_VIDEO.path))
149+
scan_all_streams_to_update_metadata(decoder)
150+
add_video_stream(decoder, device="cuda")
151+
frames0and180, *_ = get_frames_at_indices(
152+
decoder, stream_index=3, frame_indices=[0, 180]
153+
)
154+
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
155+
reference_frame180 = NASA_VIDEO.get_frame_data_by_index(
156+
INDEX_OF_FRAME_AT_6_SECONDS
157+
)
158+
assert frames0and180.device.type == "cuda"
159+
assert_tensor_close_on_at_least(frames0and180[0].to("cpu"), reference_frame0)
160+
assert_tensor_close_on_at_least(
161+
frames0and180[1].to("cpu"), reference_frame180, 0.3, 30
162+
)
163+
140164
def test_get_frames_at_indices_unsorted_indices(self):
141165
decoder = create_from_file(str(NASA_VIDEO.path))
142166
_add_video_stream(decoder)
@@ -198,6 +222,40 @@ def test_get_frames_by_pts(self):
198222
with pytest.raises(AssertionError):
199223
assert_tensor_equal(frames[0], frames[-1])
200224

225+
# TODO: Figure out how to parameterize this test to run on both CPU and CUDA.abs
226+
# The question is how to have the @needs_cuda decorator with the pytest.mark.parametrize
227+
# decorator on the same test.
228+
@needs_cuda
229+
def test_get_frames_by_pts_with_cuda(self):
230+
decoder = create_from_file(str(NASA_VIDEO.path))
231+
_add_video_stream(decoder, device="cuda")
232+
scan_all_streams_to_update_metadata(decoder)
233+
stream_index = 3
234+
235+
# Note: 13.01 should give the last video frame for the NASA video
236+
timestamps = [2, 0, 1, 0 + 1e-3, 13.01, 2 + 1e-3]
237+
238+
expected_frames = [
239+
get_frame_at_pts(decoder, seconds=pts)[0] for pts in timestamps
240+
]
241+
242+
frames, *_ = get_frames_by_pts(
243+
decoder,
244+
stream_index=stream_index,
245+
timestamps=timestamps,
246+
)
247+
for frame, expected_frame in zip(frames, expected_frames):
248+
assert_tensor_equal(frame, expected_frame)
249+
250+
# first and last frame should be equal, at pts=2 [+ eps]. We then modify
251+
# the first frame and assert that it's now different from the last
252+
# frame. This ensures a copy was properly made during the de-duplication
253+
# logic.
254+
assert_tensor_equal(frames[0], frames[-1])
255+
frames[0] += 20
256+
with pytest.raises(AssertionError):
257+
assert_tensor_equal(frames[0], frames[-1])
258+
201259
def test_pts_apis_against_index_ref(self):
202260
# Non-regression test for https://github.com/pytorch/torchcodec/pull/287
203261
# Get all frames in the video, then query all frames with all time-based
@@ -657,8 +715,8 @@ def test_cuda_decoder(self):
657715
assert frame0.device.type == "cuda"
658716
frame0_cpu = frame0.to("cpu")
659717
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
660-
# GPU decode is not bit-accurate. In the following assertion we ensure
661-
# not more than 0.3% of values have a difference greater than 20.
718+
# GPU decode is not bit-accurate. So we allow some tolerance.
719+
assert_tensor_close_on_at_least(frame0_cpu, reference_frame0)
662720
diff = (reference_frame0.float() - frame0_cpu.float()).abs()
663721
assert (diff > 20).float().mean() <= 0.003
664722
assert pts == torch.tensor([0])

test/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ def assert_tensor_equal(*args, **kwargs):
3333
torch.testing.assert_close(*args, **kwargs, atol=absolute_tolerance, rtol=0)
3434

3535

36+
# Asserts that at least `percentage`% of the values are within the absolute tolerance.
37+
def assert_tensor_close_on_at_least(frame1, frame2, percentage=99.7, abs_tolerance=20):
38+
diff = (frame2.float() - frame1.float()).abs()
39+
diff_percentage = 100.0 - percentage
40+
assert (diff > abs_tolerance).float().mean() <= diff_percentage / 100.0
41+
42+
3643
# For use with floating point metadata, or in other instances where we are not confident
3744
# that reference and test tensors can be exactly equal. This is true for pts and duration
3845
# in seconds, as the reference values are from ffprobe's JSON output. In that case, it is

0 commit comments

Comments
 (0)