|
import spaces |
|
import gradio as gr |
|
import torch |
|
from diffusers import LTXPipeline |
|
import uuid |
|
import time |
|
import types |
|
from typing import Optional |
|
|
|
pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16) |
|
pipe.enable_sequential_cpu_offload() |
|
pipe.to("cuda") |
|
|
|
|
|
|
|
|
|
HEIGHT = 480 |
|
WIDTH = 640 |
|
N_FRAME = 161 |
|
N_AVG_FRAME = 2 |
|
|
|
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" |
|
|
|
prepare_latents_original = pipe.prepare_latents |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def modify_latents_callback(pipeline, step, timestep, callback_kwargs): |
|
print("Rolling latents on step", step) |
|
latents = callback_kwargs.get("latents") |
|
unpacked_latents = pipeline._unpack_latents(latents, (N_FRAME-1)//8+1, HEIGHT//32, WIDTH//32, 1, 1) |
|
modified_latents = torch.roll(unpacked_latents, shifts=1, dims=2) |
|
modified_latents = pipeline._pack_latents(modified_latents) |
|
return {"latents": modified_latents} |
|
|
|
@spaces.GPU(duration=120) |
|
def generate_gif(prompt, use_fixed_seed): |
|
seed = 0 if use_fixed_seed else torch.seed() |
|
generator = torch.Generator(device="cuda").manual_seed(seed) |
|
|
|
output = pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
width=WIDTH, |
|
height=HEIGHT, |
|
num_frames=N_FRAME, |
|
num_inference_steps=50, |
|
decode_timestep=0.03, |
|
decode_noise_scale=0.025, |
|
generator=generator, |
|
callback_on_step_end=modify_latents_callback, |
|
).frames[0] |
|
|
|
gif_path = f"/tmp/{uuid.uuid4().hex}.gif" |
|
|
|
bef = time.time() |
|
|
|
gif_path = f"/tmp/{uuid.uuid4().hex}.webp" |
|
output[0].save(gif_path, format="WebP", save_all=True, append_images=output[1:], duration=1000/24, loop=0) |
|
print("GIF creation time:", time.time() - bef) |
|
return gif_path |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## LTX Video → Looping GIF Generator") |
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt_input = gr.Textbox(label="Prompt", lines=4) |
|
use_fixed_seed = gr.Checkbox(label="Use Fixed Seed", value=True) |
|
generate_btn = gr.Button("Generate") |
|
with gr.Column(): |
|
gif_output = gr.Image(label="Looping GIF Result", type="filepath") |
|
|
|
generate_btn.click( |
|
fn=generate_gif, |
|
inputs=[prompt_input, use_fixed_seed], |
|
outputs=gif_output, |
|
concurrency_limit=1 |
|
) |
|
|
|
demo.queue(max_size=5) |
|
demo.launch(share=True) |
|
|