Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import sys | |
| import argparse | |
| import random | |
| import time | |
| from omegaconf import OmegaConf | |
| import torch | |
| import torchvision | |
| from pytorch_lightning import seed_everything | |
| from huggingface_hub import hf_hub_download | |
| from einops import repeat | |
| import torchvision.transforms as transforms | |
| from torchvision.utils import make_grid | |
| from utils.utils import instantiate_from_config | |
| from collections import OrderedDict | |
| sys.path.insert(0, "scripts/evaluation") | |
| from lvdm.models.samplers.ddim import DDIMSampler, DDIMStyleSampler | |
| def load_model_checkpoint(model, ckpt): | |
| state_dict = torch.load(ckpt, map_location="cpu") | |
| if "state_dict" in list(state_dict.keys()): | |
| state_dict = state_dict["state_dict"] | |
| else: | |
| # deepspeed | |
| state_dict = OrderedDict() | |
| for key in state_dict['module'].keys(): | |
| state_dict[key[16:]]=state_dict['module'][key] | |
| model.load_state_dict(state_dict, strict=False) | |
| print('>>> model checkpoint loaded.') | |
| return model | |
| def download_model(): | |
| REPO_ID = 'VideoCrafter/Text2Video-512' | |
| filename_list = ['model.ckpt'] | |
| os.makedirs('./checkpoints/videocrafter_t2v_320_512/', exist_ok=True) | |
| for filename in filename_list: | |
| local_file = os.path.join('./checkpoints/videocrafter_t2v_320_512/', filename) | |
| if not os.path.exists(local_file): | |
| hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/videocrafter_t2v_320_512/', force_download=True) | |
| REPO_ID = 'liuhuohuo/StyleCrafter' | |
| filename_list = ['adapter_v1.pth', 'temporal_v1.pth'] | |
| os.makedirs('./checkpoints/stylecrafter', exist_ok=True) | |
| for filename in filename_list: | |
| local_file = os.path.join('./checkpoints/stylecrafter', filename) | |
| if not os.path.exists(local_file): | |
| hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/stylecrafter', force_download=True) | |
| def infer(image, prompt, infer_type='image', seed=123, style_strength=1.0, steps=50): | |
| download_model() | |
| ckpt_path = 'checkpoints/videocrafter_t2v_320_512/model.ckpt' | |
| adapter_ckpt_path = 'checkpoints/stylecrafter/adapter_v1.pth' | |
| temporal_ckpt_path = 'checkpoints/stylecrafter/temporal_v1.pth' | |
| if infer_type == 'image': | |
| config_file='configs/inference_image_512_512.yaml' | |
| h, w = 512 // 8, 512 // 8 | |
| unconditional_guidance_scale = 7.5 | |
| unconditional_guidance_scale_style = None | |
| else: | |
| config_file='configs/inference_video_320_512.yaml' | |
| h, w = 320 // 8, 512 // 8 | |
| unconditional_guidance_scale = 15.0 | |
| unconditional_guidance_scale_style = 7.5 | |
| config = OmegaConf.load(config_file) | |
| model_config = config.pop("model", OmegaConf.create()) | |
| model_config['params']['adapter_config']['params']['scale'] = style_strength | |
| model = instantiate_from_config(model_config) | |
| model = model.cuda() | |
| # load ckpt | |
| assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!" | |
| assert os.path.exists(adapter_ckpt_path), "Error: adapter checkpoint Not Found!" | |
| assert os.path.exists(temporal_ckpt_path), "Error: temporal checkpoint Not Found!" | |
| model = load_model_checkpoint(model, ckpt_path) | |
| model.load_pretrained_adapter(adapter_ckpt_path) | |
| if infer_type == 'video': | |
| model.load_pretrained_temporal(temporal_ckpt_path) | |
| model.eval() | |
| seed_everything(seed) | |
| batch_size=1 | |
| channels = model.channels | |
| frames = model.temporal_length if infer_type == 'video' else 1 | |
| noise_shape = [batch_size, channels, frames, h, w] | |
| # text cond | |
| cond = model.get_learned_conditioning([prompt]) | |
| neg_prompt = batch_size * [""] | |
| uc = model.get_learned_conditioning(neg_prompt) | |
| # style cond | |
| style_transforms = torchvision.transforms.Compose([ | |
| torchvision.transforms.Resize(512), | |
| torchvision.transforms.CenterCrop(512), | |
| torchvision.transforms.ToTensor(), | |
| torchvision.transforms.Lambda(lambda x: x * 2. - 1.), | |
| ]) | |
| style_img = style_transforms(image).unsqueeze(0).cuda() | |
| style_cond = model.get_batch_style(style_img) | |
| append_to_context = model.adapter(style_cond) | |
| scale_scalar = model.adapter.scale_predictor(torch.concat([append_to_context, cond], dim=1)) | |
| ddim_sampler = DDIMSampler(model) if infer_type == 'image' else DDIMStyleSampler(model) | |
| samples, _ = ddim_sampler.sample(S=steps, | |
| conditioning=cond, | |
| batch_size=noise_shape[0], | |
| shape=noise_shape[1:], | |
| verbose=False, | |
| unconditional_guidance_scale=unconditional_guidance_scale, | |
| unconditional_guidance_scale_style=unconditional_guidance_scale_style, | |
| unconditional_conditioning=uc, | |
| eta=1.0, | |
| temporal_length=noise_shape[2], | |
| append_to_context=append_to_context, | |
| scale_scalar=scale_scalar | |
| ) | |
| samples = model.decode_first_stage(samples) | |
| if infer_type == 'image': | |
| samples = samples[:, :, 0, :, :].detach().cpu() | |
| out_path = "./output.png" | |
| torchvision.utils.save_image(samples, out_path, nrow=1, normalize=True, range=(-1, 1)) | |
| elif infer_type == 'video': | |
| samples = samples.detach().cpu() | |
| out_path = "./output.mp4" | |
| video = torch.clamp(samples, -1, 1) | |
| video = video.permute(2, 0, 1, 3, 4) # [T, B, C, H, W] | |
| frame_grids = [torchvision.utils.make_grid(video[t], nrow=1) for t in range(video.shape[0])] | |
| grid = torch.stack(frame_grids, dim=0) | |
| grid = (grid + 1.0) / 2.0 | |
| grid = (grid * 255).permute(0, 2, 3, 1).numpy().astype('uint8') | |
| torchvision.io.write_video(out_path, grid, fps=8, video_codec='h264', options={'crf': '10'}) | |
| return out_path | |
| def read_content(file_path: str) -> str: | |
| """read the content of target file | |
| """ | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| return content | |
| demo_exaples = [ | |
| ['eval_data/3d_1.png', 'A bouquet of flowers in a vase.', 'image', 123, 1.0, 50], | |
| ['eval_data/craft_1.png', 'A modern cityscape with towering skyscrapers.', 'image', 124, 1.0, 50], | |
| ['eval_data/digital_art_2.jpeg', 'A lighthouse standing tall on a rocky coast.', 'image', 123, 1.0, 50], | |
| ['eval_data/oil_paint_2.jpg', 'A man playing the guitar on a city street.', 'image', 123, 1.0, 50], | |
| ['eval_data/craft_2.jpg', 'City street at night with bright lights and busy traffic.', 'video', 123, 1.0, 50], | |
| ['eval_data/anime_1.jpg', 'A field of sunflowers on a sunny day.', 'video', 123, 1.0, 50], | |
| ['eval_data/ink_2.jpeg', 'A knight riding a horse through a field.', 'video', 123, 1.0, 50], | |
| ['eval_data/oil_paint_2.jpg', 'A street performer playing the guitar.', 'video', 121, 1.0, 50], | |
| ['eval_data/icon_1.png', 'A campfire surrounded by tents.', 'video', 123, 1.0, 50], | |
| ] | |
| css = """ | |
| #input_img {max-height: 512px} | |
| #output_vid {max-width: 512px;} | |
| """ | |
| with gr.Blocks(analytics_enabled=False, css=css) as demo_iface: | |
| gr.HTML(read_content("header.html")) | |
| with gr.Tab(label='Stylized Generation'): | |
| with gr.Column(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| input_style_ref = gr.Image(label="Style Reference",elem_id="input_img") | |
| with gr.Row(): | |
| input_prompt = gr.Text(label='Prompts') | |
| with gr.Row(): | |
| input_seed = gr.Slider(label='Random Seed', minimum=0, maximum=10000, step=1, value=123) | |
| input_style_strength = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, label='Style Strength', value=1.0) | |
| with gr.Row(): | |
| input_step = gr.Slider(minimum=1, maximum=75, step=1, elem_id="i2v_steps", label="Sampling steps", value=50) | |
| input_type = gr.Radio(choices=["image", "video"], label="Generation Type", value="image") | |
| input_end_btn = gr.Button("Generate") | |
| # with gr.Tab(label='Result'): | |
| with gr.Row(): | |
| output_result = gr.Video(label="Generated Results",elem_id="output_vid",autoplay=True,show_share_button=True) | |
| gr.Examples(examples=demo_exaples, | |
| inputs=[input_style_ref, input_prompt, input_type, input_seed, input_style_strength, input_step], | |
| outputs=[output_result], | |
| fn = infer, | |
| ) | |
| input_end_btn.click(inputs=[input_style_ref, input_prompt, input_type, input_seed, input_style_strength, input_step], | |
| outputs=[output_result], | |
| fn = infer | |
| ) | |
| demo_iface.queue(max_size=12).launch(show_api=True) |