Skip to content

Commit ffbdf4e

Browse files
committed
remove device arg instead use frames device
1 parent 7e5e6d4 commit ffbdf4e

File tree

5 files changed

+13
-38
lines changed

5 files changed

+13
-38
lines changed

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
#include <ATen/cuda/CUDAEvent.h>
22
#include <c10/cuda/CUDAStream.h>
3-
#include <cuda_runtime.h>
43
#include <torch/types.h>
54
#include <mutex>
65

7-
#include "CUDACommon.h"
86
#include "Cache.h"
97
#include "CudaDeviceInterface.h"
108
#include "FFMPEGCommon.h"

src/torchcodec/_core/custom_ops.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3737
m.def(
3838
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
3939
m.def(
40-
"encode_video_to_file(Tensor frames, float frame_rate, str filename, str device=\"cpu\", str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
40+
"encode_video_to_file(Tensor frames, float frame_rate, str filename, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
4141
m.def(
42-
"encode_video_to_tensor(Tensor frames, float frame_rate, str format, str device=\"cpu\", str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> Tensor");
42+
"encode_video_to_tensor(Tensor frames, float frame_rate, str format, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> Tensor");
4343
m.def(
44-
"_encode_video_to_file_like(Tensor frames, float frame_rate, str format, int file_like_context, str device=\"cpu\",str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
44+
"_encode_video_to_file_like(Tensor frames, float frame_rate, str format, int file_like_context, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
4545
m.def(
4646
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
4747
m.def(
@@ -640,14 +640,13 @@ void encode_video_to_file(
640640
const at::Tensor& frames,
641641
double frame_rate,
642642
std::string_view file_name,
643-
std::string_view device = "cpu",
644643
std::optional<std::string_view> codec = std::nullopt,
645644
std::optional<std::string_view> pixel_format = std::nullopt,
646645
std::optional<double> crf = std::nullopt,
647646
std::optional<std::string_view> preset = std::nullopt,
648647
std::optional<std::vector<std::string>> extra_options = std::nullopt) {
649648
VideoStreamOptions videoStreamOptions;
650-
videoStreamOptions.device = torch::Device(std::string(device));
649+
videoStreamOptions.device = frames.device();
651650
videoStreamOptions.codec = std::move(codec);
652651
videoStreamOptions.pixelFormat = std::move(pixel_format);
653652
videoStreamOptions.crf = crf;
@@ -665,15 +664,14 @@ at::Tensor encode_video_to_tensor(
665664
const at::Tensor& frames,
666665
double frame_rate,
667666
std::string_view format,
668-
std::string_view device = "cpu",
669667
std::optional<std::string_view> codec = std::nullopt,
670668
std::optional<std::string_view> pixel_format = std::nullopt,
671669
std::optional<double> crf = std::nullopt,
672670
std::optional<std::string_view> preset = std::nullopt,
673671
std::optional<std::vector<std::string>> extra_options = std::nullopt) {
674672
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
675673
VideoStreamOptions videoStreamOptions;
676-
videoStreamOptions.device = torch::Device(std::string(device));
674+
videoStreamOptions.device = frames.device();
677675
videoStreamOptions.codec = std::move(codec);
678676
videoStreamOptions.pixelFormat = std::move(pixel_format);
679677
videoStreamOptions.crf = crf;
@@ -698,7 +696,6 @@ void _encode_video_to_file_like(
698696
double frame_rate,
699697
std::string_view format,
700698
int64_t file_like_context,
701-
std::string_view device = "cpu",
702699
std::optional<std::string_view> codec = std::nullopt,
703700
std::optional<std::string_view> pixel_format = std::nullopt,
704701
std::optional<double> crf = std::nullopt,
@@ -711,7 +708,7 @@ void _encode_video_to_file_like(
711708
std::unique_ptr<AVIOFileLikeContext> avioContextHolder(fileLikeContext);
712709

713710
VideoStreamOptions videoStreamOptions;
714-
videoStreamOptions.device = torch::Device(std::string(device));
711+
videoStreamOptions.device = frames.device();
715712
videoStreamOptions.codec = std::move(codec);
716713
videoStreamOptions.pixelFormat = std::move(pixel_format);
717714
videoStreamOptions.crf = crf;

src/torchcodec/_core/ops.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,6 @@ def encode_video_to_file_like(
213213
frame_rate: float,
214214
format: str,
215215
file_like: Union[io.RawIOBase, io.BufferedIOBase],
216-
device: Optional[str] = "cpu",
217216
codec: Optional[str] = None,
218217
pixel_format: Optional[str] = None,
219218
crf: Optional[Union[int, float]] = None,
@@ -223,11 +222,10 @@ def encode_video_to_file_like(
223222
"""Encode video frames to a file-like object.
224223
225224
Args:
226-
frames: Video frames tensor
225+
frames: Video frames tensor. The device of the frames tensor will be used for encoding.
227226
frame_rate: Frame rate in frames per second
228227
format: Video format (e.g., "mp4", "mov", "mkv")
229228
file_like: File-like object that supports write() and seek() methods
230-
device: Device to use for encoding (default: "cpu")
231229
codec: Optional codec name (e.g., "libx264", "h264")
232230
pixel_format: Optional pixel format (e.g., "yuv420p", "yuv444p")
233231
crf: Optional constant rate factor for encoding quality
@@ -241,7 +239,6 @@ def encode_video_to_file_like(
241239
frame_rate,
242240
format,
243241
_pybind_ops.create_file_like_context(file_like, True), # True means for writing
244-
device,
245242
codec,
246243
pixel_format,
247244
crf,
@@ -334,7 +331,6 @@ def encode_video_to_file_abstract(
334331
frames: torch.Tensor,
335332
frame_rate: float,
336333
filename: str,
337-
device: str = "cpu",
338334
codec: Optional[str] = None,
339335
pixel_format: Optional[str] = None,
340336
preset: Optional[str] = None,
@@ -349,7 +345,6 @@ def encode_video_to_tensor_abstract(
349345
frames: torch.Tensor,
350346
frame_rate: float,
351347
format: str,
352-
device: str = "cpu",
353348
codec: Optional[str] = None,
354349
pixel_format: Optional[str] = None,
355350
preset: Optional[str] = None,
@@ -365,7 +360,6 @@ def _encode_video_to_file_like_abstract(
365360
frame_rate: float,
366361
format: str,
367362
file_like_context: int,
368-
device: str = "cpu",
369363
codec: Optional[str] = None,
370364
pixel_format: Optional[str] = None,
371365
preset: Optional[str] = None,

src/torchcodec/encoders/_video_encoder.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Any, Dict, Optional, Union
33

44
import torch
5-
from torch import device as torch_device, Tensor
5+
from torch import Tensor
66

77
from torchcodec import _core
88

@@ -15,17 +15,15 @@ class VideoEncoder:
1515
tensor of shape ``(N, C, H, W)`` where N is the number of frames,
1616
C is 3 channels (RGB), H is height, and W is width.
1717
Values must be uint8 in the range ``[0, 255]``.
18+
The device of the frames tensor will be used for encoding.
1819
frame_rate (float): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate.
19-
device (str or torch.device, optional): The device to use for encoding. Default: "cpu".
20-
If you pass a CUDA device, frames will be encoded on GPU.
2120
"""
2221

2322
def __init__(
2423
self,
2524
frames: Tensor,
2625
*,
2726
frame_rate: float,
28-
device: Optional[Union[str, torch_device]] = "cpu",
2927
):
3028
torch._C._log_api_usage_once("torchcodec.encoders.VideoEncoder")
3129
if not isinstance(frames, Tensor):
@@ -37,13 +35,8 @@ def __init__(
3735
if frame_rate <= 0:
3836
raise ValueError(f"{frame_rate = } must be > 0.")
3937

40-
# Validate and store device
41-
if isinstance(device, torch_device):
42-
device = str(device)
43-
4438
self._frames = frames
4539
self._frame_rate = frame_rate
46-
self._device = device
4740

4841
def to_file(
4942
self,
@@ -86,7 +79,6 @@ def to_file(
8679
frames=self._frames,
8780
frame_rate=self._frame_rate,
8881
filename=str(dest),
89-
device=self._device,
9082
codec=codec,
9183
pixel_format=pixel_format,
9284
crf=crf,
@@ -139,7 +131,6 @@ def to_tensor(
139131
frames=self._frames,
140132
frame_rate=self._frame_rate,
141133
format=format,
142-
device=self._device,
143134
codec=codec,
144135
pixel_format=pixel_format,
145136
crf=crf,
@@ -196,7 +187,6 @@ def to_file_like(
196187
frame_rate=self._frame_rate,
197188
format=format,
198189
file_like=file_like,
199-
device=self._device,
200190
codec=codec,
201191
pixel_format=pixel_format,
202192
crf=crf,

test/test_encoders.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -829,18 +829,16 @@ def encode_to_tensor(frames):
829829
common_params = dict(crf=0, pixel_format="yuv444p")
830830
if method == "to_file":
831831
dest = str(tmp_path / "output.mp4")
832-
VideoEncoder(frames, frame_rate=30, device=device).to_file(
833-
dest=dest, **common_params
834-
)
832+
VideoEncoder(frames, frame_rate=30).to_file(dest=dest, **common_params)
835833
with open(dest, "rb") as f:
836834
return torch.frombuffer(f.read(), dtype=torch.uint8).clone()
837835
elif method == "to_tensor":
838-
return VideoEncoder(frames, frame_rate=30, device=device).to_tensor(
836+
return VideoEncoder(frames, frame_rate=30).to_tensor(
839837
format="mp4", **common_params
840838
)
841839
elif method == "to_file_like":
842840
file_like = io.BytesIO()
843-
VideoEncoder(frames, frame_rate=30, device=device).to_file_like(
841+
VideoEncoder(frames, frame_rate=30).to_file_like(
844842
file_like, format="mp4", **common_params
845843
)
846844
return torch.frombuffer(file_like.getvalue(), dtype=torch.uint8)
@@ -1331,9 +1329,7 @@ def test_nvenc_against_ffmpeg_cli(
13311329
else:
13321330
raise
13331331

1334-
encoder = VideoEncoder(
1335-
frames=source_frames, frame_rate=frame_rate, device=device
1336-
)
1332+
encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate)
13371333

13381334
encoder_extra_options = {"qp": qp}
13391335
if codec == "av1_nvenc":

0 commit comments

Comments
 (0)