RelightVid / misc_utils /image_utils.py
aleafy's picture
Start fresh
0a63786
raw
history blame
8.8 kB
import os
import matplotlib.pyplot as plt
import torch
import numpy as np
import cv2
import imageio
from PIL import Image
import textwrap
def find_nearest_Nx(size, N=32):
return int(np.ceil(size / N) * N)
def load_image_as_tensor(image_path, image_size):
if isinstance(image_size, int):
image_size = (image_size, image_size)
image = cv2.imread(image_path)[..., ::-1]
try:
image = cv2.resize(image, image_size)
except Exception as e:
print(e)
print(image_path)
image = torch.from_numpy(np.array(image).transpose(2, 0, 1)) / 255.
return image
def show_image(image):
if len(image.shape) == 4:
image = image[0]
plt.imshow(image.permute(1, 2, 0).detach().cpu().numpy())
plt.show()
def extract_video(video_path, save_dir, sampling_fps, skip_frames=0):
os.makedirs(save_dir, exist_ok=True)
cap = cv2.VideoCapture(video_path)
frame_skip = int(cap.get(cv2.CAP_PROP_FPS) / sampling_fps)
frame_count = 0
save_count = 0
while True:
ret, frame = cap.read()
if not ret:
break
if frame_count < skip_frames: # skip the first N frames
frame_count += 1
continue
if (frame_count - skip_frames) % frame_skip == 0:
# Save the frame as an image file if it doesn't already exist
save_path = os.path.join(save_dir, f"frame{save_count:04d}.jpg")
save_count += 1
if not os.path.exists(save_path):
cv2.imwrite(save_path, frame)
frame_count += 1
cap.release()
cv2.destroyAllWindows()
def concatenate_frames_to_video(frame_dir, video_path, fps):
os.makedirs(os.path.dirname(video_path), exist_ok=True)
# Get the list of frame file names in the directory
frame_files = [f for f in os.listdir(frame_dir) if f.startswith("frame")]
# Sort the frame file names in ascending order
frame_files.sort()
# Load the first frame to get the frame size
frame = cv2.imread(os.path.join(frame_dir, frame_files[0]))
height, width, _ = frame.shape
# Initialize the video writer
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(video_path, fourcc, fps, (width, height))
# Loop through the frame files and add them to the video
for frame_file in frame_files:
frame_path = os.path.join(frame_dir, frame_file)
frame = cv2.imread(frame_path)
out.write(frame)
# Release the video writer
out.release()
def cumulative_histogram(hist):
cum_hist = hist.copy()
for i in range(1, len(hist)):
cum_hist[i] = cum_hist[i - 1] + hist[i]
return cum_hist
def histogram_matching(src_img, ref_img):
src_img = (src_img * 255).astype(np.uint8)
ref_img = (ref_img * 255).astype(np.uint8)
src_img_yuv = cv2.cvtColor(src_img, cv2.COLOR_RGB2YUV)
ref_img_yuv = cv2.cvtColor(ref_img, cv2.COLOR_RGB2YUV)
matched_img = np.zeros_like(src_img_yuv)
for channel in range(src_img_yuv.shape[2]):
src_hist, _ = np.histogram(src_img_yuv[:, :, channel].ravel(), 256, (0, 256))
ref_hist, _ = np.histogram(ref_img_yuv[:, :, channel].ravel(), 256, (0, 256))
src_cum_hist = cumulative_histogram(src_hist)
ref_cum_hist = cumulative_histogram(ref_hist)
lut = np.zeros(256, dtype=np.uint8)
j = 0
for i in range(256):
while ref_cum_hist[j] < src_cum_hist[i] and j < 255:
j += 1
lut[i] = j
matched_img[:, :, channel] = cv2.LUT(src_img_yuv[:, :, channel], lut)
matched_img = cv2.cvtColor(matched_img, cv2.COLOR_YUV2RGB)
matched_img = matched_img.astype(np.float32) / 255
return matched_img
def canny_image_batch(image_batch, low_threshold=100, high_threshold=200):
if isinstance(image_batch, torch.Tensor):
# [-1, 1] tensor -> [0, 255] numpy array
is_torch = True
device = image_batch.device
image_batch = (image_batch + 1) * 127.5
image_batch = image_batch.permute(0, 2, 3, 1).detach().cpu().numpy()
image_batch = image_batch.astype(np.uint8)
image_batch = np.array([cv2.Canny(image, low_threshold, high_threshold) for image in image_batch])
image_batch = image_batch[:, :, :, None]
image_batch = np.concatenate([image_batch, image_batch, image_batch], axis=3)
if is_torch:
# [0, 255] numpy array -> [-1, 1] tensor
image_batch = torch.from_numpy(image_batch).permute(0, 3, 1, 2).float() / 255.
image_batch = image_batch.to(device)
return image_batch
def images_to_gif(images, filename, fps):
os.makedirs(os.path.dirname(filename), exist_ok=True)
# Normalize to 0-255 and convert to uint8
images = [(img * 255).astype(np.uint8) if img.dtype == np.float32 else img for img in images]
images = [Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) for img in images]
imageio.mimsave(filename, images, duration=1 / fps)
def load_gif(image_path):
import imageio
gif = imageio.get_reader(image_path)
np_images = np.array([frame[..., :3] for frame in gif])
return np_images
def add_text_to_frame(frame, text, font_scale=1, thickness=2, color=(0, 0, 0), bg_color=(255, 255, 255), max_width=30):
"""
Add text to a frame.
"""
# Make a copy of the frame
frame_with_text = np.copy(frame)
# Choose font
font = cv2.FONT_HERSHEY_SIMPLEX
# Split text into lines if it's too long
lines = textwrap.wrap(text, width=max_width)
# Get total text height
total_text_height = len(lines) * (thickness * font_scale + 10) + 60 * font_scale
# Create an image filled with the background color, having enough space for the text
text_bg_img = np.full((int(total_text_height), frame.shape[1], 3), bg_color, dtype=np.uint8)
# Put each line on the text background image
y = 0
for line in lines:
text_size, _ = cv2.getTextSize(line, font, font_scale, thickness)
text_x = (text_bg_img.shape[1] - text_size[0]) // 2
y += text_size[1] + 10
cv2.putText(text_bg_img, line, (text_x, y), font, font_scale, color, thickness)
# Append the text background image to the frame
frame_with_text = np.vstack((frame_with_text, text_bg_img))
return frame_with_text
def add_text_to_gif(numpy_images, text, **kwargs):
"""
Add text to each frame of a gif.
"""
# Iterate over frames and add text to each frame
frames_with_text = []
for frame in numpy_images:
frame_with_text = add_text_to_frame(frame, text, **kwargs)
frames_with_text.append(frame_with_text)
# Convert the list of frames to a numpy array
numpy_images_with_text = np.array(frames_with_text)
return numpy_images_with_text
def pad_images_to_same_height(images):
"""
Pad images to the same height.
"""
# Find the maximum height
max_height = max(img.shape[0] for img in images)
# Pad each image to the maximum height
padded_images = []
for img in images:
pad_height = max_height - img.shape[0]
padded_img = cv2.copyMakeBorder(img, 0, pad_height, 0, 0, cv2.BORDER_CONSTANT, value=[255, 255, 255])
padded_images.append(padded_img)
return padded_images
def concatenate_gifs(gifs):
"""
Concatenate gifs.
"""
# Ensure that all gifs have the same number of frames
min_num_frames = min(gif.shape[0] for gif in gifs)
gifs = [gif[:min_num_frames] for gif in gifs]
# Concatenate each frame
concatenated_gifs = []
for i in range(min_num_frames):
# Get the i-th frame from each gif
frames = [gif[i] for gif in gifs]
# Pad the frames to the same height
padded_frames = pad_images_to_same_height(frames)
# Concatenate the padded frames
concatenated_frame = np.concatenate(padded_frames, axis=1)
concatenated_gifs.append(concatenated_frame)
return np.array(concatenated_gifs)
def stack_gifs(gifs):
'''vertically stack gifs'''
min_num_frames = min(gif.shape[0] for gif in gifs)
stacked_gifs = []
for i in range(min_num_frames):
frames = [gif[i] for gif in gifs]
stacked_frame = np.concatenate(frames, axis=0)
stacked_gifs.append(stacked_frame)
return np.array(stacked_gifs)
def save_tensor_to_gif(images, filename, fps):
images = images.squeeze(0).detach().cpu().numpy().transpose(0, 2, 3, 1) / 2 + 0.5
images_to_gif(images, filename, fps)
def save_tensor_to_images(images, output_dir):
images = images.squeeze(0).detach().cpu().numpy().transpose(0, 2, 3, 1) / 2 + 0.5
os.makedirs(output_dir, exist_ok=True)
for i in range(images.shape[0]):
plt.imsave(f'{output_dir}/{i:03d}.jpg', images[i])