Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# Copyright (c) Alibaba, Inc. and its affiliates. | |
import copy | |
import io | |
import os | |
import torch | |
import numpy as np | |
import cv2 | |
import imageio | |
from PIL import Image | |
import pycocotools.mask as mask_utils | |
def single_mask_to_rle(mask): | |
rle = mask_utils.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0] | |
rle["counts"] = rle["counts"].decode("utf-8") | |
return rle | |
def single_rle_to_mask(rle): | |
mask = np.array(mask_utils.decode(rle)).astype(np.uint8) | |
return mask | |
def single_mask_to_xyxy(mask): | |
bbox = np.zeros((4), dtype=int) | |
rows, cols = np.where(np.array(mask)) | |
if len(rows) > 0 and len(cols) > 0: | |
x_min, x_max = np.min(cols), np.max(cols) | |
y_min, y_max = np.min(rows), np.max(rows) | |
bbox[:] = [x_min, y_min, x_max, y_max] | |
return bbox.tolist() | |
def get_mask_box(mask, threshold=255): | |
locs = np.where(mask >= threshold) | |
if len(locs) < 1 or locs[0].shape[0] < 1 or locs[1].shape[0] < 1: | |
return None | |
left, right = np.min(locs[1]), np.max(locs[1]) | |
top, bottom = np.min(locs[0]), np.max(locs[0]) | |
return [left, top, right, bottom] | |
def convert_to_numpy(image): | |
if isinstance(image, Image.Image): | |
image = np.array(image) | |
elif isinstance(image, torch.Tensor): | |
image = image.detach().cpu().numpy() | |
elif isinstance(image, np.ndarray): | |
image = image.copy() | |
else: | |
raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.' | |
return image | |
def convert_to_pil(image): | |
if isinstance(image, Image.Image): | |
image = image.copy() | |
elif isinstance(image, torch.Tensor): | |
image = image.detach().cpu().numpy() | |
image = Image.fromarray(image.astype('uint8')) | |
elif isinstance(image, np.ndarray): | |
image = Image.fromarray(image.astype('uint8')) | |
else: | |
raise TypeError(f'Unsupported data type {type(image)}, only supports np.ndarray, torch.Tensor, Pillow Image.') | |
return image | |
def convert_to_torch(image): | |
if isinstance(image, Image.Image): | |
image = torch.from_numpy(np.array(image)).float() | |
elif isinstance(image, torch.Tensor): | |
image = image.clone() | |
elif isinstance(image, np.ndarray): | |
image = torch.from_numpy(image.copy()).float() | |
else: | |
raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.' | |
return image | |
def resize_image(input_image, resolution): | |
H, W, C = input_image.shape | |
H = float(H) | |
W = float(W) | |
k = float(resolution) / min(H, W) | |
H *= k | |
W *= k | |
H = int(np.round(H / 64.0)) * 64 | |
W = int(np.round(W / 64.0)) * 64 | |
img = cv2.resize( | |
input_image, (W, H), | |
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) | |
return img, k | |
def resize_image_ori(h, w, image, k): | |
img = cv2.resize( | |
image, (w, h), | |
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) | |
return img | |
def save_one_video(file_path, videos, fps=8, quality=8, macro_block_size=None): | |
try: | |
video_writer = imageio.get_writer(file_path, fps=fps, codec='libx264', quality=quality, macro_block_size=macro_block_size) | |
for frame in videos: | |
video_writer.append_data(frame) | |
video_writer.close() | |
return True | |
except Exception as e: | |
print(f"Video save error: {e}") | |
return False | |
def save_one_image(file_path, image, use_type='cv2'): | |
try: | |
if use_type == 'cv2': | |
cv2.imwrite(file_path, image) | |
elif use_type == 'pil': | |
image = Image.fromarray(image) | |
image.save(file_path) | |
else: | |
raise ValueError(f"Unknown image write type '{use_type}'") | |
return True | |
except Exception as e: | |
print(f"Image save error: {e}") | |
return False | |
def read_image(image_path, use_type='cv2', is_rgb=True, info=False): | |
image = None | |
width, height = None, None | |
if use_type == 'cv2': | |
try: | |
image = cv2.imread(image_path) | |
if image is None: | |
raise Exception("Image not found or path is incorrect.") | |
if is_rgb: | |
image = image[..., ::-1] | |
height, width = image.shape[:2] | |
except Exception as e: | |
print(f"OpenCV read error: {e}") | |
return None | |
elif use_type == 'pil': | |
try: | |
image = Image.open(image_path) | |
if is_rgb: | |
image = image.convert('RGB') | |
width, height = image.size | |
image = np.array(image) | |
except Exception as e: | |
print(f"PIL read error: {e}") | |
return None | |
else: | |
raise ValueError(f"Unknown image read type '{use_type}'") | |
if info: | |
return image, width, height | |
else: | |
return image | |
def read_mask(mask_path, use_type='cv2', info=False): | |
mask = None | |
width, height = None, None | |
if use_type == 'cv2': | |
try: | |
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | |
if mask is None: | |
raise Exception("Mask not found or path is incorrect.") | |
height, width = mask.shape | |
except Exception as e: | |
print(f"OpenCV read error: {e}") | |
return None | |
elif use_type == 'pil': | |
try: | |
mask = Image.open(mask_path).convert('L') | |
width, height = mask.size | |
mask = np.array(mask) | |
except Exception as e: | |
print(f"PIL read error: {e}") | |
return None | |
else: | |
raise ValueError(f"Unknown mask read type '{use_type}'") | |
if info: | |
return mask, width, height | |
else: | |
return mask | |
def read_video_frames(video_path, use_type='cv2', is_rgb=True, info=False): | |
frames = [] | |
if use_type == "decord": | |
import decord | |
decord.bridge.set_bridge("native") | |
try: | |
cap = decord.VideoReader(video_path) | |
total_frames = len(cap) | |
fps = cap.get_avg_fps() | |
height, width, _ = cap[0].shape | |
frames = [cap[i].asnumpy() for i in range(len(cap))] | |
except Exception as e: | |
print(f"Decord read error: {e}") | |
return None | |
elif use_type == "cv2": | |
try: | |
cap = cv2.VideoCapture(video_path) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
if is_rgb: | |
frames.append(frame[..., ::-1]) | |
else: | |
frames.append(frame) | |
cap.release() | |
total_frames = len(frames) | |
except Exception as e: | |
print(f"OpenCV read error: {e}") | |
return None | |
else: | |
raise ValueError(f"Unknown video type {use_type}") | |
if info: | |
return frames, fps, width, height, total_frames | |
else: | |
return frames | |
def read_video_one_frame(video_path, use_type='cv2', is_rgb=True): | |
image_first = None | |
if use_type == "decord": | |
import decord | |
decord.bridge.set_bridge("native") | |
try: | |
cap = decord.VideoReader(video_path) | |
image_first = cap[0].asnumpy() | |
except Exception as e: | |
print(f"Decord read error: {e}") | |
return None | |
elif use_type == "cv2": | |
try: | |
cap = cv2.VideoCapture(video_path) | |
ret, frame = cap.read() | |
if is_rgb: | |
image_first = frame[..., ::-1] | |
else: | |
image_first = frame | |
cap.release() | |
except Exception as e: | |
print(f"OpenCV read error: {e}") | |
return None | |
else: | |
raise ValueError(f"Unknown video type {use_type}") | |
return image_first | |
def read_video_last_frame(video_path, use_type='cv2', is_rgb=True): | |
image_last = None | |
if use_type == "decord": | |
import decord | |
decord.bridge.set_bridge("native") | |
try: | |
cap = decord.VideoReader(video_path) | |
if len(cap) > 0: # Check if video has at least one frame | |
image_last = cap[-1].asnumpy() # Get last frame using negative index | |
except Exception as e: | |
print(f"Decord read error: {e}") | |
return None | |
elif use_type == "cv2": | |
try: | |
cap = cv2.VideoCapture(video_path) | |
# Get total frame count | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
if total_frames > 0: | |
# Set position to last frame | |
cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - 1) | |
ret, frame = cap.read() | |
if ret: # Check if frame was read successfully | |
if is_rgb: | |
image_last = frame[..., ::-1] | |
else: | |
image_last = frame | |
cap.release() | |
except Exception as e: | |
print(f"OpenCV read error: {e}") | |
return None | |
else: | |
raise ValueError(f"Unknown video type {use_type}") | |
return image_last | |
def align_frames(first_frame, last_frame): | |
h1, w1 = first_frame.shape[:2] | |
h2, w2 = last_frame.shape[:2] | |
if (h1, w1) == (h2, w2): | |
return last_frame | |
ratio = min(w1 / w2, h1 / h2) | |
new_w = int(w2 * ratio) | |
new_h = int(h2 * ratio) | |
resized = cv2.resize(last_frame, (new_w, new_h), interpolation=cv2.INTER_AREA) | |
aligned = np.ones((h1, w1, 3), dtype=np.uint8) * 255 | |
x_offset = (w1 - new_w) // 2 | |
y_offset = (h1 - new_h) // 2 | |
aligned[y_offset:y_offset + new_h, x_offset:x_offset + new_w] = resized | |
return aligned | |
def save_sam2_video(video_path, video_segments, output_video_path): | |
cap = cv2.VideoCapture(video_path) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
frames = [] | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frames.append(frame) | |
cap.release() | |
obj_mask_map = {} | |
for frame_idx, segments in video_segments.items(): | |
for obj_id, info in segments.items(): | |
seg = single_rle_to_mask(info['mask'])[None, ...].squeeze(0).astype(bool) | |
if obj_id not in obj_mask_map: | |
obj_mask_map[obj_id] = [seg] | |
else: | |
obj_mask_map[obj_id].append(seg) | |
for obj_id, segs in obj_mask_map.items(): | |
output_obj_video_path = os.path.join(output_video_path, f"{obj_id}.mp4") | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # codec for saving the video | |
video_writer = cv2.VideoWriter(output_obj_video_path, fourcc, fps, (width * 2, height)) | |
for i, (frame, seg) in enumerate(zip(frames, segs)): | |
print(obj_id, i, np.sum(seg), seg.shape) | |
left_frame = frame.copy() | |
left_frame[seg] = 0 | |
right_frame = frame.copy() | |
right_frame[~seg] = 255 | |
frame_new = np.concatenate([left_frame, right_frame], axis=1) | |
video_writer.write(frame_new) | |
video_writer.release() | |
def get_annotator_instance(anno_cfg): | |
import vace.annotators as annotators | |
anno_cfg = copy.deepcopy(anno_cfg) | |
class_name = anno_cfg.pop("NAME") | |
input_params = anno_cfg.pop("INPUTS") | |
output_params = anno_cfg.pop("OUTPUTS") | |
anno_ins = getattr(annotators, class_name)(cfg=anno_cfg) | |
return {"inputs": input_params, "outputs": output_params, "anno_ins": anno_ins} | |
def get_annotator(config_type='', config_task='', return_dict=True): | |
anno_dict = None | |
from vace.configs import VACE_CONFIGS | |
if config_type in VACE_CONFIGS: | |
task_configs = VACE_CONFIGS[config_type] | |
if config_task in task_configs: | |
anno_dict = get_annotator_instance(task_configs[config_task]) | |
else: | |
raise ValueError(f"Unknown config task {config_task}") | |
else: | |
for cfg_type, cfg_dict in VACE_CONFIGS.items(): | |
if config_task in cfg_dict: | |
for task_name, task_cfg in cfg_dict[config_task].items(): | |
anno_dict = get_annotator_instance(task_cfg) | |
else: | |
raise ValueError(f"Unknown config type {config_type}") | |
if return_dict: | |
return anno_dict | |
else: | |
return anno_dict['anno_ins'] | |