|
|
|
|
|
import contextlib |
|
import os |
|
import random |
|
import tempfile |
|
import unittest |
|
import torch |
|
import torchvision.io as io |
|
|
|
from densepose.data.transform import ImageResizeTransform |
|
from densepose.data.video import RandomKFramesSelector, VideoKeyframeDataset |
|
|
|
try: |
|
import av |
|
except ImportError: |
|
av = None |
|
|
|
|
|
|
|
def _create_video_frames(num_frames, height, width): |
|
y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width)) |
|
data = [] |
|
for i in range(num_frames): |
|
xc = float(i) / num_frames |
|
yc = 1 - float(i) / (2 * num_frames) |
|
d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255 |
|
data.append(d.unsqueeze(2).repeat(1, 1, 3).byte()) |
|
return torch.stack(data, 0) |
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None, options=None): |
|
if lossless: |
|
if video_codec is not None: |
|
raise ValueError("video_codec can't be specified together with lossless") |
|
if options is not None: |
|
raise ValueError("options can't be specified together with lossless") |
|
video_codec = "libx264rgb" |
|
options = {"crf": "0"} |
|
if video_codec is None: |
|
video_codec = "libx264" |
|
if options is None: |
|
options = {} |
|
data = _create_video_frames(num_frames, height, width) |
|
with tempfile.NamedTemporaryFile(suffix=".mp4") as f: |
|
f.close() |
|
io.write_video(f.name, data, fps=fps, video_codec=video_codec, options=options) |
|
yield f.name, data |
|
os.unlink(f.name) |
|
|
|
|
|
@unittest.skipIf(av is None, "PyAV unavailable") |
|
class TestVideoKeyframeDataset(unittest.TestCase): |
|
def test_read_keyframes_all(self): |
|
with temp_video(60, 300, 300, 5, video_codec="mpeg4") as (fname, data): |
|
video_list = [fname] |
|
category_list = [None] |
|
dataset = VideoKeyframeDataset(video_list, category_list) |
|
self.assertEqual(len(dataset), 1) |
|
data1, categories1 = dataset[0]["images"], dataset[0]["categories"] |
|
self.assertEqual(data1.shape, torch.Size((5, 3, 300, 300))) |
|
self.assertEqual(data1.dtype, torch.float32) |
|
self.assertIsNone(categories1[0]) |
|
return |
|
self.assertTrue(False) |
|
|
|
def test_read_keyframes_with_selector(self): |
|
with temp_video(60, 300, 300, 5, video_codec="mpeg4") as (fname, data): |
|
video_list = [fname] |
|
category_list = [None] |
|
random.seed(0) |
|
frame_selector = RandomKFramesSelector(3) |
|
dataset = VideoKeyframeDataset(video_list, category_list, frame_selector) |
|
self.assertEqual(len(dataset), 1) |
|
data1, categories1 = dataset[0]["images"], dataset[0]["categories"] |
|
self.assertEqual(data1.shape, torch.Size((3, 3, 300, 300))) |
|
self.assertEqual(data1.dtype, torch.float32) |
|
self.assertIsNone(categories1[0]) |
|
return |
|
self.assertTrue(False) |
|
|
|
def test_read_keyframes_with_selector_with_transform(self): |
|
with temp_video(60, 300, 300, 5, video_codec="mpeg4") as (fname, data): |
|
video_list = [fname] |
|
category_list = [None] |
|
random.seed(0) |
|
frame_selector = RandomKFramesSelector(1) |
|
transform = ImageResizeTransform() |
|
dataset = VideoKeyframeDataset(video_list, category_list, frame_selector, transform) |
|
data1, categories1 = dataset[0]["images"], dataset[0]["categories"] |
|
self.assertEqual(len(dataset), 1) |
|
self.assertEqual(data1.shape, torch.Size((1, 3, 800, 800))) |
|
self.assertEqual(data1.dtype, torch.float32) |
|
self.assertIsNone(categories1[0]) |
|
return |
|
self.assertTrue(False) |
|
|