Spaces:
Running
on
Zero
Running
on
Zero
import cv2 | |
import torch | |
import argparse | |
import numpy as np | |
import os | |
from control_cogvideox.cogvideox_transformer_3d import CogVideoXTransformer3DModel | |
from control_cogvideox.controlnet_cogvideox_transformer_3d import ControlCogVideoXTransformer3DModel | |
from pipeline_cogvideox_controlnet_5b_i2v_instruction2 import ControlCogVideoXPipeline | |
from diffusers.utils import export_to_video | |
from diffusers import AutoencoderKLCogVideoX | |
from transformers import T5EncoderModel, T5Tokenizer | |
from diffusers.schedulers import CogVideoXDDIMScheduler | |
from safetensors.torch import load_file | |
from omegaconf import OmegaConf | |
from transformers import T5EncoderModel | |
from einops import rearrange | |
from decord import VideoReader | |
import transformers | |
from transformers import CLIPTextModel, CLIPProcessor, CLIPVisionModel, CLIPTokenizer | |
from PIL import Image | |
import torch.nn.functional as F | |
from dataset_demo_videos import VideoDataset | |
def unwarp_model(state_dict): | |
new_state_dict = {} | |
for key in state_dict: | |
new_state_dict[key.split('module.')[1]] = state_dict[key] | |
return new_state_dict | |
""" | |
def transform_tensor_to_images(images): | |
images = images.cpu().detach().numpy() | |
images = np.uint8(images) | |
images2 = [] | |
for image in images: | |
image = Image.fromarray(image) | |
images2.append(image) | |
return images2 | |
""" | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--pos_prompt", type=str, default="") | |
parser.add_argument("--neg_prompt", type=str, default="") | |
parser.add_argument("--training_steps", type=int, default=30001) | |
parser.add_argument("--root_path", type=str, default="./models_half") | |
parser.add_argument("--i2v", action="store_true",default=True) | |
parser.add_argument("--guidance_scale", type=float, default=4.0) | |
parser.add_argument("--random_seed", type=int, default=0) | |
args = parser.parse_args() | |
#----------------------------------------------------------------- | |
prefix = args.root_path.replace("/","_").replace(".","_") + "_" + args.pos_prompt.replace(" ","_").replace(".","_") | |
#----------------------------------------------------------------- | |
if args.i2v: | |
key = "i2v" | |
else: | |
key = "t2v" | |
noise_scheduler = CogVideoXDDIMScheduler( | |
**OmegaConf.to_container( | |
OmegaConf.load(f"./cogvideox-5b-{key}/scheduler/scheduler_config.json") | |
) | |
) | |
text_encoder = T5EncoderModel.from_pretrained(f"./cogvideox-5b-{key}/", subfolder="text_encoder", torch_dtype=torch.float16)#.to("cuda:0") | |
vae = AutoencoderKLCogVideoX.from_pretrained(f"./cogvideox-5b-{key}/", subfolder="vae", torch_dtype=torch.float16).to("cuda:0") | |
tokenizer = T5Tokenizer.from_pretrained(f"./cogvideox-5b-{key}/tokenizer", torch_dtype=torch.float16) | |
config = OmegaConf.to_container( | |
OmegaConf.load(f"./cogvideox-5b-{key}/transformer/config.json") | |
) | |
if args.i2v: | |
config["in_channels"] = 32 | |
else: | |
config["in_channels"] = 16 | |
transformer = CogVideoXTransformer3DModel(**config) | |
control_config = OmegaConf.to_container( | |
OmegaConf.load(f"./cogvideox-5b-{key}/transformer/config.json") | |
) | |
if args.i2v: | |
control_config["in_channels"] = 32 | |
else: | |
control_config["in_channels"] = 16 | |
control_config['num_layers'] = 6 | |
control_config['control_in_channels'] = 16 | |
controlnet_transformer = ControlCogVideoXTransformer3DModel(**control_config) | |
all_state_dicts = torch.load("{args.root_path}/ff_controlnet_half.pth", map_location="cpu",weights_only=True) | |
transformer_state_dict = unwarp_model(all_state_dicts["transformer_state_dict"]) | |
controlnet_transformer_state_dict = unwarp_model(all_state_dicts["controlnet_transformer_state_dict"]) | |
transformer.load_state_dict(transformer_state_dict, strict=True) | |
controlnet_transformer.load_state_dict(controlnet_transformer_state_dict, strict=True) | |
transformer = transformer.half().to("cuda:0") | |
controlnet_transformer = controlnet_transformer.half().to("cuda:0") | |
vae = vae.eval() | |
text_encoder = text_encoder.eval() | |
transformer = transformer.eval() | |
controlnet_transformer = controlnet_transformer.eval() | |
pipe = ControlCogVideoXPipeline(tokenizer, | |
text_encoder, | |
vae, | |
transformer, | |
noise_scheduler, | |
controlnet_transformer, | |
)#.to("cuda:0") | |
pipe.vae.enable_slicing() | |
pipe.vae.enable_tiling() | |
pipe.enable_model_cpu_offload() | |
def inference(prefix, source_images, \ | |
target_images, \ | |
text_prompt, negative_prompt, \ | |
pipe, vae, \ | |
step, guidance_scale, \ | |
target_path, video_dir, \ | |
h, w, random_seed): | |
source_pixel_values = source_images/127.5 - 1.0 | |
source_pixel_values = source_pixel_values.to(torch.float16).to("cuda:0") | |
if target_images is not None: | |
target_pixel_values = target_images/127.5 - 1.0 | |
target_pixel_values = target_pixel_values.to(torch.float16).to("cuda:0") | |
bsz,f,h,w,c = source_pixel_values.shape | |
with torch.no_grad(): | |
source_pixel_values = rearrange(source_pixel_values, "b f w h c -> b c f w h") | |
source_latents = vae.encode(source_pixel_values).latent_dist.sample() | |
source_latents = source_latents.to(torch.float16) | |
source_latents = source_latents * vae.config.scaling_factor | |
source_latents = rearrange(source_latents, "b c f h w -> b f c h w") | |
if target_images is not None: | |
target_pixel_values = rearrange(target_pixel_values, "b f w h c -> b c f w h") | |
images = target_pixel_values[:,:,:1,...] | |
image_latents = vae.encode(images).latent_dist.sample() | |
image_latents = image_latents.to(torch.float16) | |
image_latents = image_latents * vae.config.scaling_factor | |
image_latents = rearrange(image_latents, "b c f h w -> b f c h w") | |
image_latents = torch.cat([image_latents, torch.zeros_like(source_latents)[:,1:]],dim=1) | |
latents = torch.cat([image_latents, source_latents], dim=2) | |
else: | |
image_latents = None | |
latents = source_latents | |
video = pipe( | |
prompt = text_prompt, | |
negative_prompt = negative_prompt, | |
video_condition = source_latents, # input to controlnet | |
video_condition2 = image_latents, # concat with latents | |
height = h, | |
width = w, | |
num_frames = f, | |
num_inference_steps = 50, | |
interval = 6, | |
guidance_scale = guidance_scale, | |
generator = torch.Generator(device=f"cuda:0").manual_seed(random_seed) | |
).frames[0] | |
def transform_tensor_to_images(images): | |
images = images.cpu().detach().numpy() | |
images = np.uint8(images) | |
images2 = [] | |
for image in images: | |
image = Image.fromarray(image) | |
images2.append(image) | |
return images2 | |
source_images = transform_tensor_to_images(source_images[0]) | |
os.makedirs(f"./{target_path}/{step}_{prefix}_video_guidance_scale_{guidance_scale}", exist_ok=True) | |
export_to_video(video, f"./{target_path}/{step}_{prefix}_video_guidance_scale_{guidance_scale}/output_{random_seed}.mp4", fps=8) | |
export_to_video(source_images, f"./{target_path}/{step}_{prefix}_video_guidance_scale_{guidance_scale}/output_{random_seed}_org.mp4", fps=8) | |
def read_video(video_path, h, w): | |
vr = VideoReader(video_path) | |
images = vr.get_batch(list(range(min(33, len(vr))))).asnumpy() | |
images2 = [] | |
for image in images: | |
image = cv2.resize(image, (h,w)) | |
images2.append(image) | |
images2 = np.array(images2) | |
images = images2 | |
del vr | |
images = torch.from_numpy(images) | |
return images | |
def resize(images, h, w): | |
images = rearrange(images, "f w h c -> f c w h") | |
images = F.interpolate(images, (h, w), mode="bilinear") | |
images = rearrange(images, "f c w h -> f w h c") | |
images = images[None,...] | |
return images | |
h = 448 | |
w = 768 | |
root_dir = 'additional_videos8' | |
dataset = VideoDataset(root_dir) | |
print(len(dataset)) | |
for step, sample in enumerate(dataset): | |
image = sample['image'] # w h c | |
images = sample['frames'] # f w h c | |
pos_prompt = sample['pos_prompt'] | |
neg_prompt = sample['neg_prompt'] | |
image_path = sample['image_path'] | |
prefix = image_path.replace("/","_") | |
source_images = images[None,...] | |
target_images = image[None,None,...] | |
print(pos_prompt, neg_prompt) | |
print(source_images.shape, torch.min(source_images), torch.max(source_images)) | |
print(target_images.shape, torch.min(target_images), torch.max(target_images)) | |
target_path = f"demo_first_frame_controlnet_33_stride_2_new_videos_8/{prefix}/" | |
random_seeds = [args.random_seed] | |
for random_seed in random_seeds: | |
inference("", source_images, \ | |
target_images, pos_prompt, \ | |
neg_prompt, pipe, vae, \ | |
args.training_steps, args.guidance_scale, \ | |
target_path, "", \ | |
h, w, random_seed) | |