|
9 | 9 |
|
10 | 10 | import pytest |
11 | 11 | import torch |
12 | | -from torchcodec.decoders import AudioDecoder |
| 12 | +from torchcodec.decoders import AudioDecoder, VideoDecoder |
13 | 13 |
|
14 | 14 | from torchcodec.encoders import AudioEncoder, VideoEncoder |
15 | 15 |
|
|
20 | 20 | in_fbcode, |
21 | 21 | IS_WINDOWS, |
22 | 22 | NASA_AUDIO_MP3, |
| 23 | + psnr, |
23 | 24 | SINE_MONO_S32, |
| 25 | + TEST_SRC_2_720P, |
24 | 26 | TestContainerFile, |
25 | 27 | ) |
26 | 28 |
|
@@ -567,6 +569,9 @@ def write(self, data): |
567 | 569 |
|
568 | 570 |
|
569 | 571 | class TestVideoEncoder: |
| 572 | + def decode(self, source=None) -> torch.Tensor: |
| 573 | + return VideoDecoder(source).get_frames_in_range(start=0, stop=60) |
| 574 | + |
570 | 575 | @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) |
571 | 576 | def test_bad_input_parameterized(self, tmp_path, method): |
572 | 577 | if method == "to_file": |
@@ -676,3 +681,241 @@ def encode_to_tensor(frames): |
676 | 681 | torch.testing.assert_close( |
677 | 682 | encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0 |
678 | 683 | ) |
| 684 | + |
| 685 | + @pytest.mark.parametrize( |
| 686 | + "format", ("mov", "mp4", "mkv", pytest.param("webm", marks=pytest.mark.slow)) |
| 687 | + ) |
| 688 | + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) |
| 689 | + def test_round_trip(self, tmp_path, format, method): |
| 690 | + # Test that decode(encode(decode(frames))) == decode(frames) |
| 691 | + ffmpeg_version = get_ffmpeg_major_version() |
| 692 | + # In FFmpeg6, the default codec's best pixel format is lossy for all container formats but webm. |
| 693 | + # As a result, we skip the round trip test. |
| 694 | + if ffmpeg_version == 6 and format != "webm": |
| 695 | + pytest.skip( |
| 696 | + f"FFmpeg6 defaults to lossy encoding for {format}, skipping round-trip test." |
| 697 | + ) |
| 698 | + if format == "webm" and ( |
| 699 | + ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7)) |
| 700 | + ): |
| 701 | + pytest.skip("Codec for webm is not available in this FFmpeg installation.") |
| 702 | + source_frames = self.decode(TEST_SRC_2_720P.path).data |
| 703 | + |
| 704 | + # Frame rate is fixed with num frames decoded |
| 705 | + encoder = VideoEncoder(frames=source_frames, frame_rate=30) |
| 706 | + |
| 707 | + if method == "to_file": |
| 708 | + encoded_path = str(tmp_path / f"encoder_output.{format}") |
| 709 | + encoder.to_file(dest=encoded_path, crf=0) |
| 710 | + round_trip_frames = self.decode(encoded_path).data |
| 711 | + elif method == "to_tensor": |
| 712 | + encoded_tensor = encoder.to_tensor(format=format, crf=0) |
| 713 | + round_trip_frames = self.decode(encoded_tensor).data |
| 714 | + elif method == "to_file_like": |
| 715 | + file_like = io.BytesIO() |
| 716 | + encoder.to_file_like(file_like=file_like, format=format, crf=0) |
| 717 | + round_trip_frames = self.decode(file_like.getvalue()).data |
| 718 | + else: |
| 719 | + raise ValueError(f"Unknown method: {method}") |
| 720 | + |
| 721 | + assert source_frames.shape == round_trip_frames.shape |
| 722 | + assert source_frames.dtype == round_trip_frames.dtype |
| 723 | + |
| 724 | + # If FFmpeg selects a codec or pixel format that does lossy encoding, assert 99% of pixels |
| 725 | + # are within a higher tolerance. |
| 726 | + if ffmpeg_version == 6: |
| 727 | + assert_close = partial(assert_tensor_close_on_at_least, percentage=99) |
| 728 | + atol = 15 |
| 729 | + else: |
| 730 | + assert_close = torch.testing.assert_close |
| 731 | + atol = 2 |
| 732 | + for s_frame, rt_frame in zip(source_frames, round_trip_frames): |
| 733 | + assert psnr(s_frame, rt_frame) > 30 |
| 734 | + assert_close(s_frame, rt_frame, atol=atol, rtol=0) |
| 735 | + |
| 736 | + @pytest.mark.parametrize( |
| 737 | + "format", |
| 738 | + ( |
| 739 | + "mov", |
| 740 | + "mp4", |
| 741 | + "avi", |
| 742 | + "mkv", |
| 743 | + "flv", |
| 744 | + "gif", |
| 745 | + pytest.param("webm", marks=pytest.mark.slow), |
| 746 | + ), |
| 747 | + ) |
| 748 | + @pytest.mark.parametrize("method", ("to_tensor", "to_file_like")) |
| 749 | + def test_against_to_file(self, tmp_path, format, method): |
| 750 | + # Test that to_file, to_tensor, and to_file_like produce the same results |
| 751 | + ffmpeg_version = get_ffmpeg_major_version() |
| 752 | + if format == "webm" and ( |
| 753 | + ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7)) |
| 754 | + ): |
| 755 | + pytest.skip("Codec for webm is not available in this FFmpeg installation.") |
| 756 | + |
| 757 | + source_frames = self.decode(TEST_SRC_2_720P.path).data |
| 758 | + encoder = VideoEncoder(frames=source_frames, frame_rate=30) |
| 759 | + |
| 760 | + encoded_file = tmp_path / f"output.{format}" |
| 761 | + encoder.to_file(dest=encoded_file, crf=0) |
| 762 | + |
| 763 | + if method == "to_tensor": |
| 764 | + encoded_output = encoder.to_tensor(format=format, crf=0) |
| 765 | + else: # to_file_like |
| 766 | + file_like = io.BytesIO() |
| 767 | + encoder.to_file_like(file_like=file_like, format=format, crf=0) |
| 768 | + encoded_output = file_like.getvalue() |
| 769 | + |
| 770 | + torch.testing.assert_close( |
| 771 | + self.decode(encoded_file).data, |
| 772 | + self.decode(encoded_output).data, |
| 773 | + atol=0, |
| 774 | + rtol=0, |
| 775 | + ) |
| 776 | + |
| 777 | + @pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available") |
| 778 | + @pytest.mark.parametrize( |
| 779 | + "format", |
| 780 | + ( |
| 781 | + "mov", |
| 782 | + "mp4", |
| 783 | + "avi", |
| 784 | + "mkv", |
| 785 | + "flv", |
| 786 | + "gif", |
| 787 | + pytest.param("webm", marks=pytest.mark.slow), |
| 788 | + ), |
| 789 | + ) |
| 790 | + def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format): |
| 791 | + # Encode samples with our encoder and with the FFmpeg CLI, and check |
| 792 | + # that both decoded outputs are similar |
| 793 | + ffmpeg_version = get_ffmpeg_major_version() |
| 794 | + if format == "webm" and ( |
| 795 | + ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7)) |
| 796 | + ): |
| 797 | + pytest.skip("Codec for webm is not available in this FFmpeg installation.") |
| 798 | + |
| 799 | + source_frames = self.decode(TEST_SRC_2_720P.path).data |
| 800 | + |
| 801 | + # Encode with FFmpeg CLI |
| 802 | + temp_raw_path = str(tmp_path / "temp_input.raw") |
| 803 | + with open(temp_raw_path, "wb") as f: |
| 804 | + f.write(source_frames.permute(0, 2, 3, 1).cpu().numpy().tobytes()) |
| 805 | + |
| 806 | + ffmpeg_encoded_path = str(tmp_path / f"ffmpeg_output.{format}") |
| 807 | + frame_rate = 30 |
| 808 | + crf = 0 |
| 809 | + # Some codecs (ex. MPEG4) do not support CRF. |
| 810 | + # Flags not supported by the selected codec will be ignored. |
| 811 | + ffmpeg_cmd = [ |
| 812 | + "ffmpeg", |
| 813 | + "-y", |
| 814 | + "-f", |
| 815 | + "rawvideo", |
| 816 | + "-pix_fmt", |
| 817 | + "rgb24", |
| 818 | + "-s", |
| 819 | + f"{source_frames.shape[3]}x{source_frames.shape[2]}", |
| 820 | + "-r", |
| 821 | + str(frame_rate), |
| 822 | + "-i", |
| 823 | + temp_raw_path, |
| 824 | + "-crf", |
| 825 | + str(crf), |
| 826 | + ffmpeg_encoded_path, |
| 827 | + ] |
| 828 | + subprocess.run(ffmpeg_cmd, check=True) |
| 829 | + |
| 830 | + # Encode with our video encoder |
| 831 | + encoder_output_path = str(tmp_path / f"encoder_output.{format}") |
| 832 | + encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate) |
| 833 | + encoder.to_file(dest=encoder_output_path, crf=crf) |
| 834 | + |
| 835 | + ffmpeg_frames = self.decode(ffmpeg_encoded_path).data |
| 836 | + encoder_frames = self.decode(encoder_output_path).data |
| 837 | + |
| 838 | + assert ffmpeg_frames.shape[0] == encoder_frames.shape[0] |
| 839 | + |
| 840 | + # If FFmpeg selects a codec or pixel format that uses qscale (not crf), |
| 841 | + # the VideoEncoder outputs *slightly* different frames. |
| 842 | + # There may be additional subtle differences in the encoder. |
| 843 | + percentage = 94 if ffmpeg_version == 6 or format == "avi" else 99 |
| 844 | + |
| 845 | + # Check that PSNR between both encoded versions is high |
| 846 | + for ff_frame, enc_frame in zip(ffmpeg_frames, encoder_frames): |
| 847 | + res = psnr(ff_frame, enc_frame) |
| 848 | + assert res > 30 |
| 849 | + assert_tensor_close_on_at_least( |
| 850 | + ff_frame, enc_frame, percentage=percentage, atol=2 |
| 851 | + ) |
| 852 | + |
| 853 | + def test_to_file_like_custom_file_object(self): |
| 854 | + """Test to_file_like with a custom file-like object that implements write and seek.""" |
| 855 | + |
| 856 | + class CustomFileObject: |
| 857 | + def __init__(self): |
| 858 | + self._file = io.BytesIO() |
| 859 | + |
| 860 | + def write(self, data): |
| 861 | + return self._file.write(data) |
| 862 | + |
| 863 | + def seek(self, offset, whence=0): |
| 864 | + return self._file.seek(offset, whence) |
| 865 | + |
| 866 | + def get_encoded_data(self): |
| 867 | + return self._file.getvalue() |
| 868 | + |
| 869 | + source_frames = self.decode(TEST_SRC_2_720P.path).data |
| 870 | + encoder = VideoEncoder(frames=source_frames, frame_rate=30) |
| 871 | + |
| 872 | + file_like = CustomFileObject() |
| 873 | + encoder.to_file_like(file_like, format="mp4", crf=0) |
| 874 | + decoded_frames = self.decode(file_like.get_encoded_data()) |
| 875 | + |
| 876 | + torch.testing.assert_close( |
| 877 | + decoded_frames.data, |
| 878 | + source_frames, |
| 879 | + atol=2, |
| 880 | + rtol=0, |
| 881 | + ) |
| 882 | + |
| 883 | + def test_to_file_like_real_file(self, tmp_path): |
| 884 | + """Test to_file_like with a real file opened in binary write mode.""" |
| 885 | + source_frames = self.decode(TEST_SRC_2_720P.path).data |
| 886 | + encoder = VideoEncoder(frames=source_frames, frame_rate=30) |
| 887 | + |
| 888 | + file_path = tmp_path / "test_file_like.mp4" |
| 889 | + |
| 890 | + with open(file_path, "wb") as file_like: |
| 891 | + encoder.to_file_like(file_like, format="mp4", crf=0) |
| 892 | + decoded_frames = self.decode(str(file_path)) |
| 893 | + |
| 894 | + torch.testing.assert_close( |
| 895 | + decoded_frames.data, |
| 896 | + source_frames, |
| 897 | + atol=2, |
| 898 | + rtol=0, |
| 899 | + ) |
| 900 | + |
| 901 | + def test_to_file_like_bad_methods(self): |
| 902 | + source_frames = self.decode(TEST_SRC_2_720P.path).data |
| 903 | + encoder = VideoEncoder(frames=source_frames, frame_rate=30) |
| 904 | + |
| 905 | + class NoWriteMethod: |
| 906 | + def seek(self, offset, whence=0): |
| 907 | + return 0 |
| 908 | + |
| 909 | + with pytest.raises( |
| 910 | + RuntimeError, match="File like object must implement a write method" |
| 911 | + ): |
| 912 | + encoder.to_file_like(NoWriteMethod(), format="mp4") |
| 913 | + |
| 914 | + class NoSeekMethod: |
| 915 | + def write(self, data): |
| 916 | + return len(data) |
| 917 | + |
| 918 | + with pytest.raises( |
| 919 | + RuntimeError, match="File like object must implement a seek method" |
| 920 | + ): |
| 921 | + encoder.to_file_like(NoSeekMethod(), format="mp4") |
0 commit comments