File size: 3,338 Bytes
4e424ea
ca753f0
4e424ea
632fdb4
4e424ea
 
502938a
4e424ea
502938a
 
4e424ea
 
502938a
ca753f0
502938a
4e424ea
502938a
4e424ea
502938a
4e424ea
 
 
f0f4c78
4e424ea
 
 
502938a
 
 
942fdd0
502938a
632fdb4
 
502938a
 
 
 
 
632fdb4
9d714b0
632fdb4
 
 
 
 
 
 
 
502938a
632fdb4
 
665534e
 
 
502938a
 
 
632fdb4
 
502938a
632fdb4
502938a
632fdb4
935512c
632fdb4
502938a
632fdb4
 
 
942fdd0
 
 
 
 
502938a
f0f4c78
502938a
665534e
502938a
942fdd0
f0f4c78
d701afa
4e424ea
942fdd0
 
4e424ea
502938a
4e424ea
 
502938a
c81f025
502938a
4e424ea
 
 
502938a
 
 
4e424ea
 
502938a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import re 
import subprocess
import select
from huggingface_hub import snapshot_download

# Download model (for demonstration, adjust based on actual model needs)
snapshot_download(
    repo_id="Wan-AI/Wan2.1-T2V-1.3B",
    local_dir="./Wan2.1-T2V-1.3B"
)

# Function to generate video
def infer(prompt, progress=gr.Progress(track_tqdm=True)):
    # Reduced progress output and simplified structure
    command = [
        "python", "-u", "-m", "generate",  # Using unbuffered output
        "--task", "t2v-1.3B",
        "--size", "832*480",  # You can try reducing resolution further for CPU
        "--ckpt_dir", "./Wan2.1-T2V-1.3B",
        "--sample_shift", "8",
        "--sample_guide_scale", "6",
        "--prompt", prompt,
        "--save_file", "generated_video.mp4"
    ]

    # Run the model inference in a subprocess
    process = subprocess.Popen(command, 
                               stdout=subprocess.PIPE, 
                               stderr=subprocess.PIPE,  # Capture stderr for error messages
                               text=True, 
                               bufsize=1)

    # Monitor progress with a minimal progress bar
    progress_pattern = re.compile(r"(\d+)%\|.*\| (\d+)/(\d+)")
    video_progress_bar = None
    overall_steps = 0

    while True:
        rlist, _, _ = select.select([process.stdout], [], [], 0.04)
        if rlist:
            line = process.stdout.readline()
            if not line:
                break
            stripped_line = line.strip()
            if not stripped_line:
                continue

            # Check for video generation progress
            progress_match = progress_pattern.search(stripped_line)
            if progress_match:
                current = int(progress_match.group(2))
                total = int(progress_match.group(3))
                if video_progress_bar is None:
                    video_progress_bar = gr.Progress()
                    video_progress_bar.update(current / total)
                video_progress_bar.update(current / total)
                continue

            # Process info messages (simplified)
            if "INFO:" in stripped_line:
                overall_steps += 1
                continue
            else:
                print(stripped_line)

        if process.poll() is not None:
            break

    # Drain any remaining output from stderr
    stderr_output = process.stderr.read().strip()
    if stderr_output:
        print(f"Error output:\n{stderr_output}")

    # Clean up and finalize the progress bar
    process.wait()
    if video_progress_bar:
        video_progress_bar.close()

    # Check if the process finished successfully
    if process.returncode == 0:
        return "generated_video.mp4"
    else:
        print(f"Process failed with return code {process.returncode}")
        raise Exception(f"Error executing command: {stderr_output}")

# Gradio UI
with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("# Wan 2.1 1.3B Video Generation")
        prompt = gr.Textbox(label="Prompt")
        submit_btn = gr.Button("Generate Video")
        video_res = gr.Video(label="Generated Video")

    submit_btn.click(
        fn=infer,
        inputs=[prompt],
        outputs=[video_res]
    )

demo.queue().launch(show_error=True, show_api=False)