Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import random | |
| from dataclasses import dataclass | |
| from typing import List | |
| from training.dataset.vos_segment_loader import LazySegments | |
| MAX_RETRIES = 1000 | |
| class SampledFramesAndObjects: | |
| frames: List[int] | |
| object_ids: List[int] | |
| class VOSSampler: | |
| def __init__(self, sort_frames=True): | |
| # frames are ordered by frame id when sort_frames is True | |
| self.sort_frames = sort_frames | |
| def sample(self, video): | |
| raise NotImplementedError() | |
| class RandomUniformSampler(VOSSampler): | |
| def __init__( | |
| self, | |
| num_frames, | |
| max_num_objects, | |
| reverse_time_prob=0.0, | |
| ): | |
| self.num_frames = num_frames | |
| self.max_num_objects = max_num_objects | |
| self.reverse_time_prob = reverse_time_prob | |
| def sample(self, video, segment_loader, epoch=None): | |
| for retry in range(MAX_RETRIES): | |
| if len(video.frames) < self.num_frames: | |
| raise Exception( | |
| f"Cannot sample {self.num_frames} frames from video {video.video_name} as it only has {len(video.frames)} annotated frames." | |
| ) | |
| start = random.randrange(0, len(video.frames) - self.num_frames + 1) | |
| frames = [video.frames[start + step] for step in range(self.num_frames)] | |
| if random.uniform(0, 1) < self.reverse_time_prob: | |
| # Reverse time | |
| frames = frames[::-1] | |
| # Get first frame object ids | |
| visible_object_ids = [] | |
| loaded_segms = segment_loader.load(frames[0].frame_idx) | |
| if isinstance(loaded_segms, LazySegments): | |
| # LazySegments for SA1BRawDataset | |
| visible_object_ids = list(loaded_segms.keys()) | |
| else: | |
| for object_id, segment in segment_loader.load( | |
| frames[0].frame_idx | |
| ).items(): | |
| if segment.sum(): | |
| visible_object_ids.append(object_id) | |
| # First frame needs to have at least a target to track | |
| if len(visible_object_ids) > 0: | |
| break | |
| if retry >= MAX_RETRIES - 1: | |
| raise Exception("No visible objects") | |
| object_ids = random.sample( | |
| visible_object_ids, | |
| min(len(visible_object_ids), self.max_num_objects), | |
| ) | |
| return SampledFramesAndObjects(frames=frames, object_ids=object_ids) | |
| class EvalSampler(VOSSampler): | |
| """ | |
| VOS Sampler for evaluation: sampling all the frames and all the objects in a video | |
| """ | |
| def __init__( | |
| self, | |
| ): | |
| super().__init__() | |
| def sample(self, video, segment_loader, epoch=None): | |
| """ | |
| Sampling all the frames and all the objects | |
| """ | |
| if self.sort_frames: | |
| # ordered by frame id | |
| frames = sorted(video.frames, key=lambda x: x.frame_idx) | |
| else: | |
| # use the original order | |
| frames = video.frames | |
| object_ids = segment_loader.load(frames[0].frame_idx).keys() | |
| if len(object_ids) == 0: | |
| raise Exception("First frame of the video has no objects") | |
| return SampledFramesAndObjects(frames=frames, object_ids=object_ids) | |