looop / app.py
imthanhlv's picture
Update app.py
f4d0dd9 verified
raw
history blame
3.59 kB
import spaces
import gradio as gr
import torch
from diffusers import LTXPipeline
import uuid
import time
import types
from typing import Optional
pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16)
pipe.enable_sequential_cpu_offload()
pipe.to("cuda")
# pipe.vae.decode = vae_decode
HEIGHT = 480
WIDTH = 640
N_FRAME = 161
N_AVG_FRAME = 2
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
prepare_latents_original = pipe.prepare_latents
# unpack will have shape B, C, F, H, W with F, H, W are in latent dim
# def prepare_latents_loop(*args, **kwargs):
# packed_latents = prepare_latents_original(*args, **kwargs)
# unpacked_latents = pipe._unpack_latents(packed_latents, (N_FRAME-1)//8+1, HEIGHT//32, WIDTH//32, 1, 1)
# # now average the first n and last n frames
# last_n = unpacked_latents[:, :, -N_AVG_FRAME:, :, :]
# # 0,1,2,3,4, roll -1 => 1,2,3,4,0
# # last n: [3, 4]
# # last_next_n: [4, 0]
# # then 3 will be 0.75*3 + 0.25*4, and 4 will be 0.75*4+0.25*0
# last_next_n = torch.roll(unpacked_latents, shifts=-1, dims=2)[:, :, -N_AVG_FRAME:, :, :]
# avg_n = last_n * 0.75 + last_next_n * 0.25
# unpacked_latents[:, :, -N_AVG_FRAME:, :, :] = avg_n
# # pack the latents back
# packed_latents = pipe._pack_latents(unpacked_latents)
# return packed_latents
# pipe.prepare_latents = prepare_latents_loop
# with the shift it will become step=0 0,1,2,3 -> step=1 1,2,3,0 -> step=2 2,3,0,1 -> step=3 3,0,1,2 -> step=4 0,1,2,3
# so we only shift (N_FRAME-1)//8+1 times
def modify_latents_callback(pipeline, step, timestep, callback_kwargs):
print("Rolling latents on step", step)
latents = callback_kwargs.get("latents")
unpacked_latents = pipeline._unpack_latents(latents, (N_FRAME-1)//8+1, HEIGHT//32, WIDTH//32, 1, 1)
modified_latents = torch.roll(unpacked_latents, shifts=1, dims=2)
modified_latents = pipeline._pack_latents(modified_latents)
return {"latents": modified_latents}
@spaces.GPU(duration=120)
def generate_gif(prompt, use_fixed_seed):
seed = 0 if use_fixed_seed else torch.seed()
generator = torch.Generator(device="cuda").manual_seed(seed)
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=WIDTH,
height=HEIGHT,
num_frames=N_FRAME,
num_inference_steps=50,
decode_timestep=0.03,
decode_noise_scale=0.025,
generator=generator,
callback_on_step_end=modify_latents_callback,
).frames[0]
gif_path = f"/tmp/{uuid.uuid4().hex}.gif"
bef = time.time()
# imageio.mimsave(gif_path, output, format="GIF", fps=24, loop=0)
gif_path = f"/tmp/{uuid.uuid4().hex}.webp"
output[0].save(gif_path, format="WebP", save_all=True, append_images=output[1:], duration=1000/24, loop=0)
print("GIF creation time:", time.time() - bef)
return gif_path
with gr.Blocks() as demo:
gr.Markdown("## LTX Video → Looping GIF Generator")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Prompt", lines=4)
use_fixed_seed = gr.Checkbox(label="Use Fixed Seed", value=True)
generate_btn = gr.Button("Generate")
with gr.Column():
gif_output = gr.Image(label="Looping GIF Result", type="filepath")
generate_btn.click(
fn=generate_gif,
inputs=[prompt_input, use_fixed_seed],
outputs=gif_output,
concurrency_limit=1
)
demo.queue(max_size=5)
demo.launch(share=True)