File size: 3,586 Bytes
15c34d2
88f09e9
15c34d2
 
 
 
 
 
88f09e9
15c34d2
f4d0dd9
15c34d2
88f09e9
15c34d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import spaces
import gradio as gr
import torch
from diffusers import LTXPipeline
import uuid
import time
import types
from typing import Optional

pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16)
pipe.enable_sequential_cpu_offload()
pipe.to("cuda")


# pipe.vae.decode = vae_decode

HEIGHT = 480
WIDTH = 640
N_FRAME = 161
N_AVG_FRAME = 2

negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

prepare_latents_original = pipe.prepare_latents

# unpack will have shape B, C, F, H, W with F, H, W are in latent dim

# def prepare_latents_loop(*args, **kwargs):
#     packed_latents = prepare_latents_original(*args, **kwargs)
#     unpacked_latents = pipe._unpack_latents(packed_latents, (N_FRAME-1)//8+1, HEIGHT//32, WIDTH//32, 1, 1)
#     # now average the first n and last n frames
#     last_n = unpacked_latents[:, :, -N_AVG_FRAME:, :, :]
#     # 0,1,2,3,4, roll -1 => 1,2,3,4,0
#     # last n: [3, 4]
#     # last_next_n: [4, 0]
#     # then 3 will be 0.75*3 + 0.25*4, and 4 will be 0.75*4+0.25*0
#     last_next_n = torch.roll(unpacked_latents, shifts=-1, dims=2)[:, :, -N_AVG_FRAME:, :, :]
#     avg_n = last_n * 0.75 + last_next_n * 0.25
#     unpacked_latents[:, :, -N_AVG_FRAME:, :, :] = avg_n
#     # pack the latents back
#     packed_latents = pipe._pack_latents(unpacked_latents)
#     return packed_latents

# pipe.prepare_latents = prepare_latents_loop

# with the shift it will become step=0 0,1,2,3 -> step=1 1,2,3,0 -> step=2 2,3,0,1 -> step=3 3,0,1,2 -> step=4 0,1,2,3
# so we only shift (N_FRAME-1)//8+1 times

def modify_latents_callback(pipeline, step, timestep, callback_kwargs):
    print("Rolling latents on step", step)
    latents = callback_kwargs.get("latents")
    unpacked_latents = pipeline._unpack_latents(latents, (N_FRAME-1)//8+1, HEIGHT//32, WIDTH//32, 1, 1)
    modified_latents = torch.roll(unpacked_latents, shifts=1, dims=2)
    modified_latents = pipeline._pack_latents(modified_latents)
    return {"latents": modified_latents}

@spaces.GPU(duration=120)
def generate_gif(prompt, use_fixed_seed):
    seed = 0 if use_fixed_seed else torch.seed()
    generator = torch.Generator(device="cuda").manual_seed(seed)

    output = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        width=WIDTH,
        height=HEIGHT,
        num_frames=N_FRAME,
        num_inference_steps=50,
        decode_timestep=0.03,
        decode_noise_scale=0.025,
        generator=generator,
        callback_on_step_end=modify_latents_callback,
    ).frames[0]

    gif_path = f"/tmp/{uuid.uuid4().hex}.gif"

    bef = time.time()
    # imageio.mimsave(gif_path, output, format="GIF", fps=24, loop=0)
    gif_path = f"/tmp/{uuid.uuid4().hex}.webp"
    output[0].save(gif_path, format="WebP", save_all=True, append_images=output[1:], duration=1000/24, loop=0)
    print("GIF creation time:", time.time() - bef)
    return gif_path

with gr.Blocks() as demo:
    gr.Markdown("## LTX Video → Looping GIF Generator")
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(label="Prompt", lines=4)
            use_fixed_seed = gr.Checkbox(label="Use Fixed Seed", value=True)
            generate_btn = gr.Button("Generate")
        with gr.Column():
            gif_output = gr.Image(label="Looping GIF Result", type="filepath")

    generate_btn.click(
        fn=generate_gif,
        inputs=[prompt_input, use_fixed_seed],
        outputs=gif_output,
        concurrency_limit=1
    )

demo.queue(max_size=5)
demo.launch(share=True)