Spaces:
Running
Running
import gradio as gr | |
from gradio import update as gr_update | |
import subprocess | |
import threading | |
import time | |
import re | |
import os | |
import random | |
import tiktoken | |
import sys | |
import ffmpeg | |
from typing import List, Tuple, Optional, Generator, Dict | |
import json | |
from gradio import themes | |
from gradio.themes.utils import colors | |
import subprocess | |
from PIL import Image | |
import math | |
import cv2 | |
import glob | |
import shutil | |
from pathlib import Path | |
import logging | |
from datetime import datetime | |
from tqdm import tqdm | |
# Add global stop event | |
stop_event = threading.Event() | |
logger = logging.getLogger(__name__) | |
def process_hunyuani2v_video( | |
prompt: str, | |
width: int, | |
height: int, | |
batch_size: int, | |
video_length: int, | |
fps: int, | |
infer_steps: int, | |
seed: int, | |
dit_folder: str, | |
model: str, | |
vae: str, | |
te1: str, | |
te2: str, | |
save_path: str, | |
flow_shift: float, | |
cfg_scale: float, | |
output_type: str, | |
attn_mode: str, | |
block_swap: int, | |
exclude_single_blocks: bool, | |
use_split_attn: bool, | |
lora_folder: str, | |
lora1: str = "", | |
lora2: str = "", | |
lora3: str = "", | |
lora4: str = "", | |
lora1_multiplier: float = 1.0, | |
lora2_multiplier: float = 1.0, | |
lora3_multiplier: float = 1.0, | |
lora4_multiplier: float = 1.0, | |
video_path: Optional[str] = None, | |
image_path: Optional[str] = None, | |
strength: Optional[float] = None, | |
negative_prompt: Optional[str] = None, | |
embedded_cfg_scale: Optional[float] = None, | |
split_uncond: Optional[bool] = None, | |
guidance_scale: Optional[float] = None, | |
use_fp8: bool = True, | |
clip_vision_path: Optional[str] = None, | |
i2v_stability: bool = False, | |
fp8_fast: bool = False, | |
compile_model: bool = False, | |
compile_backend: str = "inductor", | |
compile_mode: str = "max-autotune-no-cudagraphs", | |
compile_dynamic: bool = False, | |
compile_fullgraph: bool = False | |
) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: | |
"""Generate a single video with the hunyuani2v script with updated parameters""" | |
global stop_event | |
if stop_event.is_set(): | |
yield [], "", "" | |
return | |
# Determine if this is a SkyReels model and what type | |
is_skyreels = "skyreels" in model.lower() | |
is_skyreels_i2v = is_skyreels and "i2v" in model.lower() | |
is_skyreels_t2v = is_skyreels and "t2v" in model.lower() | |
# Set defaults for hunyuani2v specific parameters | |
if is_skyreels: | |
# Force certain parameters for SkyReels | |
if negative_prompt is None: | |
negative_prompt = "" | |
if embedded_cfg_scale is None: | |
embedded_cfg_scale = 1.0 # Force to 1.0 for SkyReels | |
if split_uncond is None: | |
split_uncond = True | |
if guidance_scale is None: | |
guidance_scale = cfg_scale # Use cfg_scale as guidance_scale if not provided | |
else: | |
embedded_cfg_scale = cfg_scale | |
if os.path.isabs(model): | |
model_path = model | |
else: | |
model_path = os.path.normpath(os.path.join(dit_folder, model)) | |
env = os.environ.copy() | |
env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") | |
env["PYTHONIOENCODING"] = "utf-8" | |
env["BATCH_RUN_ID"] = f"{time.time()}" | |
if seed == -1: | |
current_seed = random.randint(0, 2**32 - 1) | |
else: | |
batch_id = int(env.get("BATCH_RUN_ID", "0").split('.')[-1]) | |
if batch_size > 1: # Only modify seed for batch generation | |
current_seed = (seed + batch_id * 100003) % (2**32) | |
else: | |
current_seed = seed | |
clear_cuda_cache() | |
# Now use hv_generate_video_with_hunyuani2v.py instead | |
command = [ | |
sys.executable, | |
"hv_generate_video_with_hunyuani2v.py", | |
"--dit", model_path, | |
"--vae", vae, | |
"--text_encoder1", te1, | |
"--text_encoder2", te2, | |
"--prompt", prompt, | |
"--video_size", str(height), str(width), | |
"--video_length", str(video_length), | |
"--fps", str(fps), | |
"--infer_steps", str(infer_steps), | |
"--save_path", save_path, | |
"--seed", str(current_seed), | |
"--flow_shift", str(flow_shift), | |
"--embedded_cfg_scale", str(cfg_scale), | |
"--output_type", output_type, | |
"--attn_mode", attn_mode, | |
"--blocks_to_swap", str(block_swap), | |
"--fp8_llm", | |
"--vae_chunk_size", "32", | |
"--vae_spatial_tile_sample_min_size", "128" | |
] | |
if use_fp8: | |
command.append("--fp8") | |
# Add new parameters specific to hunyuani2v script | |
if clip_vision_path: | |
command.extend(["--clip_vision_path", clip_vision_path]) | |
if i2v_stability: | |
command.append("--i2v_stability") | |
if fp8_fast: | |
command.append("--fp8_fast") | |
if compile_model: | |
command.append("--compile") | |
command.extend([ | |
"--compile_args", | |
compile_backend, | |
compile_mode, | |
str(compile_dynamic).lower(), | |
str(compile_fullgraph).lower() | |
]) | |
# Add negative prompt and embedded cfg scale | |
command.extend(["--guidance_scale", str(guidance_scale)]) | |
if negative_prompt: | |
command.extend(["--negative_prompt", negative_prompt]) | |
if split_uncond: | |
command.append("--split_uncond") | |
# Add LoRA weights and multipliers if provided | |
valid_loras = [] | |
for weight, mult in zip([lora1, lora2, lora3, lora4], | |
[lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): | |
if weight and weight != "None": | |
valid_loras.append((os.path.join(lora_folder, weight), mult)) | |
if valid_loras: | |
weights = [weight for weight, _ in valid_loras] | |
multipliers = [str(mult) for _, mult in valid_loras] | |
command.extend(["--lora_weight"] + weights) | |
command.extend(["--lora_multiplier"] + multipliers) | |
if exclude_single_blocks: | |
command.append("--exclude_single_blocks") | |
if use_split_attn: | |
command.append("--split_attn") | |
# Handle input paths | |
if video_path: | |
command.extend(["--video_path", video_path]) | |
if strength is not None: | |
command.extend(["--strength", str(strength)]) | |
elif image_path: | |
command.extend(["--image_path", image_path]) | |
# Only add strength parameter for non-SkyReels I2V models | |
# SkyReels I2V doesn't use strength parameter for image-to-video generation | |
if strength is not None and not is_skyreels_i2v: | |
command.extend(["--strength", str(strength)]) | |
print(f"{command}") | |
p = subprocess.Popen( | |
command, | |
stdout=subprocess.PIPE, | |
stderr=subprocess.STDOUT, | |
env=env, | |
text=True, | |
encoding='utf-8', | |
errors='replace', | |
bufsize=1 | |
) | |
videos = [] | |
while True: | |
if stop_event.is_set(): | |
p.terminate() | |
p.wait() | |
yield [], "", "Generation stopped by user." | |
return | |
line = p.stdout.readline() | |
if not line: | |
if p.poll() is not None: | |
break | |
continue | |
print(line, end='') | |
if '|' in line and '%' in line and '[' in line and ']' in line: | |
yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() | |
p.stdout.close() | |
p.wait() | |
clear_cuda_cache() | |
time.sleep(0.5) | |
# Collect generated video | |
save_path_abs = os.path.abspath(save_path) | |
if os.path.exists(save_path_abs): | |
all_videos = sorted( | |
[f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], | |
key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), | |
reverse=True | |
) | |
matching_videos = [v for v in all_videos if f"_{current_seed}" in v] | |
if matching_videos: | |
video_path = os.path.join(save_path_abs, matching_videos[0]) | |
# Collect parameters for metadata | |
parameters = { | |
"prompt": prompt, | |
"width": width, | |
"height": height, | |
"video_length": video_length, | |
"fps": fps, | |
"infer_steps": infer_steps, | |
"seed": current_seed, | |
"model": model, | |
"vae": vae, | |
"te1": te1, | |
"te2": te2, | |
"save_path": save_path, | |
"flow_shift": flow_shift, | |
"cfg_scale": cfg_scale, | |
"output_type": output_type, | |
"attn_mode": attn_mode, | |
"block_swap": block_swap, | |
"lora_weights": [lora1, lora2, lora3, lora4], | |
"lora_multipliers": [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier], | |
"input_video": video_path if video_path else None, | |
"input_image": image_path if image_path else None, | |
"strength": strength, | |
"negative_prompt": negative_prompt, | |
"embedded_cfg_scale": embedded_cfg_scale, | |
"clip_vision_path": clip_vision_path, | |
"i2v_stability": i2v_stability, | |
"fp8_fast": fp8_fast, | |
"compile_model": compile_model | |
} | |
add_metadata_to_video(video_path, parameters) | |
videos.append((str(video_path), f"Seed: {current_seed}")) | |
yield videos, f"Completed (seed: {current_seed})", "" | |
# Now let's create a new batch processing function that uses the hunyuani2v function | |
def process_hunyuani2v_batch( | |
prompt: str, | |
width: int, | |
height: int, | |
batch_size: int, | |
video_length: int, | |
fps: int, | |
infer_steps: int, | |
seed: int, | |
dit_folder: str, | |
model: str, | |
vae: str, | |
te1: str, | |
te2: str, | |
save_path: str, | |
flow_shift: float, | |
cfg_scale: float, | |
output_type: str, | |
attn_mode: str, | |
block_swap: int, | |
exclude_single_blocks: bool, | |
use_split_attn: bool, | |
lora_folder: str, | |
*args | |
) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: | |
"""Process a batch of videos using the hunyuani2v script""" | |
global stop_event | |
stop_event.clear() | |
all_videos = [] | |
progress_text = "Starting generation..." | |
yield [], "Preparing...", progress_text | |
# Extract additional arguments | |
num_lora_weights = 4 | |
lora_weights = args[:num_lora_weights] | |
lora_multipliers = args[num_lora_weights:num_lora_weights*2] | |
# New parameters for hunyuani2v | |
# Base parameter list index after lora weights and multipliers | |
base_idx = num_lora_weights*2 | |
# Extract parameters | |
input_path = args[base_idx] if len(args) > base_idx else None | |
strength = float(args[base_idx+1]) if len(args) > base_idx+1 and args[base_idx+1] is not None else None | |
negative_prompt = str(args[base_idx+2]) if len(args) > base_idx+2 and args[base_idx+2] is not None else None | |
guidance_scale = float(args[base_idx+3]) if len(args) > base_idx+3 and args[base_idx+3] is not None else cfg_scale | |
split_uncond = bool(args[base_idx+4]) if len(args) > base_idx+4 else None | |
use_fp8 = bool(args[base_idx+5]) if len(args) > base_idx+5 else True | |
# New hunyuani2v parameters | |
clip_vision_path = str(args[base_idx+6]) if len(args) > base_idx+6 and args[base_idx+6] is not None else None | |
i2v_stability = bool(args[base_idx+7]) if len(args) > base_idx+7 else False | |
fp8_fast = bool(args[base_idx+8]) if len(args) > base_idx+8 else False | |
compile_model = bool(args[base_idx+9]) if len(args) > base_idx+9 else False | |
compile_backend = str(args[base_idx+10]) if len(args) > base_idx+10 and args[base_idx+10] is not None else "inductor" | |
compile_mode = str(args[base_idx+11]) if len(args) > base_idx+11 and args[base_idx+11] is not None else "max-autotune-no-cudagraphs" | |
compile_dynamic = bool(args[base_idx+12]) if len(args) > base_idx+12 else False | |
compile_fullgraph = bool(args[base_idx+13]) if len(args) > base_idx+13 else False | |
embedded_cfg_scale = cfg_scale | |
for i in range(batch_size): | |
if stop_event.is_set(): | |
break | |
batch_text = f"Generating video {i + 1} of {batch_size}" | |
yield all_videos.copy(), batch_text, progress_text | |
# Handle different input types | |
video_path = None | |
image_path = None | |
if input_path: | |
is_image = False | |
lower_path = input_path.lower() | |
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp') | |
is_image = any(lower_path.endswith(ext) for ext in image_extensions) | |
if is_image: | |
image_path = input_path | |
else: | |
video_path = input_path | |
# Prepare arguments for process_hunyuani2v_video | |
current_seed = seed + i if seed != -1 and batch_size > 1 else seed if seed != -1 else -1 | |
hunyuani2v_args = [ | |
prompt, width, height, batch_size, video_length, fps, infer_steps, | |
current_seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, | |
output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, | |
lora_folder | |
] | |
hunyuani2v_args.extend(lora_weights) | |
hunyuani2v_args.extend(lora_multipliers) | |
hunyuani2v_args.extend([ | |
video_path, image_path, strength, negative_prompt, embedded_cfg_scale, | |
split_uncond, guidance_scale, use_fp8, clip_vision_path, i2v_stability, | |
fp8_fast, compile_model, compile_backend, compile_mode, compile_dynamic, compile_fullgraph | |
]) | |
for videos, status, progress in process_hunyuani2v_video(*hunyuani2v_args): | |
if videos: | |
all_videos.extend(videos) | |
yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress | |
yield all_videos, "Batch complete", "" | |
def variance_of_laplacian(image): | |
""" | |
Compute the variance of the Laplacian of the image. | |
Higher variance indicates a sharper image. | |
""" | |
return cv2.Laplacian(image, cv2.CV_64F).var() | |
def extract_sharpest_frame(video_path, frames_to_check=30): | |
""" | |
Extract the sharpest frame from the last N frames of the video. | |
Args: | |
video_path (str): Path to the video file | |
frames_to_check (int): Number of frames from the end to check | |
Returns: | |
tuple: (temp_image_path, frame_number, sharpness_score) | |
""" | |
print(f"\n=== Extracting sharpest frame from the last {frames_to_check} frames ===") | |
print(f"Input video path: {video_path}") | |
if not video_path or not os.path.exists(video_path): | |
print("❌ Error: Video file does not exist") | |
return None, None, None | |
try: | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
print("❌ Error: Failed to open video file") | |
return None, None, None | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
print(f"Total frames detected: {total_frames}, FPS: {fps:.2f}") | |
if total_frames < 1: | |
print("❌ Error: Video contains 0 frames") | |
return None, None, None | |
# Determine how many frames to check (the last N frames) | |
if frames_to_check > total_frames: | |
frames_to_check = total_frames | |
start_frame = 0 | |
else: | |
start_frame = total_frames - frames_to_check | |
print(f"Checking frames {start_frame} to {total_frames-1}") | |
# Find the sharpest frame | |
sharpest_frame = None | |
max_sharpness = -1 | |
sharpest_frame_number = -1 | |
# Set starting position | |
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) | |
# Process frames with a progress bar | |
with tqdm(total=frames_to_check, desc="Finding sharpest frame") as pbar: | |
frame_idx = start_frame | |
while frame_idx < total_frames: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Convert to grayscale and calculate sharpness | |
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | |
sharpness = variance_of_laplacian(gray) | |
# Update if this is the sharpest frame so far | |
if sharpness > max_sharpness: | |
max_sharpness = sharpness | |
sharpest_frame = frame.copy() | |
sharpest_frame_number = frame_idx | |
frame_idx += 1 | |
pbar.update(1) | |
cap.release() | |
if sharpest_frame is None: | |
print("❌ Error: Failed to find a sharp frame") | |
return None, None, None | |
# Prepare output path | |
temp_dir = os.path.abspath("temp_frames") | |
os.makedirs(temp_dir, exist_ok=True) | |
temp_path = os.path.join(temp_dir, f"sharpest_frame_{os.path.basename(video_path)}.png") | |
print(f"Saving frame to: {temp_path}") | |
# Write and verify | |
if not cv2.imwrite(temp_path, sharpest_frame): | |
print("❌ Error: Failed to write frame to file") | |
return None, None, None | |
if not os.path.exists(temp_path): | |
print("❌ Error: Output file not created") | |
return None, None, None | |
# Calculate frame time in seconds | |
frame_time = sharpest_frame_number / fps | |
print(f"✅ Extracted sharpest frame: {sharpest_frame_number} (at {frame_time:.2f}s) with sharpness {max_sharpness:.2f}") | |
return temp_path, sharpest_frame_number, max_sharpness | |
except Exception as e: | |
print(f"❌ Unexpected error: {str(e)}") | |
return None, None, None | |
finally: | |
if 'cap' in locals(): | |
cap.release() | |
def trim_video_to_frame(video_path, frame_number, output_dir="outputs"): | |
""" | |
Trim video up to the specified frame and save as a new video. | |
Args: | |
video_path (str): Path to the video file | |
frame_number (int): Frame number to trim to | |
output_dir (str): Directory to save the trimmed video | |
Returns: | |
str: Path to the trimmed video file | |
""" | |
print(f"\n=== Trimming video to frame {frame_number} ===") | |
if not video_path or not os.path.exists(video_path): | |
print("❌ Error: Video file does not exist") | |
return None | |
try: | |
# Get video information | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
print("❌ Error: Failed to open video file") | |
return None | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
cap.release() | |
# Calculate time in seconds | |
time_seconds = frame_number / fps | |
# Create output directory if it doesn't exist | |
os.makedirs(output_dir, exist_ok=True) | |
# Generate output filename | |
timestamp = f"{int(time_seconds)}s" | |
base_name = Path(video_path).stem | |
output_file = os.path.join(output_dir, f"{base_name}_trimmed_to_{timestamp}.mp4") | |
# Use ffmpeg to trim the video | |
( | |
ffmpeg | |
.input(video_path) | |
.output(output_file, to=time_seconds, c="copy") | |
.global_args('-y') # Overwrite output files | |
.run(quiet=True) | |
) | |
if not os.path.exists(output_file): | |
print("❌ Error: Failed to create trimmed video") | |
return None | |
print(f"✅ Successfully trimmed video to {time_seconds:.2f}s: {output_file}") | |
return output_file | |
except Exception as e: | |
print(f"❌ Error trimming video: {str(e)}") | |
return None | |
def send_sharpest_frame_handler(gallery, selected_idx, frames_to_check=30): | |
""" | |
Extract the sharpest frame from the last N frames of the selected video | |
Args: | |
gallery: Gradio gallery component with videos | |
selected_idx: Index of the selected video | |
frames_to_check: Number of frames from the end to check | |
Returns: | |
tuple: (image_path, video_path, frame_number, sharpness) | |
""" | |
if gallery is None or not gallery: | |
return None, None, None, "No videos in gallery" | |
if selected_idx is None and len(gallery) == 1: | |
selected_idx = 0 | |
if selected_idx is None or selected_idx >= len(gallery): | |
return None, None, None, "No video selected" | |
# Get the video path | |
item = gallery[selected_idx] | |
if isinstance(item, tuple): | |
video_path = item[0] | |
elif isinstance(item, dict): | |
video_path = item.get('name') or item.get('data') | |
else: | |
video_path = str(item) | |
# Extract the sharpest frame | |
image_path, frame_number, sharpness = extract_sharpest_frame(video_path, frames_to_check) | |
if image_path is None: | |
return None, None, None, "Failed to extract sharpest frame" | |
return image_path, video_path, frame_number, f"Extracted frame {frame_number} with sharpness {sharpness:.2f}" | |
def trim_and_prepare_for_extension(video_path, frame_number, save_path="outputs"): | |
""" | |
Trim the video to the specified frame and prepare for extension. | |
Args: | |
video_path: Path to the video file | |
frame_number: Frame number to trim to | |
save_path: Directory to save the trimmed video | |
Returns: | |
tuple: (trimmed_video_path, status_message) | |
""" | |
if not video_path or not os.path.exists(video_path): | |
return None, "No video selected or video file does not exist" | |
if frame_number is None: | |
return None, "No frame number provided, please extract sharpest frame first" | |
# Trim the video | |
trimmed_video = trim_video_to_frame(video_path, frame_number, save_path) | |
if trimmed_video is None: | |
return None, "Failed to trim video" | |
return trimmed_video, f"Video trimmed to frame {frame_number} and ready for extension" | |
def send_last_frame_handler(gallery, selected_idx): | |
"""Handle sending last frame to input with better error handling""" | |
if gallery is None or not gallery: | |
return None, None | |
if selected_idx is None and len(gallery) == 1: | |
selected_idx = 0 | |
if selected_idx is None or selected_idx >= len(gallery): | |
return None, None | |
# Get the frame and video path | |
frame = handle_last_frame_transfer(gallery, selected_idx) | |
video_path = None | |
if selected_idx < len(gallery): | |
item = gallery[selected_idx] | |
video_path = parse_video_path(item) | |
return frame, video_path | |
def extract_last_frame(video_path: str) -> Optional[str]: | |
"""Extract last frame from video and return temporary image path with error handling""" | |
print(f"\n=== Starting frame extraction ===") | |
print(f"Input video path: {video_path}") | |
if not video_path or not os.path.exists(video_path): | |
print("❌ Error: Video file does not exist") | |
return None | |
try: | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
print("❌ Error: Failed to open video file") | |
return None | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
print(f"Total frames detected: {total_frames}") | |
if total_frames < 1: | |
print("❌ Error: Video contains 0 frames") | |
return None | |
# Extract last frame | |
cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - 1) | |
success, frame = cap.read() | |
if not success or frame is None: | |
print("❌ Error: Failed to read last frame") | |
return None | |
# Prepare output path | |
temp_dir = os.path.abspath("temp_frames") | |
os.makedirs(temp_dir, exist_ok=True) | |
temp_path = os.path.join(temp_dir, f"last_frame_{os.path.basename(video_path)}.png") | |
print(f"Saving frame to: {temp_path}") | |
# Write and verify | |
if not cv2.imwrite(temp_path, frame): | |
print("❌ Error: Failed to write frame to file") | |
return None | |
if not os.path.exists(temp_path): | |
print("❌ Error: Output file not created") | |
return None | |
print("✅ Frame extraction successful") | |
return temp_path | |
except Exception as e: | |
print(f"❌ Unexpected error: {str(e)}") | |
return None | |
finally: | |
if 'cap' in locals(): | |
cap.release() | |
def handle_last_frame_transfer(gallery: list, selected_idx: int) -> Optional[str]: | |
"""Improved frame transfer with video input validation""" | |
try: | |
if gallery is None or not gallery: | |
raise ValueError("No videos generated yet") | |
if selected_idx is None: | |
# Auto-select last generated video if batch_size=1 | |
if len(gallery) == 1: | |
selected_idx = 0 | |
else: | |
raise ValueError("Please select a video first") | |
if selected_idx >= len(gallery): | |
raise ValueError("Invalid selection index") | |
item = gallery[selected_idx] | |
# Video file existence check | |
video_path = parse_video_path(item) | |
if not os.path.exists(video_path): | |
raise FileNotFoundError(f"Video file missing: {video_path}") | |
return extract_last_frame(video_path) | |
except Exception as e: | |
print(f"Frame transfer failed: {str(e)}") | |
return None | |
def parse_video_path(item) -> str: | |
"""Parse different gallery item formats""" | |
if isinstance(item, tuple): | |
return item[0] | |
elif isinstance(item, dict): | |
return item.get('name') or item.get('data') | |
return str(item) | |
def get_random_image_from_folder(folder_path): | |
"""Get a random image from the specified folder""" | |
if not os.path.isdir(folder_path): | |
return None, f"Error: {folder_path} is not a valid directory" | |
# Get all image files in the folder | |
image_files = [] | |
for ext in ('*.jpg', '*.jpeg', '*.png', '*.bmp', '*.webp'): | |
image_files.extend(glob.glob(os.path.join(folder_path, ext))) | |
for ext in ('*.JPG', '*.JPEG', '*.PNG', '*.BMP', '*.WEBP'): | |
image_files.extend(glob.glob(os.path.join(folder_path, ext))) | |
if not image_files: | |
return None, f"Error: No image files found in {folder_path}" | |
# Select a random image | |
random_image = random.choice(image_files) | |
return random_image, f"Selected: {os.path.basename(random_image)}" | |
def resize_image_keeping_aspect_ratio(image_path, max_width, max_height): | |
"""Resize image keeping aspect ratio and ensuring dimensions are divisible by 16""" | |
try: | |
img = Image.open(image_path) | |
width, height = img.size | |
# Calculate aspect ratio | |
aspect_ratio = width / height | |
# Calculate new dimensions while maintaining aspect ratio | |
if width > height: | |
new_width = min(max_width, width) | |
new_height = int(new_width / aspect_ratio) | |
else: | |
new_height = min(max_height, height) | |
new_width = int(new_height * aspect_ratio) | |
# Make dimensions divisible by 16 | |
new_width = math.floor(new_width / 16) * 16 | |
new_height = math.floor(new_height / 16) * 16 | |
# Ensure minimum size | |
new_width = max(16, new_width) | |
new_height = max(16, new_height) | |
# Resize image | |
resized_img = img.resize((new_width, new_height), Image.LANCZOS) | |
# Save to temporary file | |
temp_path = f"temp_resized_{os.path.basename(image_path)}" | |
resized_img.save(temp_path) | |
return temp_path, (new_width, new_height) | |
except Exception as e: | |
return None, f"Error: {str(e)}" | |
# Function to process a batch of images from a folder | |
def batch_handler( | |
use_random, | |
prompt, negative_prompt, | |
width, height, | |
video_length, fps, infer_steps, | |
seed, flow_shift, guidance_scale, embedded_cfg_scale, | |
batch_size, input_folder_path, | |
dit_folder, model, vae, te1, te2, save_path, output_type, attn_mode, | |
block_swap, exclude_single_blocks, use_split_attn, use_fp8, split_uncond, | |
lora_folder, *lora_params | |
): | |
"""Handle both folder-based batch processing and regular batch processing""" | |
global stop_event | |
# Check if this is a SkyReels model that needs special handling | |
is_skyreels = "skyreels" in model.lower() | |
is_skyreels_i2v = is_skyreels and "i2v" in model.lower() | |
if use_random: | |
# Random image from folder mode | |
stop_event.clear() | |
all_videos = [] | |
progress_text = "Starting generation..." | |
yield [], "Preparing...", progress_text | |
for i in range(batch_size): | |
if stop_event.is_set(): | |
break | |
batch_text = f"Generating video {i + 1} of {batch_size}" | |
yield all_videos.copy(), batch_text, progress_text | |
# Get random image from folder | |
random_image, status = get_random_image_from_folder(input_folder_path) | |
if random_image is None: | |
yield all_videos, f"Error in batch {i+1}: {status}", "" | |
continue | |
# Resize image | |
resized_image, size_info = resize_image_keeping_aspect_ratio(random_image, width, height) | |
if resized_image is None: | |
yield all_videos, f"Error resizing image in batch {i+1}: {size_info}", "" | |
continue | |
# If we have dimensions, update them | |
local_width, local_height = width, height | |
if isinstance(size_info, tuple): | |
local_width, local_height = size_info | |
progress_text = f"Using image: {os.path.basename(random_image)} - Resized to {local_width}x{local_height}" | |
else: | |
progress_text = f"Using image: {os.path.basename(random_image)}" | |
yield all_videos.copy(), batch_text, progress_text | |
# Calculate seed for this batch item | |
current_seed = seed | |
if seed == -1: | |
current_seed = random.randint(0, 2**32 - 1) | |
elif batch_size > 1: | |
current_seed = seed + i | |
# Process the image | |
# For SkyReels models, we need to create a command with dit_in_channels=32 | |
if is_skyreels_i2v: | |
env = os.environ.copy() | |
env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") | |
env["PYTHONIOENCODING"] = "utf-8" | |
model_path = os.path.join(dit_folder, model) if not os.path.isabs(model) else model | |
# Extract parameters from lora_params | |
num_lora_weights = 4 | |
lora_weights = lora_params[:num_lora_weights] | |
lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] | |
cmd = [ | |
sys.executable, | |
"hv_generate_video.py", | |
"--dit", model_path, | |
"--vae", vae, | |
"--text_encoder1", te1, | |
"--text_encoder2", te2, | |
"--prompt", prompt, | |
"--video_size", str(local_height), str(local_width), | |
"--video_length", str(video_length), | |
"--fps", str(fps), | |
"--infer_steps", str(infer_steps), | |
"--save_path", save_path, | |
"--seed", str(current_seed), | |
"--flow_shift", str(flow_shift), | |
"--embedded_cfg_scale", str(embedded_cfg_scale), | |
"--output_type", output_type, | |
"--attn_mode", attn_mode, | |
"--blocks_to_swap", str(block_swap), | |
"--fp8_llm", | |
"--vae_chunk_size", "32", | |
"--vae_spatial_tile_sample_min_size", "128", | |
"--dit_in_channels", "32", # This is crucial for SkyReels i2v | |
"--image_path", resized_image # Pass the image directly | |
] | |
if use_fp8: | |
cmd.append("--fp8") | |
if split_uncond: | |
cmd.append("--split_uncond") | |
if use_split_attn: | |
cmd.append("--split_attn") | |
if exclude_single_blocks: | |
cmd.append("--exclude_single_blocks") | |
if negative_prompt: | |
cmd.extend(["--negative_prompt", negative_prompt]) | |
if guidance_scale is not None: | |
cmd.extend(["--guidance_scale", str(guidance_scale)]) | |
# Add LoRA weights and multipliers if provided | |
valid_loras = [] | |
for weight, mult in zip(lora_weights, lora_multipliers): | |
if weight and weight != "None": | |
valid_loras.append((os.path.join(lora_folder, weight), mult)) | |
if valid_loras: | |
weights = [weight for weight, _ in valid_loras] | |
multipliers = [str(mult) for _, mult in valid_loras] | |
cmd.extend(["--lora_weight"] + weights) | |
cmd.extend(["--lora_multiplier"] + multipliers) | |
print(f"Running command: {' '.join(cmd)}") | |
# Run the process | |
p = subprocess.Popen( | |
cmd, | |
stdout=subprocess.PIPE, | |
stderr=subprocess.STDOUT, | |
env=env, | |
text=True, | |
encoding='utf-8', | |
errors='replace', | |
bufsize=1 | |
) | |
while True: | |
if stop_event.is_set(): | |
p.terminate() | |
p.wait() | |
yield all_videos, "Generation stopped by user.", "" | |
return | |
line = p.stdout.readline() | |
if not line: | |
if p.poll() is not None: | |
break | |
continue | |
print(line, end='') | |
if '|' in line and '%' in line and '[' in line and ']' in line: | |
yield all_videos.copy(), f"Processing video {i+1} (seed: {current_seed})", line.strip() | |
p.stdout.close() | |
p.wait() | |
# Collect generated video | |
save_path_abs = os.path.abspath(save_path) | |
if os.path.exists(save_path_abs): | |
all_videos_files = sorted( | |
[f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], | |
key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), | |
reverse=True | |
) | |
matching_videos = [v for v in all_videos_files if f"_{current_seed}" in v] | |
if matching_videos: | |
video_path = os.path.join(save_path_abs, matching_videos[0]) | |
all_videos.append((str(video_path), f"Seed: {current_seed}")) | |
else: | |
# For non-SkyReels models, use the regular process_single_video function | |
num_lora_weights = 4 | |
lora_weights = lora_params[:num_lora_weights] | |
lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] | |
single_video_args = [ | |
prompt, local_width, local_height, 1, video_length, fps, infer_steps, | |
current_seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, embedded_cfg_scale, | |
output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, | |
lora_folder | |
] | |
single_video_args.extend(lora_weights) | |
single_video_args.extend(lora_multipliers) | |
single_video_args.extend([None, resized_image, None, negative_prompt, embedded_cfg_scale, split_uncond, guidance_scale, use_fp8]) | |
for videos, status, progress in process_single_video(*single_video_args): | |
if videos: | |
all_videos.extend(videos) | |
yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress | |
# Clean up temporary file | |
try: | |
if os.path.exists(resized_image): | |
os.remove(resized_image) | |
except: | |
pass | |
# Clear CUDA cache between generations | |
clear_cuda_cache() | |
time.sleep(0.5) | |
yield all_videos, "Batch complete", "" | |
else: | |
# Regular image input - this is the part we need to fix | |
# When a SkyReels I2V model is used, we need to use the direct command approach | |
# with dit_in_channels=32 explicitly specified, just like in the folder processing branch | |
if is_skyreels_i2v: | |
stop_event.clear() | |
all_videos = [] | |
progress_text = "Starting generation..." | |
yield [], "Preparing...", progress_text | |
# Extract lora parameters | |
num_lora_weights = 4 | |
lora_weights = lora_params[:num_lora_weights] | |
lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] | |
extra_args = list(lora_params[num_lora_weights*2:]) if len(lora_params) > num_lora_weights*2 else [] | |
# Print extra_args for debugging | |
print(f"Extra args: {extra_args}") | |
# Get input image path from extra args - this is where we need to fix | |
# In skyreels_generate_btn.click, we're passing skyreels_input which | |
# should be the image path | |
image_path = None | |
if len(extra_args) > 0 and extra_args[0] is not None: | |
image_path = extra_args[0] | |
print(f"Image path found in extra_args[0]: {image_path}") | |
# If we still don't have an image path, this is a problem | |
if not image_path: | |
# Let's try to debug what's happening - in the future, you can remove these | |
# debug prints once everything works correctly | |
print("No image path found in extra_args[0]") | |
print(f"Full lora_params: {lora_params}") | |
yield [], "Error: No input image provided", "An input image is required for SkyReels I2V models" | |
return | |
for i in range(batch_size): | |
if stop_event.is_set(): | |
yield all_videos, "Generation stopped by user", "" | |
return | |
# Calculate seed for this batch item | |
current_seed = seed | |
if seed == -1: | |
current_seed = random.randint(0, 2**32 - 1) | |
elif batch_size > 1: | |
current_seed = seed + i | |
batch_text = f"Generating video {i + 1} of {batch_size}" | |
yield all_videos.copy(), batch_text, progress_text | |
# Set up environment | |
env = os.environ.copy() | |
env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") | |
env["PYTHONIOENCODING"] = "utf-8" | |
model_path = os.path.join(dit_folder, model) if not os.path.isabs(model) else model | |
# Build the command with dit_in_channels=32 | |
cmd = [ | |
sys.executable, | |
"hv_generate_video.py", | |
"--dit", model_path, | |
"--vae", vae, | |
"--text_encoder1", te1, | |
"--text_encoder2", te2, | |
"--prompt", prompt, | |
"--video_size", str(height), str(width), | |
"--video_length", str(video_length), | |
"--fps", str(fps), | |
"--infer_steps", str(infer_steps), | |
"--save_path", save_path, | |
"--seed", str(current_seed), | |
"--flow_shift", str(flow_shift), | |
"--embedded_cfg_scale", str(embedded_cfg_scale), | |
"--output_type", output_type, | |
"--attn_mode", attn_mode, | |
"--blocks_to_swap", str(block_swap), | |
"--fp8_llm", | |
"--vae_chunk_size", "32", | |
"--vae_spatial_tile_sample_min_size", "128", | |
"--dit_in_channels", "32", # This is crucial for SkyReels i2v | |
"--image_path", image_path | |
] | |
if use_fp8: | |
cmd.append("--fp8") | |
if split_uncond: | |
cmd.append("--split_uncond") | |
if use_split_attn: | |
cmd.append("--split_attn") | |
if exclude_single_blocks: | |
cmd.append("--exclude_single_blocks") | |
if negative_prompt: | |
cmd.extend(["--negative_prompt", negative_prompt]) | |
if guidance_scale is not None: | |
cmd.extend(["--guidance_scale", str(guidance_scale)]) | |
# Add LoRA weights and multipliers if provided | |
valid_loras = [] | |
for weight, mult in zip(lora_weights, lora_multipliers): | |
if weight and weight != "None": | |
valid_loras.append((os.path.join(lora_folder, weight), mult)) | |
if valid_loras: | |
weights = [weight for weight, _ in valid_loras] | |
multipliers = [str(mult) for _, mult in valid_loras] | |
cmd.extend(["--lora_weight"] + weights) | |
cmd.extend(["--lora_multiplier"] + multipliers) | |
print(f"Running command: {' '.join(cmd)}") | |
# Run the process | |
p = subprocess.Popen( | |
cmd, | |
stdout=subprocess.PIPE, | |
stderr=subprocess.STDOUT, | |
env=env, | |
text=True, | |
encoding='utf-8', | |
errors='replace', | |
bufsize=1 | |
) | |
while True: | |
if stop_event.is_set(): | |
p.terminate() | |
p.wait() | |
yield all_videos, "Generation stopped by user.", "" | |
return | |
line = p.stdout.readline() | |
if not line: | |
if p.poll() is not None: | |
break | |
continue | |
print(line, end='') | |
if '|' in line and '%' in line and '[' in line and ']' in line: | |
yield all_videos.copy(), f"Processing (seed: {current_seed})", line.strip() | |
p.stdout.close() | |
p.wait() | |
# Collect generated video | |
save_path_abs = os.path.abspath(save_path) | |
if os.path.exists(save_path_abs): | |
all_videos_files = sorted( | |
[f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], | |
key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), | |
reverse=True | |
) | |
matching_videos = [v for v in all_videos_files if f"_{current_seed}" in v] | |
if matching_videos: | |
video_path = os.path.join(save_path_abs, matching_videos[0]) | |
all_videos.append((str(video_path), f"Seed: {current_seed}")) | |
# Clear CUDA cache between generations | |
clear_cuda_cache() | |
time.sleep(0.5) | |
yield all_videos, "Batch complete", "" | |
else: | |
# For regular non-SkyReels models, use the original process_batch function | |
regular_args = [ | |
prompt, width, height, batch_size, video_length, fps, infer_steps, | |
seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, guidance_scale, | |
output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, | |
lora_folder | |
] | |
yield from process_batch(*(regular_args + list(lora_params))) | |
def get_dit_models(dit_folder: str) -> List[str]: | |
"""Get list of available DiT models in the specified folder""" | |
if not os.path.exists(dit_folder): | |
return ["mp_rank_00_model_states.pt"] | |
models = [f for f in os.listdir(dit_folder) if f.endswith('.pt') or f.endswith('.safetensors')] | |
models.sort(key=str.lower) | |
return models if models else ["mp_rank_00_model_states.pt"] | |
def update_dit_and_lora_dropdowns(dit_folder: str, lora_folder: str, *current_values) -> List[gr.update]: | |
"""Update both DiT and LoRA dropdowns""" | |
# Get model lists | |
dit_models = get_dit_models(dit_folder) | |
lora_choices = get_lora_options(lora_folder) | |
# Current values processing | |
dit_value = current_values[0] | |
if dit_value not in dit_models: | |
dit_value = dit_models[0] if dit_models else None | |
weights = current_values[1:5] | |
multipliers = current_values[5:9] | |
results = [gr.update(choices=dit_models, value=dit_value)] | |
# Add LoRA updates | |
for i in range(4): | |
weight = weights[i] if i < len(weights) else "None" | |
multiplier = multipliers[i] if i < len(multipliers) else 1.0 | |
if weight not in lora_choices: | |
weight = "None" | |
results.extend([ | |
gr.update(choices=lora_choices, value=weight), | |
gr.update(value=multiplier) | |
]) | |
return results | |
def extract_video_metadata(video_path: str) -> Dict: | |
"""Extract metadata from video file using ffprobe.""" | |
cmd = [ | |
'ffprobe', | |
'-v', 'quiet', | |
'-print_format', 'json', | |
'-show_format', | |
video_path | |
] | |
try: | |
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) | |
metadata = json.loads(result.stdout.decode('utf-8')) | |
if 'format' in metadata and 'tags' in metadata['format']: | |
comment = metadata['format']['tags'].get('comment', '{}') | |
return json.loads(comment) | |
return {} | |
except Exception as e: | |
print(f"Metadata extraction failed: {str(e)}") | |
return {} | |
def create_parameter_transfer_map(metadata: Dict, target_tab: str) -> Dict: | |
"""Map metadata parameters to Gradio components for different tabs""" | |
mapping = { | |
'common': { | |
'prompt': ('prompt', 'v2v_prompt'), | |
'width': ('width', 'v2v_width'), | |
'height': ('height', 'v2v_height'), | |
'batch_size': ('batch_size', 'v2v_batch_size'), | |
'video_length': ('video_length', 'v2v_video_length'), | |
'fps': ('fps', 'v2v_fps'), | |
'infer_steps': ('infer_steps', 'v2v_infer_steps'), | |
'seed': ('seed', 'v2v_seed'), | |
'model': ('model', 'v2v_model'), | |
'vae': ('vae', 'v2v_vae'), | |
'te1': ('te1', 'v2v_te1'), | |
'te2': ('te2', 'v2v_te2'), | |
'save_path': ('save_path', 'v2v_save_path'), | |
'flow_shift': ('flow_shift', 'v2v_flow_shift'), | |
'cfg_scale': ('cfg_scale', 'v2v_cfg_scale'), | |
'output_type': ('output_type', 'v2v_output_type'), | |
'attn_mode': ('attn_mode', 'v2v_attn_mode'), | |
'block_swap': ('block_swap', 'v2v_block_swap') | |
}, | |
'lora': { | |
'lora_weights': [(f'lora{i+1}', f'v2v_lora_weights[{i}]') for i in range(4)], | |
'lora_multipliers': [(f'lora{i+1}_multiplier', f'v2v_lora_multipliers[{i}]') for i in range(4)] | |
} | |
} | |
results = {} | |
for param, value in metadata.items(): | |
# Handle common parameters | |
if param in mapping['common']: | |
target = mapping['common'][param][0 if target_tab == 't2v' else 1] | |
results[target] = value | |
# Handle LoRA parameters | |
if param == 'lora_weights': | |
for i, weight in enumerate(value[:4]): | |
target = mapping['lora']['lora_weights'][i][1 if target_tab == 'v2v' else 0] | |
results[target] = weight | |
if param == 'lora_multipliers': | |
for i, mult in enumerate(value[:4]): | |
target = mapping['lora']['lora_multipliers'][i][1 if target_tab == 'v2v' else 0] | |
results[target] = float(mult) | |
return results | |
def add_metadata_to_video(video_path: str, parameters: dict) -> None: | |
"""Add generation parameters to video metadata using ffmpeg.""" | |
import json | |
import subprocess | |
# Convert parameters to JSON string | |
params_json = json.dumps(parameters, indent=2) | |
# Temporary output path | |
temp_path = video_path.replace(".mp4", "_temp.mp4") | |
# FFmpeg command to add metadata without re-encoding | |
cmd = [ | |
'ffmpeg', | |
'-i', video_path, | |
'-metadata', f'comment={params_json}', | |
'-codec', 'copy', | |
temp_path | |
] | |
try: | |
# Execute FFmpeg command | |
subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
# Replace original file with the metadata-enhanced version | |
os.replace(temp_path, video_path) | |
except subprocess.CalledProcessError as e: | |
print(f"Failed to add metadata: {e.stderr.decode()}") | |
if os.path.exists(temp_path): | |
os.remove(temp_path) | |
except Exception as e: | |
print(f"Error: {str(e)}") | |
def count_prompt_tokens(prompt: str) -> int: | |
enc = tiktoken.get_encoding("cl100k_base") | |
tokens = enc.encode(prompt) | |
return len(tokens) | |
def get_lora_options(lora_folder: str = "lora") -> List[str]: | |
if not os.path.exists(lora_folder): | |
return ["None"] | |
lora_files = [f for f in os.listdir(lora_folder) if f.endswith('.safetensors') or f.endswith('.pt')] | |
lora_files.sort(key=str.lower) | |
return ["None"] + lora_files | |
def update_lora_dropdowns(lora_folder: str, *current_values) -> List[gr.update]: | |
new_choices = get_lora_options(lora_folder) | |
weights = current_values[:4] | |
multipliers = current_values[4:8] | |
results = [] | |
for i in range(4): | |
weight = weights[i] if i < len(weights) else "None" | |
multiplier = multipliers[i] if i < len(multipliers) else 1.0 | |
if weight not in new_choices: | |
weight = "None" | |
results.extend([ | |
gr.update(choices=new_choices, value=weight), | |
gr.update(value=multiplier) | |
]) | |
return results | |
def send_to_v2v(evt: gr.SelectData, gallery: list, prompt: str, selected_index: gr.State) -> Tuple[Optional[str], str, int]: | |
"""Transfer selected video and prompt to Video2Video tab""" | |
if not gallery or evt.index >= len(gallery): | |
return None, "", selected_index.value | |
selected_item = gallery[evt.index] | |
# Handle different gallery item formats | |
if isinstance(selected_item, dict): | |
video_path = selected_item.get("name", selected_item.get("data", None)) | |
elif isinstance(selected_item, (tuple, list)): | |
video_path = selected_item[0] | |
else: | |
video_path = selected_item | |
# Final cleanup for Gradio Video component | |
if isinstance(video_path, tuple): | |
video_path = video_path[0] | |
# Update the selected index | |
selected_index.value = evt.index | |
return str(video_path), prompt, evt.index | |
def send_selected_to_v2v(gallery: list, prompt: str, selected_index: gr.State) -> Tuple[Optional[str], str]: | |
"""Send the currently selected video to V2V tab""" | |
if not gallery or selected_index.value is None or selected_index.value >= len(gallery): | |
return None, "" | |
selected_item = gallery[selected_index.value] | |
# Handle different gallery item formats | |
if isinstance(selected_item, dict): | |
video_path = selected_item.get("name", selected_item.get("data", None)) | |
elif isinstance(selected_item, (tuple, list)): | |
video_path = selected_item[0] | |
else: | |
video_path = selected_item | |
# Final cleanup for Gradio Video component | |
if isinstance(video_path, tuple): | |
video_path = video_path[0] | |
return str(video_path), prompt | |
def clear_cuda_cache(): | |
"""Clear CUDA cache if available""" | |
import torch | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Optional: synchronize to ensure cache is cleared | |
torch.cuda.synchronize() | |
def wanx_batch_handler( | |
use_random, | |
prompt, | |
negative_prompt, | |
width, | |
height, | |
video_length, | |
fps, | |
infer_steps, | |
flow_shift, | |
guidance_scale, | |
seed, | |
batch_size, | |
input_folder_path, | |
task, | |
dit_path, | |
vae_path, | |
t5_path, | |
clip_path, | |
save_path, | |
output_type, | |
sample_solver, | |
exclude_single_blocks, | |
attn_mode, | |
block_swap, | |
fp8, | |
fp8_t5, | |
lora_folder, | |
*lora_params | |
): | |
"""Handle both folder-based batch processing and regular processing for WanX""" | |
global stop_event | |
if use_random: | |
# Random image from folder mode | |
stop_event.clear() | |
all_videos = [] | |
progress_text = "Starting generation..." | |
yield [], "Preparing...", progress_text | |
# Ensure batch_size is treated as an integer | |
batch_size = int(batch_size) | |
# Process each item in the batch separately | |
for i in range(batch_size): | |
if stop_event.is_set(): | |
yield all_videos, "Generation stopped by user", "" | |
return | |
batch_text = f"Generating video {i + 1} of {batch_size}" | |
yield all_videos.copy(), batch_text, progress_text | |
# Get random image from folder | |
random_image, status = get_random_image_from_folder(input_folder_path) | |
if random_image is None: | |
yield all_videos, f"Error in batch {i+1}: {status}", "" | |
continue | |
# Resize image | |
resized_image, size_info = resize_image_keeping_aspect_ratio(random_image, width, height) | |
if resized_image is None: | |
yield all_videos, f"Error resizing image in batch {i+1}: {size_info}", "" | |
continue | |
# Use the dimensions returned from the resize function | |
local_width, local_height = width, height # Default fallback | |
if isinstance(size_info, tuple): | |
local_width, local_height = size_info | |
progress_text = f"Using image: {os.path.basename(random_image)} - Resized to {local_width}x{local_height} (maintaining aspect ratio)" | |
else: | |
progress_text = f"Using image: {os.path.basename(random_image)}" | |
yield all_videos.copy(), batch_text, progress_text | |
# Calculate seed for this batch item | |
current_seed = seed | |
if seed == -1: | |
current_seed = random.randint(0, 2**32 - 1) | |
elif batch_size > 1: | |
current_seed = seed + i | |
# Extract LoRA weights and multipliers | |
num_lora_weights = 4 | |
lora_weights = lora_params[:num_lora_weights] | |
lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] | |
# Generate video for this image - one at a time | |
for videos, status, progress in wanx_generate_video( | |
prompt, | |
negative_prompt, | |
resized_image, | |
local_width, | |
local_height, | |
video_length, | |
fps, | |
infer_steps, | |
flow_shift, | |
guidance_scale, | |
current_seed, | |
task, | |
dit_path, | |
vae_path, | |
t5_path, | |
clip_path, | |
save_path, | |
output_type, | |
sample_solver, | |
exclude_single_blocks, | |
attn_mode, | |
block_swap, | |
fp8, | |
fp8_t5, | |
lora_folder, | |
*lora_weights, | |
*lora_multipliers | |
): | |
if videos: | |
all_videos.extend(videos) | |
yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress | |
# Clean up temporary file | |
try: | |
if os.path.exists(resized_image): | |
os.remove(resized_image) | |
except: | |
pass | |
# Clear CUDA cache between generations | |
clear_cuda_cache() | |
time.sleep(0.5) | |
yield all_videos, "Batch complete", "" | |
else: | |
# For non-random mode, if batch_size > 1, we need to process multiple times | |
# with the same input image but different seeds | |
if int(batch_size) > 1: | |
stop_event.clear() | |
all_videos = [] | |
progress_text = "Starting generation..." | |
yield [], "Preparing...", progress_text | |
# Extract LoRA weights and multipliers and input image | |
num_lora_weights = 4 | |
lora_weights = lora_params[:num_lora_weights] | |
lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] | |
input_image = lora_params[num_lora_weights*2] if len(lora_params) > num_lora_weights*2 else None | |
# Process each batch item | |
for i in range(int(batch_size)): | |
if stop_event.is_set(): | |
yield all_videos, "Generation stopped by user", "" | |
return | |
# Calculate seed for this batch item | |
current_seed = seed | |
if seed == -1: | |
current_seed = random.randint(0, 2**32 - 1) | |
elif batch_size > 1: | |
current_seed = seed + i | |
batch_text = f"Generating video {i + 1} of {batch_size}" | |
yield all_videos.copy(), batch_text, progress_text | |
# Generate a single video with the current seed | |
for videos, status, progress in wanx_generate_video( | |
prompt, | |
negative_prompt, | |
input_image, | |
width, | |
height, | |
video_length, | |
fps, | |
infer_steps, | |
flow_shift, | |
guidance_scale, | |
current_seed, | |
task, | |
dit_path, | |
vae_path, | |
t5_path, | |
clip_path, | |
save_path, | |
output_type, | |
sample_solver, | |
exclude_single_blocks, | |
attn_mode, | |
block_swap, | |
fp8, | |
fp8_t5, | |
lora_folder, | |
*lora_weights, | |
*lora_multipliers | |
): | |
if videos: | |
all_videos.extend(videos) | |
yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress | |
# Clear CUDA cache between generations | |
clear_cuda_cache() | |
time.sleep(0.5) | |
yield all_videos, "Batch complete", "" | |
else: | |
# Single image, single generation - use existing function | |
num_lora_weights = 4 | |
lora_weights = lora_params[:num_lora_weights] | |
lora_multipliers = lora_params[num_lora_weights:num_lora_weights*2] | |
input_image = lora_params[num_lora_weights*2] if len(lora_params) > num_lora_weights*2 else None | |
yield from wanx_generate_video( | |
prompt, | |
negative_prompt, | |
input_image, | |
width, | |
height, | |
video_length, | |
fps, | |
infer_steps, | |
flow_shift, | |
guidance_scale, | |
seed, | |
task, | |
dit_path, | |
vae_path, | |
t5_path, | |
clip_path, | |
save_path, | |
output_type, | |
sample_solver, | |
exclude_single_blocks, | |
attn_mode, | |
block_swap, | |
fp8, | |
fp8_t5, | |
lora_folder, | |
*lora_weights, | |
*lora_multipliers | |
) | |
def process_single_video( | |
prompt: str, | |
width: int, | |
height: int, | |
batch_size: int, | |
video_length: int, | |
fps: int, | |
infer_steps: int, | |
seed: int, | |
dit_folder: str, | |
model: str, | |
vae: str, | |
te1: str, | |
te2: str, | |
save_path: str, | |
flow_shift: float, | |
cfg_scale: float, | |
output_type: str, | |
attn_mode: str, | |
block_swap: int, | |
exclude_single_blocks: bool, | |
use_split_attn: bool, | |
lora_folder: str, | |
lora1: str = "", | |
lora2: str = "", | |
lora3: str = "", | |
lora4: str = "", | |
lora1_multiplier: float = 1.0, | |
lora2_multiplier: float = 1.0, | |
lora3_multiplier: float = 1.0, | |
lora4_multiplier: float = 1.0, | |
video_path: Optional[str] = None, | |
image_path: Optional[str] = None, | |
strength: Optional[float] = None, | |
negative_prompt: Optional[str] = None, | |
embedded_cfg_scale: Optional[float] = None, | |
split_uncond: Optional[bool] = None, | |
guidance_scale: Optional[float] = None, | |
use_fp8: bool = True | |
) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: | |
"""Generate a single video with the given parameters""" | |
global stop_event | |
if stop_event.is_set(): | |
yield [], "", "" | |
return | |
# Determine if this is a SkyReels model and what type | |
is_skyreels = "skyreels" in model.lower() | |
is_skyreels_i2v = is_skyreels and "i2v" in model.lower() | |
is_skyreels_t2v = is_skyreels and "t2v" in model.lower() | |
if is_skyreels: | |
# Force certain parameters for SkyReels | |
if negative_prompt is None: | |
negative_prompt = "" | |
if embedded_cfg_scale is None: | |
embedded_cfg_scale = 1.0 # Force to 1.0 for SkyReels | |
if split_uncond is None: | |
split_uncond = True | |
if guidance_scale is None: | |
guidance_scale = cfg_scale # Use cfg_scale as guidance_scale if not provided | |
# Determine the input channels based on model type | |
if is_skyreels_i2v: | |
dit_in_channels = 32 # SkyReels I2V uses 32 channels | |
else: | |
dit_in_channels = 16 # SkyReels T2V uses 16 channels (same as regular models) | |
else: | |
dit_in_channels = 16 # Regular Hunyuan models use 16 channels | |
embedded_cfg_scale = cfg_scale | |
if os.path.isabs(model): | |
model_path = model | |
else: | |
model_path = os.path.normpath(os.path.join(dit_folder, model)) | |
env = os.environ.copy() | |
env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") | |
env["PYTHONIOENCODING"] = "utf-8" | |
env["BATCH_RUN_ID"] = f"{time.time()}" | |
if seed == -1: | |
current_seed = random.randint(0, 2**32 - 1) | |
else: | |
batch_id = int(env.get("BATCH_RUN_ID", "0").split('.')[-1]) | |
if batch_size > 1: # Only modify seed for batch generation | |
current_seed = (seed + batch_id * 100003) % (2**32) | |
else: | |
current_seed = seed | |
clear_cuda_cache() | |
command = [ | |
sys.executable, | |
"hv_generate_video.py", | |
"--dit", model_path, | |
"--vae", vae, | |
"--text_encoder1", te1, | |
"--text_encoder2", te2, | |
"--prompt", prompt, | |
"--video_size", str(height), str(width), | |
"--video_length", str(video_length), | |
"--fps", str(fps), | |
"--infer_steps", str(infer_steps), | |
"--save_path", save_path, | |
"--seed", str(current_seed), | |
"--flow_shift", str(flow_shift), | |
"--embedded_cfg_scale", str(cfg_scale), | |
"--output_type", output_type, | |
"--attn_mode", attn_mode, | |
"--blocks_to_swap", str(block_swap), | |
"--fp8_llm", | |
"--vae_chunk_size", "32", | |
"--vae_spatial_tile_sample_min_size", "128" | |
] | |
if use_fp8: | |
command.append("--fp8") | |
# Add negative prompt and embedded cfg scale for SkyReels | |
if is_skyreels: | |
command.extend(["--dit_in_channels", str(dit_in_channels)]) | |
command.extend(["--guidance_scale", str(guidance_scale)]) | |
if negative_prompt: | |
command.extend(["--negative_prompt", negative_prompt]) | |
if split_uncond: | |
command.append("--split_uncond") | |
# Add LoRA weights and multipliers if provided | |
valid_loras = [] | |
for weight, mult in zip([lora1, lora2, lora3, lora4], | |
[lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): | |
if weight and weight != "None": | |
valid_loras.append((os.path.join(lora_folder, weight), mult)) | |
if valid_loras: | |
weights = [weight for weight, _ in valid_loras] | |
multipliers = [str(mult) for _, mult in valid_loras] | |
command.extend(["--lora_weight"] + weights) | |
command.extend(["--lora_multiplier"] + multipliers) | |
if exclude_single_blocks: | |
command.append("--exclude_single_blocks") | |
if use_split_attn: | |
command.append("--split_attn") | |
# Handle input paths | |
if video_path: | |
command.extend(["--video_path", video_path]) | |
if strength is not None: | |
command.extend(["--strength", str(strength)]) | |
elif image_path: | |
command.extend(["--image_path", image_path]) | |
# Only add strength parameter for non-SkyReels I2V models | |
# SkyReels I2V doesn't use strength parameter for image-to-video generation | |
if strength is not None and not is_skyreels_i2v: | |
command.extend(["--strength", str(strength)]) | |
print(f"{command}") | |
p = subprocess.Popen( | |
command, | |
stdout=subprocess.PIPE, | |
stderr=subprocess.STDOUT, | |
env=env, | |
text=True, | |
encoding='utf-8', | |
errors='replace', | |
bufsize=1 | |
) | |
videos = [] | |
while True: | |
if stop_event.is_set(): | |
p.terminate() | |
p.wait() | |
yield [], "", "Generation stopped by user." | |
return | |
line = p.stdout.readline() | |
if not line: | |
if p.poll() is not None: | |
break | |
continue | |
print(line, end='') | |
if '|' in line and '%' in line and '[' in line and ']' in line: | |
yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() | |
p.stdout.close() | |
p.wait() | |
clear_cuda_cache() | |
time.sleep(0.5) | |
# Collect generated video | |
save_path_abs = os.path.abspath(save_path) | |
if os.path.exists(save_path_abs): | |
all_videos = sorted( | |
[f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], | |
key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), | |
reverse=True | |
) | |
matching_videos = [v for v in all_videos if f"_{current_seed}" in v] | |
if matching_videos: | |
video_path = os.path.join(save_path_abs, matching_videos[0]) | |
# Collect parameters for metadata | |
parameters = { | |
"prompt": prompt, | |
"width": width, | |
"height": height, | |
"video_length": video_length, | |
"fps": fps, | |
"infer_steps": infer_steps, | |
"seed": current_seed, | |
"model": model, | |
"vae": vae, | |
"te1": te1, | |
"te2": te2, | |
"save_path": save_path, | |
"flow_shift": flow_shift, | |
"cfg_scale": cfg_scale, | |
"output_type": output_type, | |
"attn_mode": attn_mode, | |
"block_swap": block_swap, | |
"lora_weights": [lora1, lora2, lora3, lora4], | |
"lora_multipliers": [lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier], | |
"input_video": video_path if video_path else None, | |
"input_image": image_path if image_path else None, | |
"strength": strength, | |
"negative_prompt": negative_prompt if is_skyreels else None, | |
"embedded_cfg_scale": embedded_cfg_scale if is_skyreels else None | |
} | |
add_metadata_to_video(video_path, parameters) | |
videos.append((str(video_path), f"Seed: {current_seed}")) | |
yield videos, f"Completed (seed: {current_seed})", "" | |
# The issue is in the process_batch function, in the section that handles different input types | |
# Here's the corrected version of that section: | |
def process_batch( | |
prompt: str, | |
width: int, | |
height: int, | |
batch_size: int, | |
video_length: int, | |
fps: int, | |
infer_steps: int, | |
seed: int, | |
dit_folder: str, | |
model: str, | |
vae: str, | |
te1: str, | |
te2: str, | |
save_path: str, | |
flow_shift: float, | |
cfg_scale: float, | |
output_type: str, | |
attn_mode: str, | |
block_swap: int, | |
exclude_single_blocks: bool, | |
use_split_attn: bool, | |
lora_folder: str, | |
*args | |
) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: | |
"""Process a batch of videos using Gradio's queue""" | |
global stop_event | |
stop_event.clear() | |
all_videos = [] | |
progress_text = "Starting generation..." | |
yield [], "Preparing...", progress_text | |
# Extract additional arguments | |
num_lora_weights = 4 | |
lora_weights = args[:num_lora_weights] | |
lora_multipliers = args[num_lora_weights:num_lora_weights*2] | |
extra_args = args[num_lora_weights*2:] | |
# Determine if this is a SkyReels model and what type | |
is_skyreels = "skyreels" in model.lower() | |
is_skyreels_i2v = is_skyreels and "i2v" in model.lower() | |
is_skyreels_t2v = is_skyreels and "t2v" in model.lower() | |
# Handle input paths and additional parameters | |
input_path = extra_args[0] if extra_args else None | |
strength = float(extra_args[1]) if len(extra_args) > 1 else None | |
# Get use_fp8 flag (it should be the last parameter) | |
use_fp8 = bool(extra_args[-1]) if extra_args and len(extra_args) >= 3 else True | |
# Get SkyReels specific parameters if applicable | |
if is_skyreels: | |
# Always set embedded_cfg_scale to 1.0 for SkyReels models | |
embedded_cfg_scale = 1.0 | |
negative_prompt = str(extra_args[2]) if len(extra_args) > 2 and extra_args[2] is not None else "" | |
# Use cfg_scale for guidance_scale parameter | |
guidance_scale = float(extra_args[3]) if len(extra_args) > 3 and extra_args[3] is not None else cfg_scale | |
split_uncond = True if len(extra_args) > 4 and extra_args[4] else False | |
else: | |
negative_prompt = str(extra_args[2]) if len(extra_args) > 2 and extra_args[2] is not None else None | |
guidance_scale = cfg_scale | |
embedded_cfg_scale = cfg_scale | |
split_uncond = bool(extra_args[4]) if len(extra_args) > 4 else None | |
for i in range(batch_size): | |
if stop_event.is_set(): | |
break | |
batch_text = f"Generating video {i + 1} of {batch_size}" | |
yield all_videos.copy(), batch_text, progress_text | |
# Handle different input types | |
video_path = None | |
image_path = None | |
if input_path: | |
# Check if it's an image file (common image extensions) | |
is_image = False | |
lower_path = input_path.lower() | |
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp') | |
is_image = any(lower_path.endswith(ext) for ext in image_extensions) | |
# Only use image_path for SkyReels I2V models and actual image files | |
if is_skyreels_i2v and is_image: | |
image_path = input_path | |
else: | |
video_path = input_path | |
# Prepare arguments for process_single_video | |
single_video_args = [ | |
prompt, width, height, batch_size, video_length, fps, infer_steps, | |
seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, | |
output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, | |
lora_folder | |
] | |
single_video_args.extend(lora_weights) | |
single_video_args.extend(lora_multipliers) | |
single_video_args.extend([video_path, image_path, strength, negative_prompt, embedded_cfg_scale, split_uncond, guidance_scale, use_fp8]) | |
for videos, status, progress in process_single_video(*single_video_args): | |
if videos: | |
all_videos.extend(videos) | |
yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress | |
yield all_videos, "Batch complete", "" | |
def update_wanx_image_dimensions(image): | |
"""Update dimensions from uploaded image""" | |
if image is None: | |
return "", gr.update(value=832), gr.update(value=480) | |
img = Image.open(image) | |
w, h = img.size | |
w = (w // 32) * 32 | |
h = (h // 32) * 32 | |
return f"{w}x{h}", w, h | |
def calculate_wanx_width(height, original_dims): | |
"""Calculate width based on height maintaining aspect ratio""" | |
if not original_dims: | |
return gr.update() | |
orig_w, orig_h = map(int, original_dims.split('x')) | |
aspect_ratio = orig_w / orig_h | |
new_width = math.floor((height * aspect_ratio) / 32) * 32 | |
return gr.update(value=new_width) | |
def calculate_wanx_height(width, original_dims): | |
"""Calculate height based on width maintaining aspect ratio""" | |
if not original_dims: | |
return gr.update() | |
orig_w, orig_h = map(int, original_dims.split('x')) | |
aspect_ratio = orig_w / orig_h | |
new_height = math.floor((width / aspect_ratio) / 32) * 32 | |
return gr.update(value=new_height) | |
def update_wanx_from_scale(scale, original_dims): | |
"""Update dimensions based on scale percentage""" | |
if not original_dims: | |
return gr.update(), gr.update() | |
orig_w, orig_h = map(int, original_dims.split('x')) | |
new_w = math.floor((orig_w * scale / 100) / 32) * 32 | |
new_h = math.floor((orig_h * scale / 100) / 32) * 32 | |
return gr.update(value=new_w), gr.update(value=new_h) | |
def recommend_wanx_flow_shift(width, height): | |
"""Get recommended flow shift value based on dimensions""" | |
recommended_shift = 3.0 if (width == 832 and height == 480) or (width == 480 and height == 832) else 5.0 | |
return gr.update(value=recommended_shift) | |
def handle_wanx_gallery_select(evt: gr.SelectData, gallery) -> tuple: | |
"""Track selected index and video path when gallery item is clicked""" | |
if gallery is None: | |
return None, None | |
if evt.index >= len(gallery): | |
return None, None | |
selected_item = gallery[evt.index] | |
video_path = None | |
# Extract the video path based on the item type | |
if isinstance(selected_item, tuple): | |
video_path = selected_item[0] | |
elif isinstance(selected_item, dict): | |
video_path = selected_item.get("name", selected_item.get("data", None)) | |
else: | |
video_path = selected_item | |
return evt.index, video_path | |
def wanx_generate_video( | |
prompt, | |
negative_prompt, | |
input_image, | |
width, | |
height, | |
video_length, | |
fps, | |
infer_steps, | |
flow_shift, | |
guidance_scale, | |
seed, | |
task, | |
dit_path, | |
vae_path, | |
t5_path, | |
clip_path, | |
save_path, | |
output_type, | |
sample_solver, | |
exclude_single_blocks, | |
attn_mode, | |
block_swap, | |
fp8, | |
fp8_t5, | |
lora_folder, | |
lora1="None", | |
lora2="None", | |
lora3="None", | |
lora4="None", | |
lora1_multiplier=1.0, | |
lora2_multiplier=1.0, | |
lora3_multiplier=1.0, | |
lora4_multiplier=1.0 | |
) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: | |
"""Generate video with WanX model (supports both i2v and t2v)""" | |
global stop_event | |
if stop_event.is_set(): | |
yield [], "", "" | |
return | |
if seed == -1: | |
current_seed = random.randint(0, 2**32 - 1) | |
else: | |
current_seed = seed | |
# Check if we need input image (required for i2v, not for t2v) | |
if "i2v" in task and not input_image: | |
yield [], "Error: No input image provided", "Please provide an input image for image-to-video generation" | |
return | |
# Prepare environment | |
env = os.environ.copy() | |
env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") | |
env["PYTHONIOENCODING"] = "utf-8" | |
clear_cuda_cache() | |
command = [ | |
sys.executable, | |
"wan_generate_video.py", | |
"--task", task, | |
"--prompt", prompt, | |
"--video_size", str(height), str(width), | |
"--video_length", str(video_length), | |
"--fps", str(fps), | |
"--infer_steps", str(infer_steps), | |
"--save_path", save_path, | |
"--seed", str(current_seed), | |
"--flow_shift", str(flow_shift), | |
"--guidance_scale", str(guidance_scale), | |
"--output_type", output_type, | |
"--attn_mode", attn_mode, | |
"--blocks_to_swap", str(block_swap), | |
"--dit", dit_path, | |
"--vae", vae_path, | |
"--t5", t5_path, | |
"--sample_solver", sample_solver | |
] | |
# Add image path only for i2v task and if input image is provided | |
if "i2v" in task and input_image: | |
command.extend(["--image_path", input_image]) | |
command.extend(["--clip", clip_path]) # CLIP is only needed for i2v | |
if negative_prompt: | |
command.extend(["--negative_prompt", negative_prompt]) | |
if fp8: | |
command.append("--fp8") | |
if fp8_t5: | |
command.append("--fp8_t5") | |
if exclude_single_blocks: | |
command.append("--exclude_single_blocks") | |
# Add LoRA weights and multipliers if provided | |
valid_loras = [] | |
for weight, mult in zip([lora1, lora2, lora3, lora4], | |
[lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier]): | |
if weight and weight != "None": | |
valid_loras.append((os.path.join(lora_folder, weight), mult)) | |
if valid_loras: | |
weights = [weight for weight, _ in valid_loras] | |
multipliers = [str(mult) for _, mult in valid_loras] | |
command.extend(["--lora_weight"] + weights) | |
command.extend(["--lora_multiplier"] + multipliers) | |
print(f"Running: {' '.join(command)}") | |
p = subprocess.Popen( | |
command, | |
stdout=subprocess.PIPE, | |
stderr=subprocess.STDOUT, | |
env=env, | |
text=True, | |
encoding='utf-8', | |
errors='replace', | |
bufsize=1 | |
) | |
videos = [] | |
while True: | |
if stop_event.is_set(): | |
p.terminate() | |
p.wait() | |
yield [], "", "Generation stopped by user." | |
return | |
line = p.stdout.readline() | |
if not line: | |
if p.poll() is not None: | |
break | |
continue | |
print(line, end='') | |
if '|' in line and '%' in line and '[' in line and ']' in line: | |
yield videos.copy(), f"Processing (seed: {current_seed})", line.strip() | |
p.stdout.close() | |
p.wait() | |
clear_cuda_cache() | |
time.sleep(0.5) | |
# Collect generated video | |
save_path_abs = os.path.abspath(save_path) | |
if os.path.exists(save_path_abs): | |
all_videos = sorted( | |
[f for f in os.listdir(save_path_abs) if f.endswith('.mp4')], | |
key=lambda x: os.path.getmtime(os.path.join(save_path_abs, x)), | |
reverse=True | |
) | |
matching_videos = [v for v in all_videos if f"_{current_seed}" in v] | |
if matching_videos: | |
video_path = os.path.join(save_path_abs, matching_videos[0]) | |
# Collect parameters for metadata | |
parameters = { | |
"prompt": prompt, | |
"width": width, | |
"height": height, | |
"video_length": video_length, | |
"fps": fps, | |
"infer_steps": infer_steps, | |
"seed": current_seed, | |
"task": task, | |
"flow_shift": flow_shift, | |
"guidance_scale": guidance_scale, | |
"output_type": output_type, | |
"attn_mode": attn_mode, | |
"block_swap": block_swap, | |
"input_image": input_image if "i2v" in task else None | |
} | |
add_metadata_to_video(video_path, parameters) | |
videos.append((str(video_path), f"Seed: {current_seed}")) | |
yield videos, f"Completed (seed: {current_seed})", "" | |
def send_wanx_to_v2v( | |
gallery: list, | |
prompt: str, | |
selected_index: int, | |
width: int, | |
height: int, | |
video_length: int, | |
fps: int, | |
infer_steps: int, | |
seed: int, | |
flow_shift: float, | |
guidance_scale: float, | |
negative_prompt: str | |
) -> Tuple: | |
"""Send the selected WanX video to Video2Video tab""" | |
if gallery is None or not gallery: | |
return (None, "", width, height, video_length, fps, infer_steps, seed, | |
flow_shift, guidance_scale, negative_prompt) | |
# If no selection made but we have videos, use the first one | |
if selected_index is None and len(gallery) > 0: | |
selected_index = 0 | |
if selected_index is None or selected_index >= len(gallery): | |
return (None, "", width, height, video_length, fps, infer_steps, seed, | |
flow_shift, guidance_scale, negative_prompt) | |
selected_item = gallery[selected_index] | |
# Handle different gallery item formats | |
if isinstance(selected_item, tuple): | |
video_path = selected_item[0] | |
elif isinstance(selected_item, dict): | |
video_path = selected_item.get("name", selected_item.get("data", None)) | |
else: | |
video_path = selected_item | |
# Clean up path for Video component | |
if isinstance(video_path, tuple): | |
video_path = video_path[0] | |
# Make sure it's a string | |
video_path = str(video_path) | |
return (video_path, prompt, width, height, video_length, fps, infer_steps, seed, | |
flow_shift, guidance_scale, negative_prompt) | |
def wanx_generate_video_batch( | |
prompt, | |
negative_prompt, | |
width, | |
height, | |
video_length, | |
fps, | |
infer_steps, | |
flow_shift, | |
guidance_scale, | |
seed, | |
task, | |
dit_path, | |
vae_path, | |
t5_path, | |
clip_path, | |
save_path, | |
output_type, | |
sample_solver, | |
exclude_single_blocks, | |
attn_mode, | |
block_swap, | |
fp8, | |
fp8_t5, | |
lora_folder, | |
lora1="None", | |
lora2="None", | |
lora3="None", | |
lora4="None", | |
lora1_multiplier=1.0, | |
lora2_multiplier=1.0, | |
lora3_multiplier=1.0, | |
lora4_multiplier=1.0, | |
batch_size=1, | |
input_image=None # Make input_image optional and place it at the end | |
) -> Generator[Tuple[List[Tuple[str, str]], str, str], None, None]: | |
"""Generate videos with WanX with support for batches""" | |
global stop_event | |
stop_event.clear() | |
all_videos = [] | |
progress_text = "Starting generation..." | |
yield [], "Preparing...", progress_text | |
# Process each item in the batch | |
for i in range(batch_size): | |
if stop_event.is_set(): | |
yield all_videos, "Generation stopped by user", "" | |
return | |
# Calculate seed for this batch item | |
current_seed = seed | |
if seed == -1: | |
current_seed = random.randint(0, 2**32 - 1) | |
elif batch_size > 1: | |
current_seed = seed + i | |
batch_text = f"Generating video {i + 1} of {batch_size}" | |
yield all_videos.copy(), batch_text, progress_text | |
# Generate a single video using the existing function | |
for videos, status, progress in wanx_generate_video( | |
prompt, negative_prompt, input_image, width, height, | |
video_length, fps, infer_steps, flow_shift, guidance_scale, | |
current_seed, task, dit_path, vae_path, t5_path, clip_path, | |
save_path, output_type, sample_solver, exclude_single_blocks, | |
attn_mode, block_swap, fp8, fp8_t5, | |
lora_folder, | |
lora1, | |
lora2, | |
lora3, | |
lora4, | |
lora1_multiplier, | |
lora2_multiplier, | |
lora3_multiplier, | |
lora4_multiplier | |
): | |
if videos: | |
all_videos.extend(videos) | |
yield all_videos.copy(), f"Batch {i+1}/{batch_size}: {status}", progress | |
yield all_videos, "Batch complete", "" | |
def update_wanx_t2v_dimensions(size): | |
"""Update width and height based on selected size""" | |
width, height = map(int, size.split('*')) | |
return gr.update(value=width), gr.update(value=height) | |
def handle_wanx_t2v_gallery_select(evt: gr.SelectData) -> int: | |
"""Track selected index when gallery item is clicked""" | |
return evt.index | |
def send_wanx_t2v_to_v2v( | |
gallery, prompt, selected_index, width, height, video_length, | |
fps, infer_steps, seed, flow_shift, guidance_scale, negative_prompt | |
) -> Tuple: | |
"""Send the selected WanX T2V video to Video2Video tab""" | |
if not gallery or selected_index is None or selected_index >= len(gallery): | |
return (None, "", width, height, video_length, fps, infer_steps, seed, | |
flow_shift, guidance_scale, negative_prompt) | |
selected_item = gallery[selected_index] | |
if isinstance(selected_item, dict): | |
video_path = selected_item.get("name", selected_item.get("data", None)) | |
elif isinstance(selected_item, (tuple, list)): | |
video_path = selected_item[0] | |
else: | |
video_path = selected_item | |
if isinstance(video_path, tuple): | |
video_path = video_path[0] | |
return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, | |
flow_shift, guidance_scale, negative_prompt) | |
def prepare_for_batch_extension(input_img, base_video, batch_size): | |
"""Prepare inputs for batch video extension""" | |
if input_img is None: | |
return None, None, batch_size, "No input image found", "" | |
if base_video is None: | |
return input_img, None, batch_size, "No base video selected for extension", "" | |
return input_img, base_video, batch_size, "Preparing batch extension...", f"Will create {batch_size} variations of extended video" | |
def concat_batch_videos(base_video_path, generated_videos, save_path, original_video_path=None): | |
"""Concatenate multiple generated videos with the base video""" | |
if not base_video_path: | |
return [], "No base video provided" | |
if not generated_videos or len(generated_videos) == 0: | |
return [], "No new videos generated" | |
# Create output directory if it doesn't exist | |
os.makedirs(save_path, exist_ok=True) | |
# Track all extended videos | |
extended_videos = [] | |
# For each generated video, create an extended version | |
for i, video_item in enumerate(generated_videos): | |
try: | |
# Extract video path from gallery item | |
if isinstance(video_item, tuple): | |
new_video_path = video_item[0] | |
seed_info = video_item[1] if len(video_item) > 1 else "" | |
elif isinstance(video_item, dict): | |
new_video_path = video_item.get("name", video_item.get("data", None)) | |
seed_info = "" | |
else: | |
new_video_path = video_item | |
seed_info = "" | |
if not new_video_path or not os.path.exists(new_video_path): | |
print(f"Skipping missing video: {new_video_path}") | |
continue | |
# Create unique output filename | |
timestamp = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") | |
# Extract seed from seed_info if available | |
seed_match = re.search(r"Seed: (\d+)", seed_info) | |
seed_part = f"_seed{seed_match.group(1)}" if seed_match else f"_{i}" | |
output_filename = f"extended_{timestamp}{seed_part}_{Path(base_video_path).stem}.mp4" | |
output_path = os.path.join(save_path, output_filename) | |
# Create a temporary file list for ffmpeg | |
list_file = os.path.join(save_path, f"temp_list_{i}.txt") | |
with open(list_file, "w") as f: | |
f.write(f"file '{os.path.abspath(base_video_path)}'\n") | |
f.write(f"file '{os.path.abspath(new_video_path)}'\n") | |
# Run ffmpeg concatenation | |
command = [ | |
"ffmpeg", | |
"-f", "concat", | |
"-safe", "0", | |
"-i", list_file, | |
"-c", "copy", | |
output_path | |
] | |
subprocess.run(command, check=True, capture_output=True) | |
# Clean up temporary file | |
if os.path.exists(list_file): | |
os.remove(list_file) | |
# Add to extended videos list if successful | |
if os.path.exists(output_path): | |
seed_display = f"Extended {seed_info}" if seed_info else f"Extended video #{i+1}" | |
extended_videos.append((output_path, seed_display)) | |
except Exception as e: | |
print(f"Error processing video {i}: {str(e)}") | |
if not extended_videos: | |
return [], "Failed to create any extended videos" | |
return extended_videos, f"Successfully created {len(extended_videos)} extended videos" | |
def handle_extend_generation(base_video_path: str, new_videos: list, save_path: str, current_gallery: list) -> tuple: | |
"""Combine generated video with base video and update gallery""" | |
if not base_video_path: | |
return current_gallery, "Extend failed: No base video provided" | |
if not new_videos: | |
return current_gallery, "Extend failed: No new video generated" | |
# Ensure save path exists | |
os.makedirs(save_path, exist_ok=True) | |
# Get the first video from new_videos (gallery item) | |
new_video_path = new_videos[0][0] if isinstance(new_videos[0], tuple) else new_videos[0] | |
# Create a unique output filename | |
timestamp = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") | |
output_filename = f"extended_{timestamp}_{Path(base_video_path).stem}.mp4" | |
output_path = str(Path(save_path) / output_filename) | |
try: | |
# Concatenate the videos using ffmpeg | |
( | |
ffmpeg | |
.input(base_video_path) | |
.concat( | |
ffmpeg.input(new_video_path) | |
) | |
.output(output_path) | |
.run(overwrite_output=True, quiet=True) | |
) | |
# Create a new gallery entry with the combined video | |
updated_gallery = [(output_path, f"Extended video: {Path(output_path).stem}")] | |
return updated_gallery, f"Successfully extended video to {Path(output_path).name}" | |
except Exception as e: | |
print(f"Error extending video: {str(e)}") | |
return current_gallery, f"Failed to extend video: {str(e)}" | |
# UI setup | |
with gr.Blocks( | |
theme=themes.Default( | |
primary_hue=colors.Color( | |
name="custom", | |
c50="#E6F0FF", | |
c100="#CCE0FF", | |
c200="#99C1FF", | |
c300="#66A3FF", | |
c400="#3384FF", | |
c500="#0060df", # This is your main color | |
c600="#0052C2", | |
c700="#003D91", | |
c800="#002961", | |
c900="#001430", | |
c950="#000A18" | |
) | |
), | |
css=""" | |
.gallery-item:first-child { border: 2px solid #4CAF50 !important; } | |
.gallery-item:first-child:hover { border-color: #45a049 !important; } | |
.green-btn { | |
background: linear-gradient(to bottom right, #2ecc71, #27ae60) !important; | |
color: white !important; | |
border: none !important; | |
} | |
.green-btn:hover { | |
background: linear-gradient(to bottom right, #27ae60, #219651) !important; | |
} | |
.refresh-btn { | |
max-width: 40px !important; | |
min-width: 40px !important; | |
height: 40px !important; | |
border-radius: 50% !important; | |
padding: 0 !important; | |
display: flex !important; | |
align-items: center !important; | |
justify-content: center !important; | |
} | |
""", | |
) as demo: | |
# Add state for tracking selected video indices in both tabs | |
selected_index = gr.State(value=None) # For Text to Video | |
v2v_selected_index = gr.State(value=None) # For Video to Video | |
params_state = gr.State() #New addition | |
i2v_selected_index = gr.State(value=None) | |
skyreels_selected_index = gr.State(value=None) | |
wanx_i2v_selected_index = gr.State(value=None) | |
extended_videos = gr.State(value=[]) | |
wanx_base_video = gr.State(value=None) | |
wanx_sharpest_frame_number = gr.State(value=None) | |
wanx_sharpest_frame_path = gr.State(value=None) | |
wanx_trimmed_video_path = gr.State(value=None) | |
demo.load(None, None, None, js=""" | |
() => { | |
document.title = 'H1111'; | |
function updateTitle(text) { | |
if (text && text.trim()) { | |
const progressMatch = text.match(/(\d+)%.*\[.*<(\d+:\d+),/); | |
if (progressMatch) { | |
const percentage = progressMatch[1]; | |
const timeRemaining = progressMatch[2]; | |
document.title = `[${percentage}% ETA: ${timeRemaining}] - H1111`; | |
} | |
} | |
} | |
setTimeout(() => { | |
const progressElements = document.querySelectorAll('textarea.scroll-hide'); | |
progressElements.forEach(element => { | |
if (element) { | |
new MutationObserver(() => { | |
updateTitle(element.value); | |
}).observe(element, { | |
attributes: true, | |
childList: true, | |
characterData: true | |
}); | |
} | |
}); | |
}, 1000); | |
} | |
""") | |
with gr.Tabs() as tabs: | |
# Text to Video Tab | |
with gr.Tab(id=1, label="Hunyuan-t2v"): | |
with gr.Row(): | |
with gr.Column(scale=4): | |
prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) | |
with gr.Column(scale=1): | |
token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) | |
batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) | |
with gr.Column(scale=2): | |
batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") | |
progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") | |
with gr.Row(): | |
generate_btn = gr.Button("Generate Video", elem_classes="green-btn") | |
stop_btn = gr.Button("Stop Generation", variant="stop") | |
with gr.Row(): | |
with gr.Column(): | |
t2v_width = gr.Slider(minimum=64, maximum=1536, step=16, value=544, label="Video Width") | |
t2v_height = gr.Slider(minimum=64, maximum=1536, step=16, value=544, label="Video Height") | |
video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25, elem_id="my_special_slider") | |
fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24, elem_id="my_special_slider") | |
infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30, elem_id="my_special_slider") | |
flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0, elem_id="my_special_slider") | |
cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg Scale", value=7.0, elem_id="my_special_slider") | |
with gr.Column(): | |
with gr.Row(): | |
video_output = gr.Gallery( | |
label="Generated Videos (Click to select)", | |
columns=[2], | |
rows=[2], | |
object_fit="contain", | |
height="auto", | |
show_label=True, | |
elem_id="gallery", | |
allow_preview=True, | |
preview=True | |
) | |
with gr.Row():send_t2v_to_v2v_btn = gr.Button("Send Selected to Video2Video") | |
with gr.Row(): | |
refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") | |
lora_weights = [] | |
lora_multipliers = [] | |
for i in range(4): | |
with gr.Column(): | |
lora_weights.append(gr.Dropdown( | |
label=f"LoRA {i+1}", | |
choices=get_lora_options(), | |
value="None", | |
allow_custom_value=True, | |
interactive=True | |
)) | |
lora_multipliers.append(gr.Slider( | |
label=f"Multiplier", | |
minimum=0.0, | |
maximum=2.0, | |
step=0.05, | |
value=1.0 | |
)) | |
with gr.Row(): | |
exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) | |
seed = gr.Number(label="Seed (use -1 for random)", value=-1) | |
dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") | |
model = gr.Dropdown( | |
label="DiT Model", | |
choices=get_dit_models("hunyuan"), | |
value="mp_rank_00_model_states.pt", | |
allow_custom_value=True, | |
interactive=True | |
) | |
vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") | |
te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") | |
te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") | |
save_path = gr.Textbox(label="Save Path", value="outputs") | |
with gr.Row(): | |
lora_folder = gr.Textbox(label="LoRA Folder", value="lora") | |
output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") | |
use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) | |
use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) | |
attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") | |
block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) | |
#Image to Video Tab | |
with gr.Tab(label="Hunyuan-i2v") as i2v_tab: | |
with gr.Row(): | |
with gr.Column(scale=4): | |
i2v_prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) | |
with gr.Column(scale=1): | |
i2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) | |
i2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) | |
with gr.Column(scale=2): | |
i2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") | |
i2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") | |
with gr.Row(): | |
i2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") | |
i2v_stop_btn = gr.Button("Stop Generation", variant="stop") | |
with gr.Row(): | |
with gr.Column(): | |
i2v_input = gr.Image(label="Input Image", type="filepath") | |
i2v_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") | |
# Scale slider as percentage | |
scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") | |
original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) | |
# Width and height inputs | |
with gr.Row(): | |
width = gr.Number(label="New Width", value=544, step=16) | |
calc_height_btn = gr.Button("→") | |
calc_width_btn = gr.Button("←") | |
height = gr.Number(label="New Height", value=544, step=16) | |
i2v_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) | |
i2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) | |
i2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) | |
i2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) | |
i2v_cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg scale", value=7.0) | |
with gr.Column(): | |
i2v_output = gr.Gallery( | |
label="Generated Videos (Click to select)", | |
columns=[2], | |
rows=[2], | |
object_fit="contain", | |
height="auto", | |
show_label=True, | |
elem_id="gallery", | |
allow_preview=True, | |
preview=True | |
) | |
i2v_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") | |
# Add LoRA section for Image2Video | |
i2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") | |
i2v_lora_weights = [] | |
i2v_lora_multipliers = [] | |
for i in range(4): | |
with gr.Column(): | |
i2v_lora_weights.append(gr.Dropdown( | |
label=f"LoRA {i+1}", | |
choices=get_lora_options(), | |
value="None", | |
allow_custom_value=True, | |
interactive=True | |
)) | |
i2v_lora_multipliers.append(gr.Slider( | |
label=f"Multiplier", | |
minimum=0.0, | |
maximum=2.0, | |
step=0.05, | |
value=1.0 | |
)) | |
with gr.Row(): | |
i2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) | |
i2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) | |
i2v_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") | |
i2v_model = gr.Dropdown( | |
label="DiT Model", | |
choices=get_dit_models("hunyuan"), | |
value="mp_rank_00_model_states.pt", | |
allow_custom_value=True, | |
interactive=True | |
) | |
i2v_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") | |
i2v_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") | |
i2v_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") | |
i2v_save_path = gr.Textbox(label="Save Path", value="outputs") | |
with gr.Row(): | |
i2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") | |
i2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") | |
i2v_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) | |
i2v_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) | |
i2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") | |
i2v_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) | |
# Video to Video Tab | |
with gr.Tab(id=2, label="Hunyuan-v2v") as v2v_tab: | |
with gr.Row(): | |
with gr.Column(scale=4): | |
v2v_prompt = gr.Textbox(scale=3, label="Enter your prompt", value="POV video of a cat chasing a frob.", lines=5) | |
v2v_negative_prompt = gr.Textbox( | |
scale=3, | |
label="Negative Prompt (for SkyReels models)", | |
value="Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", | |
lines=3 | |
) | |
with gr.Column(scale=1): | |
v2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) | |
v2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) | |
with gr.Column(scale=2): | |
v2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") | |
v2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") | |
with gr.Row(): | |
v2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") | |
v2v_stop_btn = gr.Button("Stop Generation", variant="stop") | |
with gr.Row(): | |
with gr.Column(): | |
v2v_input = gr.Video(label="Input Video", format="mp4") | |
v2v_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") | |
v2v_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") | |
v2v_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) | |
# Width and Height Inputs | |
with gr.Row(): | |
v2v_width = gr.Number(label="New Width", value=544, step=16) | |
v2v_calc_height_btn = gr.Button("→") | |
v2v_calc_width_btn = gr.Button("←") | |
v2v_height = gr.Number(label="New Height", value=544, step=16) | |
v2v_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) | |
v2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) | |
v2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) | |
v2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) | |
v2v_cfg_scale = gr.Slider(minimum=0.0, maximum=14.0, step=0.1, label="cfg scale", value=7.0) | |
with gr.Column(): | |
v2v_output = gr.Gallery( | |
label="Generated Videos", | |
columns=[1], | |
rows=[1], | |
object_fit="contain", | |
height="auto" | |
) | |
v2v_send_to_input_btn = gr.Button("Send Selected to Input") # New button | |
v2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") | |
v2v_lora_weights = [] | |
v2v_lora_multipliers = [] | |
for i in range(4): | |
with gr.Column(): | |
v2v_lora_weights.append(gr.Dropdown( | |
label=f"LoRA {i+1}", | |
choices=get_lora_options(), | |
value="None", | |
allow_custom_value=True, | |
interactive=True | |
)) | |
v2v_lora_multipliers.append(gr.Slider( | |
label=f"Multiplier", | |
minimum=0.0, | |
maximum=2.0, | |
step=0.05, | |
value=1.0 | |
)) | |
with gr.Row(): | |
v2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) | |
v2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) | |
v2v_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") | |
v2v_model = gr.Dropdown( | |
label="DiT Model", | |
choices=get_dit_models("hunyuan"), | |
value="mp_rank_00_model_states.pt", | |
allow_custom_value=True, | |
interactive=True | |
) | |
v2v_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") | |
v2v_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") | |
v2v_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") | |
v2v_save_path = gr.Textbox(label="Save Path", value="outputs") | |
with gr.Row(): | |
v2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") | |
v2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") | |
v2v_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) | |
v2v_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) | |
v2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") | |
v2v_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) | |
v2v_split_uncond = gr.Checkbox(label="Split Unconditional (for SkyReels)", value=True) | |
### SKYREELS | |
with gr.Tab(label="SkyReels-i2v") as skyreels_tab: | |
with gr.Row(): | |
with gr.Column(scale=4): | |
skyreels_prompt = gr.Textbox( | |
scale=3, | |
label="Enter your prompt", | |
value="A person walking on a beach at sunset", | |
lines=5 | |
) | |
skyreels_negative_prompt = gr.Textbox( | |
scale=3, | |
label="Negative Prompt", | |
value="Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", | |
lines=3 | |
) | |
with gr.Column(scale=1): | |
skyreels_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) | |
skyreels_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) | |
with gr.Column(scale=2): | |
skyreels_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") | |
skyreels_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") | |
with gr.Row(): | |
skyreels_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") | |
skyreels_stop_btn = gr.Button("Stop Generation", variant="stop") | |
with gr.Row(): | |
with gr.Column(): | |
skyreels_input = gr.Image(label="Input Image (optional)", type="filepath") | |
with gr.Row(): | |
skyreels_use_random_folder = gr.Checkbox(label="Use Random Images from Folder", value=False) | |
skyreels_input_folder = gr.Textbox( | |
label="Image Folder Path", | |
placeholder="Path to folder containing images", | |
visible=False | |
) | |
skyreels_folder_status = gr.Textbox( | |
label="Folder Status", | |
placeholder="Status will appear here", | |
interactive=False, | |
visible=False | |
) | |
skyreels_validate_folder_btn = gr.Button("Validate Folder", visible=False) | |
skyreels_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.75, label="Denoise Strength") | |
# Scale slider as percentage | |
skyreels_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") | |
skyreels_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) | |
# Width and height inputs | |
with gr.Row(): | |
skyreels_width = gr.Number(label="New Width", value=544, step=16) | |
skyreels_calc_height_btn = gr.Button("→") | |
skyreels_calc_width_btn = gr.Button("←") | |
skyreels_height = gr.Number(label="New Height", value=544, step=16) | |
skyreels_video_length = gr.Slider(minimum=1, maximum=201, step=1, label="Video Length in Frames", value=25) | |
skyreels_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=24) | |
skyreels_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) | |
skyreels_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=11.0) | |
skyreels_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=6.0) | |
skyreels_embedded_cfg_scale = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, label="Embedded CFG Scale", value=1.0) | |
with gr.Column(): | |
skyreels_output = gr.Gallery( | |
label="Generated Videos (Click to select)", | |
columns=[2], | |
rows=[2], | |
object_fit="contain", | |
height="auto", | |
show_label=True, | |
elem_id="gallery", | |
allow_preview=True, | |
preview=True | |
) | |
skyreels_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") | |
# Add LoRA section for SKYREELS | |
skyreels_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") | |
skyreels_lora_weights = [] | |
skyreels_lora_multipliers = [] | |
for i in range(4): | |
with gr.Column(): | |
skyreels_lora_weights.append(gr.Dropdown( | |
label=f"LoRA {i+1}", | |
choices=get_lora_options(), | |
value="None", | |
allow_custom_value=True, | |
interactive=True | |
)) | |
skyreels_lora_multipliers.append(gr.Slider( | |
label=f"Multiplier", | |
minimum=0.0, | |
maximum=2.0, | |
step=0.05, | |
value=1.0 | |
)) | |
with gr.Row(): | |
skyreels_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) | |
skyreels_seed = gr.Number(label="Seed (use -1 for random)", value=-1) | |
skyreels_dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") | |
skyreels_model = gr.Dropdown( | |
label="DiT Model", | |
choices=get_dit_models("skyreels"), | |
value="skyreels_hunyuan_i2v_bf16.safetensors", | |
allow_custom_value=True, | |
interactive=True | |
) | |
skyreels_vae = gr.Textbox(label="vae", value="hunyuan/pytorch_model.pt") | |
skyreels_te1 = gr.Textbox(label="te1", value="hunyuan/llava_llama3_fp16.safetensors") | |
skyreels_te2 = gr.Textbox(label="te2", value="hunyuan/clip_l.safetensors") | |
skyreels_save_path = gr.Textbox(label="Save Path", value="outputs") | |
with gr.Row(): | |
skyreels_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") | |
skyreels_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") | |
skyreels_use_split_attn = gr.Checkbox(label="Use Split Attention", value=False) | |
skyreels_use_fp8 = gr.Checkbox(label="Use FP8 (faster but lower precision)", value=True) | |
skyreels_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") | |
skyreels_block_swap = gr.Slider(minimum=0, maximum=36, step=1, label="Block Swap to Save Vram", value=0) | |
skyreels_split_uncond = gr.Checkbox(label="Split Unconditional", value=True) | |
# WanX Image to Video Tab | |
with gr.Tab(id=4, label="WanX-i2v") as wanx_i2v_tab: | |
with gr.Row(): | |
with gr.Column(scale=4): | |
wanx_prompt = gr.Textbox( | |
scale=3, | |
label="Enter your prompt", | |
value="A person walking on a beach at sunset", | |
lines=5 | |
) | |
wanx_negative_prompt = gr.Textbox( | |
scale=3, | |
label="Negative Prompt", | |
value="", | |
lines=3, | |
info="Leave empty to use default negative prompt" | |
) | |
with gr.Column(scale=1): | |
wanx_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) | |
wanx_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) | |
with gr.Column(scale=2): | |
wanx_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") | |
wanx_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") | |
with gr.Row(): | |
wanx_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") | |
wanx_stop_btn = gr.Button("Stop Generation", variant="stop") | |
with gr.Row(): | |
with gr.Column(): | |
wanx_input = gr.Image(label="Input Image", type="filepath") | |
with gr.Row(): | |
wanx_use_random_folder = gr.Checkbox(label="Use Random Images from Folder", value=False) | |
wanx_input_folder = gr.Textbox( | |
label="Image Folder Path", | |
placeholder="Path to folder containing images", | |
visible=False | |
) | |
wanx_folder_status = gr.Textbox( | |
label="Folder Status", | |
placeholder="Status will appear here", | |
interactive=False, | |
visible=False | |
) | |
wanx_validate_folder_btn = gr.Button("Validate Folder", visible=False) | |
wanx_scale_slider = gr.Slider(minimum=1, maximum=200, value=100, step=1, label="Scale %") | |
wanx_original_dims = gr.Textbox(label="Original Dimensions", interactive=False, visible=True) | |
# Width and height display | |
with gr.Row(): | |
wanx_width = gr.Number(label="Width", value=832, interactive=True) | |
wanx_calc_height_btn = gr.Button("→") | |
wanx_calc_width_btn = gr.Button("←") | |
wanx_height = gr.Number(label="Height", value=480, interactive=True) | |
wanx_recommend_flow_btn = gr.Button("Recommend Flow Shift", size="sm") | |
wanx_video_length = gr.Slider(minimum=1, maximum=201, step=4, label="Video Length in Frames", value=81) | |
wanx_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=16) | |
wanx_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=20) | |
wanx_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=3.0, | |
info="Recommended: 3.0 for 480p, 5.0 for others") | |
wanx_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0) | |
with gr.Column(): | |
wanx_output = gr.Gallery( | |
label="Generated Videos (Click to select)", | |
columns=[2], | |
rows=[2], | |
object_fit="contain", | |
height="auto", | |
show_label=True, | |
elem_id="gallery", | |
allow_preview=True, | |
preview=True | |
) | |
wanx_send_to_v2v_btn = gr.Button("Send Selected to Hunyuan-v2v") | |
wanx_send_last_frame_btn = gr.Button("Send Last Frame to Input") | |
wanx_extend_btn = gr.Button("Extend Video") | |
wanx_frames_to_check = gr.Slider(minimum=1, maximum=100, step=1, value=30, | |
label="Frames to Check from End", | |
info="Number of frames from the end to check for sharpness") | |
wanx_send_sharpest_frame_btn = gr.Button("Extract Sharpest Frame") | |
wanx_trim_and_extend_btn = gr.Button("Trim Video & Prepare for Extension") | |
wanx_sharpest_frame_status = gr.Textbox(label="Status", interactive=False) | |
# Add a new button for directly extending with the trimmed video | |
wanx_extend_with_trimmed_btn = gr.Button("Extend with Trimmed Video") | |
# Add LoRA section for WanX-i2v similar to other tabs | |
wanx_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") | |
wanx_lora_weights = [] | |
wanx_lora_multipliers = [] | |
for i in range(4): | |
with gr.Column(): | |
wanx_lora_weights.append(gr.Dropdown( | |
label=f"LoRA {i+1}", | |
choices=get_lora_options(), | |
value="None", | |
allow_custom_value=True, | |
interactive=True | |
)) | |
wanx_lora_multipliers.append(gr.Slider( | |
label=f"Multiplier", | |
minimum=0.0, | |
maximum=2.0, | |
step=0.05, | |
value=1.0 | |
)) | |
with gr.Row(): | |
wanx_seed = gr.Number(label="Seed (use -1 for random)", value=-1) | |
wanx_task = gr.Dropdown( | |
label="Task", | |
choices=["i2v-14B"], | |
value="i2v-14B", | |
info="Currently only i2v-14B is supported" | |
) | |
wanx_dit_path = gr.Textbox(label="DiT Model Path", value="wan/wan2.1_i2v_480p_14B_bf16.safetensors") | |
wanx_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") | |
wanx_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") | |
wanx_clip_path = gr.Textbox(label="CLIP Path", value="wan/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth") | |
wanx_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") | |
wanx_save_path = gr.Textbox(label="Save Path", value="outputs") | |
with gr.Row(): | |
wanx_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") | |
wanx_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") | |
wanx_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) | |
wanx_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") | |
wanx_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0) | |
wanx_fp8 = gr.Checkbox(label="Use FP8", value=True) | |
wanx_fp8_t5 = gr.Checkbox(label="Use FP8 for T5", value=False) | |
#WanX-t2v Tab | |
# WanX Text to Video Tab | |
with gr.Tab(id=5, label="WanX-t2v") as wanx_t2v_tab: | |
with gr.Row(): | |
with gr.Column(scale=4): | |
wanx_t2v_prompt = gr.Textbox( | |
scale=3, | |
label="Enter your prompt", | |
value="A person walking on a beach at sunset", | |
lines=5 | |
) | |
wanx_t2v_negative_prompt = gr.Textbox( | |
scale=3, | |
label="Negative Prompt", | |
value="", | |
lines=3, | |
info="Leave empty to use default negative prompt" | |
) | |
with gr.Column(scale=1): | |
wanx_t2v_token_counter = gr.Number(label="Prompt Token Count", value=0, interactive=False) | |
wanx_t2v_batch_size = gr.Number(label="Batch Count", value=1, minimum=1, step=1) | |
with gr.Column(scale=2): | |
wanx_t2v_batch_progress = gr.Textbox(label="", visible=True, elem_id="batch_progress") | |
wanx_t2v_progress_text = gr.Textbox(label="", visible=True, elem_id="progress_text") | |
with gr.Row(): | |
wanx_t2v_generate_btn = gr.Button("Generate Video", elem_classes="green-btn") | |
wanx_t2v_stop_btn = gr.Button("Stop Generation", variant="stop") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
wanx_t2v_width = gr.Number(label="Width", value=832, interactive=True, info="Should be divisible by 32") | |
wanx_t2v_height = gr.Number(label="Height", value=480, interactive=True, info="Should be divisible by 32") | |
wanx_t2v_recommend_flow_btn = gr.Button("Recommend Flow Shift", size="sm") | |
wanx_t2v_video_length = gr.Slider(minimum=1, maximum=201, step=4, label="Video Length in Frames", value=81) | |
wanx_t2v_fps = gr.Slider(minimum=1, maximum=60, step=1, label="Frames Per Second", value=16) | |
wanx_t2v_infer_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=20) | |
wanx_t2v_flow_shift = gr.Slider(minimum=0.0, maximum=28.0, step=0.5, label="Flow Shift", value=5.0, | |
info="Recommended: 3.0 for I2V with 480p, 5.0 for others") | |
wanx_t2v_guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0) | |
with gr.Column(): | |
wanx_t2v_output = gr.Gallery( | |
label="Generated Videos (Click to select)", | |
columns=[2], | |
rows=[2], | |
object_fit="contain", | |
height="auto", | |
show_label=True, | |
elem_id="gallery", | |
allow_preview=True, | |
preview=True | |
) | |
wanx_t2v_send_to_v2v_btn = gr.Button("Send Selected to Video2Video") | |
# Add LoRA section for WanX-t2v | |
wanx_t2v_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") | |
wanx_t2v_lora_weights = [] | |
wanx_t2v_lora_multipliers = [] | |
for i in range(4): | |
with gr.Column(): | |
wanx_t2v_lora_weights.append(gr.Dropdown( | |
label=f"LoRA {i+1}", | |
choices=get_lora_options(), | |
value="None", | |
allow_custom_value=True, | |
interactive=True | |
)) | |
wanx_t2v_lora_multipliers.append(gr.Slider( | |
label=f"Multiplier", | |
minimum=0.0, | |
maximum=2.0, | |
step=0.05, | |
value=1.0 | |
)) | |
with gr.Row(): | |
wanx_t2v_seed = gr.Number(label="Seed (use -1 for random)", value=-1) | |
wanx_t2v_task = gr.Dropdown( | |
label="Task", | |
choices=["t2v-1.3B", "t2v-14B", "t2i-14B"], | |
value="t2v-14B", | |
info="Select model size: t2v-1.3B is faster, t2v-14B has higher quality" | |
) | |
wanx_t2v_dit_path = gr.Textbox(label="DiT Model Path", value="wan/wan2.1_t2v_14B_bf16.safetensors") | |
wanx_t2v_vae_path = gr.Textbox(label="VAE Path", value="wan/Wan2.1_VAE.pth") | |
wanx_t2v_t5_path = gr.Textbox(label="T5 Path", value="wan/models_t5_umt5-xxl-enc-bf16.pth") | |
wanx_t2v_clip_path = gr.Textbox(label="CLIP Path", visible=False, value="") | |
wanx_t2v_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") | |
wanx_t2v_save_path = gr.Textbox(label="Save Path", value="outputs") | |
with gr.Row(): | |
wanx_t2v_output_type = gr.Radio(choices=["video", "images", "latent", "both"], label="Output Type", value="video") | |
wanx_t2v_sample_solver = gr.Radio(choices=["unipc", "dpm++", "vanilla"], label="Sample Solver", value="unipc") | |
wanx_t2v_exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) | |
wanx_t2v_attn_mode = gr.Radio(choices=["sdpa", "flash", "sageattn", "xformers", "torch"], label="Attention Mode", value="sdpa") | |
wanx_t2v_block_swap = gr.Slider(minimum=0, maximum=39, step=1, label="Block Swap to Save VRAM", value=0, | |
info="Max 39 for 14B model, 29 for 1.3B model") | |
wanx_t2v_fp8 = gr.Checkbox(label="Use FP8", value=True) | |
wanx_t2v_fp8_t5 = gr.Checkbox(label="Use FP8 for T5", value=False) | |
#Video Info Tab | |
with gr.Tab("Video Info") as video_info_tab: | |
with gr.Row(): | |
video_input = gr.Video(label="Upload Video", interactive=True) | |
metadata_output = gr.JSON(label="Generation Parameters") | |
with gr.Row(): | |
send_to_t2v_btn = gr.Button("Send to Text2Video", variant="primary") | |
send_to_v2v_btn = gr.Button("Send to Video2Video", variant="primary") | |
send_to_wanx_i2v_btn = gr.Button("Send to WanX-i2v", variant="primary") | |
send_to_wanx_t2v_btn = gr.Button("Send to WanX-t2v", variant="primary") | |
with gr.Row(): | |
status = gr.Textbox(label="Status", interactive=False) | |
#Merge Model's tab | |
with gr.Tab("Convert LoRA") as convert_lora_tab: | |
def suggest_output_name(file_obj) -> str: | |
"""Generate suggested output name from input file""" | |
if not file_obj: | |
return "" | |
# Get input filename without extension and add MUSUBI | |
base_name = os.path.splitext(os.path.basename(file_obj.name))[0] | |
return f"{base_name}_MUSUBI" | |
def convert_lora(input_file, output_name: str, target_format: str) -> str: | |
"""Convert LoRA file to specified format""" | |
try: | |
if not input_file: | |
return "Error: No input file selected" | |
# Ensure output directory exists | |
os.makedirs("lora", exist_ok=True) | |
# Construct output path | |
output_path = os.path.join("lora", f"{output_name}.safetensors") | |
# Build command | |
cmd = [ | |
sys.executable, | |
"convert_lora.py", | |
"--input", input_file.name, | |
"--output", output_path, | |
"--target", target_format | |
] | |
print(f"Converting {input_file.name} to {output_path}") | |
# Execute conversion | |
result = subprocess.run( | |
cmd, | |
capture_output=True, | |
text=True, | |
check=True | |
) | |
if os.path.exists(output_path): | |
return f"Successfully converted LoRA to {output_path}" | |
else: | |
return "Error: Output file not created" | |
except subprocess.CalledProcessError as e: | |
return f"Error during conversion: {e.stderr}" | |
except Exception as e: | |
return f"Error: {str(e)}" | |
with gr.Row(): | |
input_file = gr.File(label="Input LoRA File", file_types=[".safetensors"]) | |
output_name = gr.Textbox(label="Output Name", placeholder="Output filename (without extension)") | |
format_radio = gr.Radio( | |
choices=["default", "other"], | |
value="default", | |
label="Target Format", | |
info="Choose 'default' for H1111/MUSUBI format or 'other' for diffusion pipe format" | |
) | |
with gr.Row(): | |
convert_btn = gr.Button("Convert LoRA", variant="primary") | |
status_output = gr.Textbox(label="Status", interactive=False) | |
# Automatically update output name when file is selected | |
input_file.change( | |
fn=suggest_output_name, | |
inputs=[input_file], | |
outputs=[output_name] | |
) | |
# Handle conversion | |
convert_btn.click( | |
fn=convert_lora, | |
inputs=[input_file, output_name, format_radio], | |
outputs=status_output | |
) | |
with gr.Tab("Model Merging") as model_merge_tab: | |
with gr.Row(): | |
with gr.Column(): | |
# Model selection | |
dit_model = gr.Dropdown( | |
label="Base DiT Model", | |
choices=["mp_rank_00_model_states.pt"], | |
value="mp_rank_00_model_states.pt", | |
allow_custom_value=True, | |
interactive=True | |
) | |
merge_refresh_btn = gr.Button("🔄", elem_classes="refresh-btn") | |
with gr.Row(): | |
with gr.Column(): | |
# Output model name | |
output_model = gr.Textbox(label="Output Model Name", value="merged_model.safetensors") | |
exclude_single_blocks = gr.Checkbox(label="Exclude Single Blocks", value=False) | |
merge_btn = gr.Button("Merge Models", variant="primary") | |
merge_status = gr.Textbox(label="Status", interactive=False) | |
with gr.Row(): | |
# LoRA selection section (similar to Text2Video) | |
merge_lora_weights = [] | |
merge_lora_multipliers = [] | |
for i in range(4): | |
with gr.Column(): | |
merge_lora_weights.append(gr.Dropdown( | |
label=f"LoRA {i+1}", | |
choices=get_lora_options(), | |
value="None", | |
allow_custom_value=True, | |
interactive=True | |
)) | |
merge_lora_multipliers.append(gr.Slider( | |
label=f"Multiplier", | |
minimum=0.0, | |
maximum=2.0, | |
step=0.05, | |
value=1.0 | |
)) | |
with gr.Row(): | |
merge_lora_folder = gr.Textbox(label="LoRA Folder", value="lora") | |
dit_folder = gr.Textbox(label="DiT Model Folder", value="hunyuan") | |
#Video Extension | |
wanx_send_last_frame_btn.click( | |
fn=send_last_frame_handler, | |
inputs=[wanx_output, wanx_i2v_selected_index], | |
outputs=[wanx_input, wanx_base_video] | |
) | |
wanx_extend_btn.click( | |
fn=prepare_for_batch_extension, | |
inputs=[wanx_input, wanx_base_video, wanx_batch_size], | |
outputs=[wanx_input, wanx_base_video, wanx_batch_size, wanx_batch_progress, wanx_progress_text] | |
).then( | |
fn=wanx_batch_handler, | |
inputs=[ | |
gr.Checkbox(value=False), # Not using random folder | |
wanx_prompt, wanx_negative_prompt, | |
wanx_width, wanx_height, wanx_video_length, | |
wanx_fps, wanx_infer_steps, wanx_flow_shift, | |
wanx_guidance_scale, wanx_seed, wanx_batch_size, | |
wanx_input_folder, # Not used but needed for function signature | |
wanx_task, | |
wanx_dit_path, wanx_vae_path, wanx_t5_path, | |
wanx_clip_path, wanx_save_path, wanx_output_type, | |
wanx_sample_solver, wanx_exclude_single_blocks, | |
wanx_attn_mode, wanx_block_swap, wanx_fp8, | |
wanx_fp8_t5, wanx_lora_folder, *wanx_lora_weights, | |
*wanx_lora_multipliers, wanx_input # Include input image | |
], | |
outputs=[wanx_output, wanx_batch_progress, wanx_progress_text] | |
).then( | |
fn=concat_batch_videos, | |
inputs=[wanx_base_video, wanx_output, wanx_save_path], | |
outputs=[wanx_output, wanx_progress_text] | |
) | |
# Extract and send sharpest frame to input | |
wanx_send_sharpest_frame_btn.click( | |
fn=send_sharpest_frame_handler, | |
inputs=[wanx_output, wanx_i2v_selected_index, wanx_frames_to_check], | |
outputs=[wanx_input, wanx_base_video, wanx_sharpest_frame_number, wanx_sharpest_frame_status] | |
) | |
# Trim video to sharpest frame and prepare for extension | |
wanx_trim_and_extend_btn.click( | |
fn=trim_and_prepare_for_extension, | |
inputs=[wanx_base_video, wanx_sharpest_frame_number, wanx_save_path], | |
outputs=[wanx_trimmed_video_path, wanx_sharpest_frame_status] | |
).then( | |
fn=lambda path, status: (path, status if "Failed" in status else "Video trimmed successfully and ready for extension"), | |
inputs=[wanx_trimmed_video_path, wanx_sharpest_frame_status], | |
outputs=[wanx_base_video, wanx_sharpest_frame_status] | |
) | |
# Event handler for extending with the trimmed video | |
wanx_extend_with_trimmed_btn.click( | |
fn=prepare_for_batch_extension, | |
inputs=[wanx_input, wanx_trimmed_video_path, wanx_batch_size], | |
outputs=[wanx_input, wanx_base_video, wanx_batch_size, wanx_batch_progress, wanx_progress_text] | |
).then( | |
fn=wanx_batch_handler, | |
inputs=[ | |
gr.Checkbox(value=False), # Not using random folder | |
wanx_prompt, wanx_negative_prompt, | |
wanx_width, wanx_height, wanx_video_length, | |
wanx_fps, wanx_infer_steps, wanx_flow_shift, | |
wanx_guidance_scale, wanx_seed, wanx_batch_size, | |
wanx_input_folder, # Not used but needed for function signature | |
wanx_task, | |
wanx_dit_path, wanx_vae_path, wanx_t5_path, | |
wanx_clip_path, wanx_save_path, wanx_output_type, | |
wanx_sample_solver, wanx_exclude_single_blocks, | |
wanx_attn_mode, wanx_block_swap, wanx_fp8, | |
wanx_fp8_t5, wanx_lora_folder, *wanx_lora_weights, | |
*wanx_lora_multipliers, wanx_input # Include input image | |
], | |
outputs=[wanx_output, wanx_batch_progress, wanx_progress_text] | |
).then( | |
fn=concat_batch_videos, | |
inputs=[wanx_trimmed_video_path, wanx_output, wanx_save_path], | |
outputs=[wanx_output, wanx_progress_text] | |
) | |
#Video Info | |
def handle_send_to_wanx_tab(metadata, target_tab): | |
"""Common handler for sending video parameters to WanX tabs""" | |
if not metadata: | |
return "No parameters to send", {} | |
# Tab names for clearer messages | |
tab_names = { | |
'wanx_i2v': 'WanX-i2v', | |
'wanx_t2v': 'WanX-t2v' | |
} | |
# Just pass through all parameters - we'll use them in the .then() function | |
return f"Parameters ready for {tab_names.get(target_tab, target_tab)}", metadata | |
def change_to_wanx_i2v_tab(): | |
return gr.Tabs(selected=4) # WanX-i2v tab index | |
def change_to_wanx_t2v_tab(): | |
return gr.Tabs(selected=5) # WanX-t2v tab index | |
send_to_wanx_i2v_btn.click( | |
fn=lambda m: handle_send_to_wanx_tab(m, 'wanx_i2v'), | |
inputs=[metadata_output], | |
outputs=[status, params_state] | |
).then( | |
lambda params: [ | |
params.get("prompt", ""), | |
params.get("width", 832), | |
params.get("height", 480), | |
params.get("video_length", 81), | |
params.get("fps", 16), | |
params.get("infer_steps", 40), | |
params.get("seed", -1), | |
params.get("flow_shift", 3.0), | |
params.get("guidance_scale", 5.0), | |
params.get("attn_mode", "sdpa"), | |
params.get("block_swap", 0), | |
params.get("task", "i2v-14B") | |
] if params else [gr.update()]*12, | |
inputs=params_state, | |
outputs=[ | |
wanx_prompt, | |
wanx_width, | |
wanx_height, | |
wanx_video_length, | |
wanx_fps, | |
wanx_infer_steps, | |
wanx_seed, | |
wanx_flow_shift, | |
wanx_guidance_scale, | |
wanx_attn_mode, | |
wanx_block_swap, | |
wanx_task | |
] | |
).then( | |
fn=change_to_wanx_i2v_tab, inputs=None, outputs=[tabs] | |
) | |
# 3. Update the WanX-t2v button handler | |
send_to_wanx_t2v_btn.click( | |
fn=lambda m: handle_send_to_wanx_tab(m, 'wanx_t2v'), | |
inputs=[metadata_output], | |
outputs=[status, params_state] | |
).then( | |
lambda params: [ | |
params.get("prompt", ""), | |
params.get("width", 832), | |
params.get("height", 480), | |
params.get("video_length", 81), | |
params.get("fps", 16), | |
params.get("infer_steps", 50), | |
params.get("seed", -1), | |
params.get("flow_shift", 5.0), | |
params.get("guidance_scale", 5.0), | |
params.get("attn_mode", "sdpa"), | |
params.get("block_swap", 0) | |
] if params else [gr.update()]*11, | |
inputs=params_state, | |
outputs=[ | |
wanx_t2v_prompt, | |
wanx_t2v_width, | |
wanx_t2v_height, | |
wanx_t2v_video_length, | |
wanx_t2v_fps, | |
wanx_t2v_infer_steps, | |
wanx_t2v_seed, | |
wanx_t2v_flow_shift, | |
wanx_t2v_guidance_scale, | |
wanx_t2v_attn_mode, | |
wanx_t2v_block_swap | |
] | |
).then( | |
fn=change_to_wanx_t2v_tab, inputs=None, outputs=[tabs] | |
) | |
#text to video | |
def change_to_tab_one(): | |
return gr.Tabs(selected=1) #This will navigate | |
#video to video | |
def change_to_tab_two(): | |
return gr.Tabs(selected=2) #This will navigate | |
def change_to_skyreels_tab(): | |
return gr.Tabs(selected=3) | |
#SKYREELS TAB!!! | |
# Add state management for dimensions | |
def sync_skyreels_dimensions(width, height): | |
return gr.update(value=width), gr.update(value=height) | |
# Add this function to update the LoRA dropdowns in the SKYREELS tab | |
def update_skyreels_lora_dropdowns(lora_folder: str, *current_values) -> List[gr.update]: | |
new_choices = get_lora_options(lora_folder) | |
weights = current_values[:4] | |
multipliers = current_values[4:8] | |
results = [] | |
for i in range(4): | |
weight = weights[i] if i < len(weights) else "None" | |
multiplier = multipliers[i] if i < len(multipliers) else 1.0 | |
if weight not in new_choices: | |
weight = "None" | |
results.extend([ | |
gr.update(choices=new_choices, value=weight), | |
gr.update(value=multiplier) | |
]) | |
return results | |
# Add this function to update the models dropdown in the SKYREELS tab | |
def update_skyreels_model_dropdown(dit_folder: str) -> Dict: | |
models = get_dit_models(dit_folder) | |
return gr.update(choices=models, value=models[0] if models else None) | |
# Add event handler for model dropdown refresh | |
skyreels_dit_folder.change( | |
fn=update_skyreels_model_dropdown, | |
inputs=[skyreels_dit_folder], | |
outputs=[skyreels_model] | |
) | |
# Add handlers for the refresh button | |
skyreels_refresh_btn.click( | |
fn=update_skyreels_lora_dropdowns, | |
inputs=[skyreels_lora_folder] + skyreels_lora_weights + skyreels_lora_multipliers, | |
outputs=[drop for _ in range(4) for drop in [skyreels_lora_weights[_], skyreels_lora_multipliers[_]]] | |
) | |
# Skyreels dimension handling | |
def calculate_skyreels_width(height, original_dims): | |
if not original_dims: | |
return gr.update() | |
orig_w, orig_h = map(int, original_dims.split('x')) | |
aspect_ratio = orig_w / orig_h | |
new_width = math.floor((height * aspect_ratio) / 16) * 16 | |
return gr.update(value=new_width) | |
def calculate_skyreels_height(width, original_dims): | |
if not original_dims: | |
return gr.update() | |
orig_w, orig_h = map(int, original_dims.split('x')) | |
aspect_ratio = orig_w / orig_h | |
new_height = math.floor((width / aspect_ratio) / 16) * 16 | |
return gr.update(value=new_height) | |
def update_skyreels_from_scale(scale, original_dims): | |
if not original_dims: | |
return gr.update(), gr.update() | |
orig_w, orig_h = map(int, original_dims.split('x')) | |
new_w = math.floor((orig_w * scale / 100) / 16) * 16 | |
new_h = math.floor((orig_h * scale / 100) / 16) * 16 | |
return gr.update(value=new_w), gr.update(value=new_h) | |
def update_skyreels_dimensions(image): | |
if image is None: | |
return "", gr.update(value=544), gr.update(value=544) | |
img = Image.open(image) | |
w, h = img.size | |
w = (w // 16) * 16 | |
h = (h // 16) * 16 | |
return f"{w}x{h}", w, h | |
def handle_skyreels_gallery_select(evt: gr.SelectData) -> int: | |
return evt.index | |
def send_skyreels_to_v2v( | |
gallery: list, | |
prompt: str, | |
selected_index: int, | |
width: int, | |
height: int, | |
video_length: int, | |
fps: int, | |
infer_steps: int, | |
seed: int, | |
flow_shift: float, | |
cfg_scale: float, | |
lora1: str, | |
lora2: str, | |
lora3: str, | |
lora4: str, | |
lora1_multiplier: float, | |
lora2_multiplier: float, | |
lora3_multiplier: float, | |
lora4_multiplier: float, | |
negative_prompt: str = "" # Add this parameter | |
) -> Tuple: | |
if not gallery or selected_index is None or selected_index >= len(gallery): | |
return (None, "", width, height, video_length, fps, infer_steps, seed, | |
flow_shift, cfg_scale, lora1, lora2, lora3, lora4, | |
lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, | |
negative_prompt) # Add negative_prompt to return | |
selected_item = gallery[selected_index] | |
if isinstance(selected_item, dict): | |
video_path = selected_item.get("name", selected_item.get("data", None)) | |
elif isinstance(selected_item, (tuple, list)): | |
video_path = selected_item[0] | |
else: | |
video_path = selected_item | |
if isinstance(video_path, tuple): | |
video_path = video_path[0] | |
return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, | |
flow_shift, cfg_scale, lora1, lora2, lora3, lora4, | |
lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, | |
negative_prompt) # Add negative_prompt to return | |
# Add event handlers for the SKYREELS tab | |
skyreels_prompt.change(fn=count_prompt_tokens, inputs=skyreels_prompt, outputs=skyreels_token_counter) | |
skyreels_stop_btn.click(fn=lambda: stop_event.set(), queue=False) | |
# Image input handling | |
skyreels_input.change( | |
fn=update_skyreels_dimensions, | |
inputs=[skyreels_input], | |
outputs=[skyreels_original_dims, skyreels_width, skyreels_height] | |
) | |
skyreels_scale_slider.change( | |
fn=update_skyreels_from_scale, | |
inputs=[skyreels_scale_slider, skyreels_original_dims], | |
outputs=[skyreels_width, skyreels_height] | |
) | |
skyreels_calc_width_btn.click( | |
fn=calculate_skyreels_width, | |
inputs=[skyreels_height, skyreels_original_dims], | |
outputs=[skyreels_width] | |
) | |
skyreels_calc_height_btn.click( | |
fn=calculate_skyreels_height, | |
inputs=[skyreels_width, skyreels_original_dims], | |
outputs=[skyreels_height] | |
) | |
# Handle checkbox visibility toggling | |
skyreels_use_random_folder.change( | |
fn=lambda x: (gr.update(visible=x), gr.update(visible=x), gr.update(visible=not x)), | |
inputs=[skyreels_use_random_folder], | |
outputs=[skyreels_input_folder, skyreels_folder_status, skyreels_input] | |
) | |
# Validate folder button click handler | |
skyreels_validate_folder_btn.click( | |
fn=lambda folder: get_random_image_from_folder(folder)[1], | |
inputs=[skyreels_input_folder], | |
outputs=[skyreels_folder_status] | |
) | |
skyreels_use_random_folder.change( | |
fn=lambda x: gr.update(visible=x), | |
inputs=[skyreels_use_random_folder], | |
outputs=[skyreels_validate_folder_btn] | |
) | |
# Modify the skyreels_generate_btn.click event handler to use process_random_image_batch when folder mode is on | |
skyreels_generate_btn.click( | |
fn=batch_handler, | |
inputs=[ | |
skyreels_use_random_folder, | |
# Rest of the arguments | |
skyreels_prompt, | |
skyreels_negative_prompt, | |
skyreels_width, | |
skyreels_height, | |
skyreels_video_length, | |
skyreels_fps, | |
skyreels_infer_steps, | |
skyreels_seed, | |
skyreels_flow_shift, | |
skyreels_guidance_scale, | |
skyreels_embedded_cfg_scale, | |
skyreels_batch_size, | |
skyreels_input_folder, | |
skyreels_dit_folder, | |
skyreels_model, | |
skyreels_vae, | |
skyreels_te1, | |
skyreels_te2, | |
skyreels_save_path, | |
skyreels_output_type, | |
skyreels_attn_mode, | |
skyreels_block_swap, | |
skyreels_exclude_single_blocks, | |
skyreels_use_split_attn, | |
skyreels_use_fp8, | |
skyreels_split_uncond, | |
skyreels_lora_folder, | |
*skyreels_lora_weights, | |
*skyreels_lora_multipliers, | |
skyreels_input # Add the input image path | |
], | |
outputs=[skyreels_output, skyreels_batch_progress, skyreels_progress_text], | |
queue=True | |
).then( | |
fn=lambda batch_size: 0 if batch_size == 1 else None, | |
inputs=[skyreels_batch_size], | |
outputs=skyreels_selected_index | |
) | |
# Gallery selection handling | |
skyreels_output.select( | |
fn=handle_skyreels_gallery_select, | |
outputs=skyreels_selected_index | |
) | |
# Send to Video2Video handler | |
skyreels_send_to_v2v_btn.click( | |
fn=send_skyreels_to_v2v, | |
inputs=[ | |
skyreels_output, skyreels_prompt, skyreels_selected_index, | |
skyreels_width, skyreels_height, skyreels_video_length, | |
skyreels_fps, skyreels_infer_steps, skyreels_seed, | |
skyreels_flow_shift, skyreels_guidance_scale | |
] + skyreels_lora_weights + skyreels_lora_multipliers + [skyreels_negative_prompt], # This is ok because skyreels_negative_prompt is a Gradio component | |
outputs=[ | |
v2v_input, v2v_prompt, v2v_width, v2v_height, | |
v2v_video_length, v2v_fps, v2v_infer_steps, | |
v2v_seed, v2v_flow_shift, v2v_cfg_scale | |
] + v2v_lora_weights + v2v_lora_multipliers + [v2v_negative_prompt] | |
).then( | |
fn=change_to_tab_two, | |
inputs=None, | |
outputs=[tabs] | |
) | |
# Refresh button handler | |
skyreels_refresh_outputs = [skyreels_model] | |
for i in range(4): | |
skyreels_refresh_outputs.extend([skyreels_lora_weights[i], skyreels_lora_multipliers[i]]) | |
skyreels_refresh_btn.click( | |
fn=update_dit_and_lora_dropdowns, | |
inputs=[skyreels_dit_folder, skyreels_lora_folder, skyreels_model] + skyreels_lora_weights + skyreels_lora_multipliers, | |
outputs=skyreels_refresh_outputs | |
) | |
def calculate_v2v_width(height, original_dims): | |
if not original_dims: | |
return gr.update() | |
orig_w, orig_h = map(int, original_dims.split('x')) | |
aspect_ratio = orig_w / orig_h | |
new_width = math.floor((height * aspect_ratio) / 16) * 16 # Ensure divisible by 16 | |
return gr.update(value=new_width) | |
def calculate_v2v_height(width, original_dims): | |
if not original_dims: | |
return gr.update() | |
orig_w, orig_h = map(int, original_dims.split('x')) | |
aspect_ratio = orig_w / orig_h | |
new_height = math.floor((width / aspect_ratio) / 16) * 16 # Ensure divisible by 16 | |
return gr.update(value=new_height) | |
def update_v2v_from_scale(scale, original_dims): | |
if not original_dims: | |
return gr.update(), gr.update() | |
orig_w, orig_h = map(int, original_dims.split('x')) | |
new_w = math.floor((orig_w * scale / 100) / 16) * 16 # Ensure divisible by 16 | |
new_h = math.floor((orig_h * scale / 100) / 16) * 16 # Ensure divisible by 16 | |
return gr.update(value=new_w), gr.update(value=new_h) | |
def update_v2v_dimensions(video): | |
if video is None: | |
return "", gr.update(value=544), gr.update(value=544) | |
cap = cv2.VideoCapture(video) | |
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
cap.release() | |
# Make dimensions divisible by 16 | |
w = (w // 16) * 16 | |
h = (h // 16) * 16 | |
return f"{w}x{h}", w, h | |
# Event Handlers for Video to Video Tab | |
v2v_input.change( | |
fn=update_v2v_dimensions, | |
inputs=[v2v_input], | |
outputs=[v2v_original_dims, v2v_width, v2v_height] | |
) | |
v2v_scale_slider.change( | |
fn=update_v2v_from_scale, | |
inputs=[v2v_scale_slider, v2v_original_dims], | |
outputs=[v2v_width, v2v_height] | |
) | |
v2v_calc_width_btn.click( | |
fn=calculate_v2v_width, | |
inputs=[v2v_height, v2v_original_dims], | |
outputs=[v2v_width] | |
) | |
v2v_calc_height_btn.click( | |
fn=calculate_v2v_height, | |
inputs=[v2v_width, v2v_original_dims], | |
outputs=[v2v_height] | |
) | |
##Image 2 video dimension logic | |
def calculate_width(height, original_dims): | |
if not original_dims: | |
return gr.update() | |
orig_w, orig_h = map(int, original_dims.split('x')) | |
aspect_ratio = orig_w / orig_h | |
new_width = math.floor((height * aspect_ratio) / 16) * 16 # Changed from 8 to 16 | |
return gr.update(value=new_width) | |
def calculate_height(width, original_dims): | |
if not original_dims: | |
return gr.update() | |
orig_w, orig_h = map(int, original_dims.split('x')) | |
aspect_ratio = orig_w / orig_h | |
new_height = math.floor((width / aspect_ratio) / 16) * 16 # Changed from 8 to 16 | |
return gr.update(value=new_height) | |
def update_from_scale(scale, original_dims): | |
if not original_dims: | |
return gr.update(), gr.update() | |
orig_w, orig_h = map(int, original_dims.split('x')) | |
new_w = math.floor((orig_w * scale / 100) / 16) * 16 # Changed from 8 to 16 | |
new_h = math.floor((orig_h * scale / 100) / 16) * 16 # Changed from 8 to 16 | |
return gr.update(value=new_w), gr.update(value=new_h) | |
def update_dimensions(image): | |
if image is None: | |
return "", gr.update(value=544), gr.update(value=544) | |
img = Image.open(image) | |
w, h = img.size | |
# Make dimensions divisible by 16 | |
w = (w // 16) * 16 # Changed from 8 to 16 | |
h = (h // 16) * 16 # Changed from 8 to 16 | |
return f"{w}x{h}", w, h | |
i2v_input.change( | |
fn=update_dimensions, | |
inputs=[i2v_input], | |
outputs=[original_dims, width, height] | |
) | |
scale_slider.change( | |
fn=update_from_scale, | |
inputs=[scale_slider, original_dims], | |
outputs=[width, height] | |
) | |
calc_width_btn.click( | |
fn=calculate_width, | |
inputs=[height, original_dims], | |
outputs=[width] | |
) | |
calc_height_btn.click( | |
fn=calculate_height, | |
inputs=[width, original_dims], | |
outputs=[height] | |
) | |
# Function to get available DiT models | |
def get_dit_models(dit_folder: str) -> List[str]: | |
if not os.path.exists(dit_folder): | |
return ["mp_rank_00_model_states.pt"] | |
models = [f for f in os.listdir(dit_folder) if f.endswith('.pt') or f.endswith('.safetensors')] | |
models.sort(key=str.lower) | |
return models if models else ["mp_rank_00_model_states.pt"] | |
# Function to perform model merging | |
def merge_models( | |
dit_folder: str, | |
dit_model: str, | |
output_model: str, | |
exclude_single_blocks: bool, | |
merge_lora_folder: str, | |
*lora_params # Will contain both weights and multipliers | |
) -> str: | |
try: | |
# Separate weights and multipliers | |
num_loras = len(lora_params) // 2 | |
weights = list(lora_params[:num_loras]) | |
multipliers = list(lora_params[num_loras:]) | |
# Filter out "None" selections | |
valid_loras = [] | |
for weight, mult in zip(weights, multipliers): | |
if weight and weight != "None": | |
valid_loras.append((os.path.join(merge_lora_folder, weight), mult)) | |
if not valid_loras: | |
return "No LoRA models selected for merging" | |
# Create output path in the dit folder | |
os.makedirs(dit_folder, exist_ok=True) | |
output_path = os.path.join(dit_folder, output_model) | |
# Prepare command | |
cmd = [ | |
sys.executable, | |
"merge_lora.py", | |
"--dit", os.path.join(dit_folder, dit_model), | |
"--save_merged_model", output_path | |
] | |
# Add LoRA weights and multipliers | |
weights = [weight for weight, _ in valid_loras] | |
multipliers = [str(mult) for _, mult in valid_loras] | |
cmd.extend(["--lora_weight"] + weights) | |
cmd.extend(["--lora_multiplier"] + multipliers) | |
if exclude_single_blocks: | |
cmd.append("--exclude_single_blocks") | |
# Execute merge operation | |
result = subprocess.run( | |
cmd, | |
capture_output=True, | |
text=True, | |
check=True | |
) | |
if os.path.exists(output_path): | |
return f"Successfully merged model and saved to {output_path}" | |
else: | |
return "Error: Output file not created" | |
except subprocess.CalledProcessError as e: | |
return f"Error during merging: {e.stderr}" | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Update DiT model dropdown | |
def update_dit_dropdown(dit_folder: str) -> Dict: | |
models = get_dit_models(dit_folder) | |
return gr.update(choices=models, value=models[0] if models else None) | |
# Connect events | |
merge_btn.click( | |
fn=merge_models, | |
inputs=[ | |
dit_folder, | |
dit_model, | |
output_model, | |
exclude_single_blocks, | |
merge_lora_folder, | |
*merge_lora_weights, | |
*merge_lora_multipliers | |
], | |
outputs=merge_status | |
) | |
# Refresh buttons for both DiT and LoRA dropdowns | |
merge_refresh_btn.click( | |
fn=lambda f: update_dit_dropdown(f), | |
inputs=[dit_folder], | |
outputs=[dit_model] | |
) | |
# LoRA refresh handling | |
merge_refresh_outputs = [] | |
for i in range(4): | |
merge_refresh_outputs.extend([merge_lora_weights[i], merge_lora_multipliers[i]]) | |
merge_refresh_btn.click( | |
fn=update_lora_dropdowns, | |
inputs=[merge_lora_folder] + merge_lora_weights + merge_lora_multipliers, | |
outputs=merge_refresh_outputs | |
) | |
# Event handlers | |
prompt.change(fn=count_prompt_tokens, inputs=prompt, outputs=token_counter) | |
v2v_prompt.change(fn=count_prompt_tokens, inputs=v2v_prompt, outputs=v2v_token_counter) | |
stop_btn.click(fn=lambda: stop_event.set(), queue=False) | |
v2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) | |
#Image_to_Video | |
def image_to_video(image_path, output_path, width, height, frames=240): # Add width, height parameters | |
img = Image.open(image_path) | |
# Resize to the specified dimensions | |
img_resized = img.resize((width, height), Image.LANCZOS) | |
temp_image_path = os.path.join(os.path.dirname(output_path), "temp_resized_image.png") | |
img_resized.save(temp_image_path) | |
# Rest of function remains the same | |
frame_rate = 24 | |
duration = frames / frame_rate | |
command = [ | |
"ffmpeg", "-loop", "1", "-i", temp_image_path, "-c:v", "libx264", | |
"-t", str(duration), "-pix_fmt", "yuv420p", | |
"-vf", f"fps={frame_rate}", output_path | |
] | |
try: | |
subprocess.run(command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) | |
print(f"Video saved to {output_path}") | |
return True | |
except subprocess.CalledProcessError as e: | |
print(f"An error occurred while creating the video: {e}") | |
return False | |
finally: | |
# Clean up the temporary image file | |
if os.path.exists(temp_image_path): | |
os.remove(temp_image_path) | |
img.close() # Make sure to close the image file explicitly | |
def generate_from_image( | |
image_path, | |
prompt, width, height, video_length, fps, infer_steps, | |
seed, model, vae, te1, te2, save_path, flow_shift, cfg_scale, | |
output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, | |
lora_folder, strength, batch_size, *lora_params | |
): | |
"""Generate video from input image with progressive updates""" | |
global stop_event | |
stop_event.clear() | |
# Create temporary video path | |
temp_video_path = os.path.join(save_path, f"temp_{os.path.basename(image_path)}.mp4") | |
try: | |
# Convert image to video | |
if not image_to_video(image_path, temp_video_path, width, height, frames=video_length): | |
yield [], "Failed to create temporary video", "Error in video creation" | |
return | |
# Ensure video is fully written before proceeding | |
time.sleep(1) | |
if not os.path.exists(temp_video_path) or os.path.getsize(temp_video_path) == 0: | |
yield [], "Failed to create temporary video", "Temporary video file is empty or missing" | |
return | |
# Get video dimensions | |
try: | |
probe = ffmpeg.probe(temp_video_path) | |
video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None) | |
if video_stream is None: | |
raise ValueError("No video stream found") | |
width = int(video_stream['width']) | |
height = int(video_stream['height']) | |
except Exception as e: | |
yield [], f"Error reading video dimensions: {str(e)}", "Video processing error" | |
return | |
# Generate the video using the temporary file | |
try: | |
generator = process_single_video( | |
prompt, width, height, batch_size, video_length, fps, infer_steps, | |
seed, model, vae, te1, te2, save_path, flow_shift, cfg_scale, | |
output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, | |
lora_folder, *lora_params, video_path=temp_video_path, strength=strength | |
) | |
# Forward all generator updates | |
for videos, batch_text, progress_text in generator: | |
yield videos, batch_text, progress_text | |
except Exception as e: | |
yield [], f"Error in video generation: {str(e)}", "Generation error" | |
return | |
except Exception as e: | |
yield [], f"Unexpected error: {str(e)}", "Error occurred" | |
return | |
finally: | |
# Clean up temporary file | |
try: | |
if os.path.exists(temp_video_path): | |
os.remove(temp_video_path) | |
except Exception: | |
pass # Ignore cleanup errors | |
# Add event handlers | |
i2v_prompt.change(fn=count_prompt_tokens, inputs=i2v_prompt, outputs=i2v_token_counter) | |
i2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) | |
def handle_i2v_gallery_select(evt: gr.SelectData) -> int: | |
"""Track selected index when I2V gallery item is clicked""" | |
return evt.index | |
def send_i2v_to_v2v( | |
gallery: list, | |
prompt: str, | |
selected_index: int, | |
width: int, | |
height: int, | |
video_length: int, | |
fps: int, | |
infer_steps: int, | |
seed: int, | |
flow_shift: float, | |
cfg_scale: float, | |
lora1: str, | |
lora2: str, | |
lora3: str, | |
lora4: str, | |
lora1_multiplier: float, | |
lora2_multiplier: float, | |
lora3_multiplier: float, | |
lora4_multiplier: float | |
) -> Tuple[Optional[str], str, int, int, int, int, int, int, float, float, str, str, str, str, float, float, float, float]: | |
"""Send the selected video and parameters from Image2Video tab to Video2Video tab""" | |
if not gallery or selected_index is None or selected_index >= len(gallery): | |
return None, "", width, height, video_length, fps, infer_steps, seed, flow_shift, cfg_scale, \ | |
lora1, lora2, lora3, lora4, lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier | |
selected_item = gallery[selected_index] | |
# Handle different gallery item formats | |
if isinstance(selected_item, dict): | |
video_path = selected_item.get("name", selected_item.get("data", None)) | |
elif isinstance(selected_item, (tuple, list)): | |
video_path = selected_item[0] | |
else: | |
video_path = selected_item | |
# Final cleanup for Gradio Video component | |
if isinstance(video_path, tuple): | |
video_path = video_path[0] | |
# Use the original width and height without doubling | |
return (str(video_path), prompt, width, height, video_length, fps, infer_steps, seed, | |
flow_shift, cfg_scale, lora1, lora2, lora3, lora4, | |
lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier) | |
# Generate button handler | |
i2v_generate_btn.click( | |
fn=process_batch, | |
inputs=[ | |
i2v_prompt, width, height, | |
i2v_batch_size, i2v_video_length, | |
i2v_fps, i2v_infer_steps, i2v_seed, i2v_dit_folder, i2v_model, i2v_vae, i2v_te1, i2v_te2, | |
i2v_save_path, i2v_flow_shift, i2v_cfg_scale, i2v_output_type, i2v_attn_mode, | |
i2v_block_swap, i2v_exclude_single_blocks, i2v_use_split_attn, i2v_lora_folder, | |
*i2v_lora_weights, *i2v_lora_multipliers, i2v_input, i2v_strength, i2v_use_fp8 | |
], | |
outputs=[i2v_output, i2v_batch_progress, i2v_progress_text], | |
queue=True | |
).then( | |
fn=lambda batch_size: 0 if batch_size == 1 else None, | |
inputs=[i2v_batch_size], | |
outputs=i2v_selected_index | |
) | |
# Send to Video2Video | |
i2v_output.select( | |
fn=handle_i2v_gallery_select, | |
outputs=i2v_selected_index | |
) | |
i2v_send_to_v2v_btn.click( | |
fn=send_i2v_to_v2v, | |
inputs=[ | |
i2v_output, i2v_prompt, i2v_selected_index, | |
width, height, | |
i2v_video_length, i2v_fps, i2v_infer_steps, | |
i2v_seed, i2v_flow_shift, i2v_cfg_scale | |
] + i2v_lora_weights + i2v_lora_multipliers, | |
outputs=[ | |
v2v_input, v2v_prompt, | |
v2v_width, v2v_height, | |
v2v_video_length, v2v_fps, v2v_infer_steps, | |
v2v_seed, v2v_flow_shift, v2v_cfg_scale | |
] + v2v_lora_weights + v2v_lora_multipliers | |
).then( | |
fn=change_to_tab_two, inputs=None, outputs=[tabs] | |
) | |
#Video Info | |
def clean_video_path(video_path) -> str: | |
"""Extract clean video path from Gradio's various return formats""" | |
print(f"Input video_path: {video_path}, type: {type(video_path)}") | |
if isinstance(video_path, dict): | |
path = video_path.get("name", "") | |
elif isinstance(video_path, (tuple, list)): | |
path = video_path[0] | |
elif isinstance(video_path, str): | |
path = video_path | |
else: | |
path = "" | |
print(f"Cleaned path: {path}") | |
return path | |
def handle_video_upload(video_path: str) -> Dict: | |
"""Handle video upload and metadata extraction""" | |
if not video_path: | |
return {}, "No video uploaded" | |
metadata = extract_video_metadata(video_path) | |
if not metadata: | |
return {}, "No metadata found in video" | |
return metadata, "Metadata extracted successfully" | |
def get_video_info(video_path: str) -> dict: | |
try: | |
probe = ffmpeg.probe(video_path) | |
video_info = next(stream for stream in probe['streams'] if stream['codec_type'] == 'video') | |
width = int(video_info['width']) | |
height = int(video_info['height']) | |
fps = eval(video_info['r_frame_rate']) # This converts '30/1' to 30.0 | |
# Calculate total frames | |
duration = float(probe['format']['duration']) | |
total_frames = int(duration * fps) | |
# Ensure video length does not exceed 201 frames | |
if total_frames > 201: | |
total_frames = 201 | |
duration = total_frames / fps # Adjust duration accordingly | |
return { | |
'width': width, | |
'height': height, | |
'fps': fps, | |
'total_frames': total_frames, | |
'duration': duration # Might be useful in some contexts | |
} | |
except Exception as e: | |
print(f"Error extracting video info: {e}") | |
return {} | |
def extract_video_details(video_path: str) -> Tuple[dict, str]: | |
metadata = extract_video_metadata(video_path) | |
video_details = get_video_info(video_path) | |
# Combine metadata with video details | |
for key, value in video_details.items(): | |
if key not in metadata: | |
metadata[key] = value | |
# Ensure video length does not exceed 201 frames | |
if 'video_length' in metadata: | |
metadata['video_length'] = min(metadata['video_length'], 201) | |
else: | |
metadata['video_length'] = min(video_details.get('total_frames', 0), 201) | |
# Return both the updated metadata and a status message | |
return metadata, "Video details extracted successfully" | |
def send_parameters_to_tab(metadata: Dict, target_tab: str) -> Tuple[str, Dict]: | |
"""Create parameter mapping for target tab""" | |
if not metadata: | |
return "No parameters to send", {} | |
tab_name = "Text2Video" if target_tab == "t2v" else "Video2Video" | |
try: | |
mapping = create_parameter_transfer_map(metadata, target_tab) | |
return f"Parameters ready for {tab_name}", mapping | |
except Exception as e: | |
return f"Error: {str(e)}", {} | |
video_input.upload( | |
fn=extract_video_details, | |
inputs=video_input, | |
outputs=[metadata_output, status] | |
) | |
send_to_t2v_btn.click( | |
fn=lambda m: send_parameters_to_tab(m, "t2v"), | |
inputs=metadata_output, | |
outputs=[status, params_state] | |
).then( | |
fn=change_to_tab_one, inputs=None, outputs=[tabs] | |
).then( | |
lambda params: [ | |
params.get("prompt", ""), | |
params.get("width", 544), | |
params.get("height", 544), | |
params.get("batch_size", 1), | |
params.get("video_length", 25), | |
params.get("fps", 24), | |
params.get("infer_steps", 30), | |
params.get("seed", -1), | |
params.get("model", "hunyuan/mp_rank_00_model_states.pt"), | |
params.get("vae", "hunyuan/pytorch_model.pt"), | |
params.get("te1", "hunyuan/llava_llama3_fp16.safetensors"), | |
params.get("te2", "hunyuan/clip_l.safetensors"), | |
params.get("save_path", "outputs"), | |
params.get("flow_shift", 11.0), | |
params.get("cfg_scale", 7.0), | |
params.get("output_type", "video"), | |
params.get("attn_mode", "sdpa"), | |
params.get("block_swap", "0"), | |
*[params.get(f"lora{i+1}", "") for i in range(4)], | |
*[params.get(f"lora{i+1}_multiplier", 1.0) for i in range(4)] | |
] if params else [gr.update()]*26, | |
inputs=params_state, | |
outputs=[prompt, width, height, batch_size, video_length, fps, infer_steps, seed, | |
model, vae, te1, te2, save_path, flow_shift, cfg_scale, | |
output_type, attn_mode, block_swap] + lora_weights + lora_multipliers | |
) | |
# Text to Video generation | |
generate_btn.click( | |
fn=process_batch, | |
inputs=[ | |
prompt, t2v_width, t2v_height, batch_size, video_length, fps, infer_steps, | |
seed, dit_folder, model, vae, te1, te2, save_path, flow_shift, cfg_scale, | |
output_type, attn_mode, block_swap, exclude_single_blocks, use_split_attn, | |
lora_folder, *lora_weights, *lora_multipliers, gr.Textbox(visible=False), gr.Number(visible=False), use_fp8 | |
], | |
outputs=[video_output, batch_progress, progress_text], | |
queue=True | |
).then( | |
fn=lambda batch_size: 0 if batch_size == 1 else None, | |
inputs=[batch_size], | |
outputs=selected_index | |
) | |
# Update gallery selection handling | |
def handle_gallery_select(evt: gr.SelectData) -> int: | |
return evt.index | |
# Track selected index when gallery item is clicked | |
video_output.select( | |
fn=handle_gallery_select, | |
outputs=selected_index | |
) | |
# Track selected index when Video2Video gallery item is clicked | |
def handle_v2v_gallery_select(evt: gr.SelectData) -> int: | |
"""Handle gallery selection without automatically updating the input""" | |
return evt.index | |
# Update the gallery selection event | |
v2v_output.select( | |
fn=handle_v2v_gallery_select, | |
outputs=v2v_selected_index | |
) | |
# Send button handler with gallery selection | |
def handle_send_button( | |
gallery: list, | |
prompt: str, | |
idx: int, | |
width: int, | |
height: int, | |
batch_size: int, | |
video_length: int, | |
fps: int, | |
infer_steps: int, | |
seed: int, | |
flow_shift: float, | |
cfg_scale: float, | |
lora1: str, | |
lora2: str, | |
lora3: str, | |
lora4: str, | |
lora1_multiplier: float, | |
lora2_multiplier: float, | |
lora3_multiplier: float, | |
lora4_multiplier: float | |
) -> tuple: | |
if not gallery or idx is None or idx >= len(gallery): | |
return (None, "", width, height, batch_size, video_length, fps, infer_steps, | |
seed, flow_shift, cfg_scale, | |
lora1, lora2, lora3, lora4, | |
lora1_multiplier, lora2_multiplier, lora3_multiplier, lora4_multiplier, | |
"") # Add empty string for negative_prompt in the return values | |
# Auto-select first item if only one exists and no selection made | |
if idx is None and len(gallery) == 1: | |
idx = 0 | |
selected_item = gallery[idx] | |
# Handle different gallery item formats | |
if isinstance(selected_item, dict): | |
video_path = selected_item.get("name", selected_item.get("data", None)) | |
elif isinstance(selected_item, (tuple, list)): | |
video_path = selected_item[0] | |
else: | |
video_path = selected_item | |
# Final cleanup for Gradio Video component | |
if isinstance(video_path, tuple): | |
video_path = video_path[0] | |
return ( | |
str(video_path), | |
prompt, | |
width, | |
height, | |
batch_size, | |
video_length, | |
fps, | |
infer_steps, | |
seed, | |
flow_shift, | |
cfg_scale, | |
lora1, | |
lora2, | |
lora3, | |
lora4, | |
lora1_multiplier, | |
lora2_multiplier, | |
lora3_multiplier, | |
lora4_multiplier, | |
"" # Add empty string for negative_prompt | |
) | |
send_t2v_to_v2v_btn.click( | |
fn=handle_send_button, | |
inputs=[ | |
video_output, prompt, selected_index, | |
t2v_width, t2v_height, batch_size, video_length, | |
fps, infer_steps, seed, flow_shift, cfg_scale | |
] + lora_weights + lora_multipliers, # Remove the string here | |
outputs=[ | |
v2v_input, | |
v2v_prompt, | |
v2v_width, | |
v2v_height, | |
v2v_batch_size, | |
v2v_video_length, | |
v2v_fps, | |
v2v_infer_steps, | |
v2v_seed, | |
v2v_flow_shift, | |
v2v_cfg_scale | |
] + v2v_lora_weights + v2v_lora_multipliers + [v2v_negative_prompt] | |
).then( | |
fn=change_to_tab_two, inputs=None, outputs=[tabs] | |
) | |
def handle_send_to_v2v(metadata: dict, video_path: str) -> Tuple[str, dict, str]: | |
"""Handle both parameters and video transfer""" | |
status_msg, params = send_parameters_to_tab(metadata, "v2v") | |
return status_msg, params, video_path | |
def handle_info_to_v2v(metadata: dict, video_path: str) -> Tuple[str, Dict, str]: | |
"""Handle both parameters and video transfer from Video Info to V2V tab""" | |
if not video_path: | |
return "No video selected", {}, None | |
status_msg, params = send_parameters_to_tab(metadata, "v2v") | |
# Just return the path directly | |
return status_msg, params, video_path | |
# Send button click handler | |
send_to_v2v_btn.click( | |
fn=handle_info_to_v2v, | |
inputs=[metadata_output, video_input], | |
outputs=[status, params_state, v2v_input] | |
).then( | |
lambda params: [ | |
params.get("v2v_prompt", ""), | |
params.get("v2v_width", 544), | |
params.get("v2v_height", 544), | |
params.get("v2v_batch_size", 1), | |
params.get("v2v_video_length", 25), | |
params.get("v2v_fps", 24), | |
params.get("v2v_infer_steps", 30), | |
params.get("v2v_seed", -1), | |
params.get("v2v_model", "hunyuan/mp_rank_00_model_states.pt"), | |
params.get("v2v_vae", "hunyuan/pytorch_model.pt"), | |
params.get("v2v_te1", "hunyuan/llava_llama3_fp16.safetensors"), | |
params.get("v2v_te2", "hunyuan/clip_l.safetensors"), | |
params.get("v2v_save_path", "outputs"), | |
params.get("v2v_flow_shift", 11.0), | |
params.get("v2v_cfg_scale", 7.0), | |
params.get("v2v_output_type", "video"), | |
params.get("v2v_attn_mode", "sdpa"), | |
params.get("v2v_block_swap", "0"), | |
*[params.get(f"v2v_lora_weights[{i}]", "") for i in range(4)], | |
*[params.get(f"v2v_lora_multipliers[{i}]", 1.0) for i in range(4)] | |
] if params else [gr.update()] * 26, | |
inputs=params_state, | |
outputs=[ | |
v2v_prompt, v2v_width, v2v_height, v2v_batch_size, v2v_video_length, | |
v2v_fps, v2v_infer_steps, v2v_seed, v2v_model, v2v_vae, v2v_te1, | |
v2v_te2, v2v_save_path, v2v_flow_shift, v2v_cfg_scale, v2v_output_type, | |
v2v_attn_mode, v2v_block_swap | |
] + v2v_lora_weights + v2v_lora_multipliers | |
).then( | |
lambda: print(f"Tabs object: {tabs}"), # Debug print | |
outputs=None | |
).then( | |
fn=change_to_tab_two, inputs=None, outputs=[tabs] | |
) | |
# Handler for sending selected video from Video2Video gallery to input | |
def handle_v2v_send_button(gallery: list, prompt: str, idx: int) -> Tuple[Optional[str], str]: | |
"""Send the currently selected video in V2V gallery to V2V input""" | |
if not gallery or idx is None or idx >= len(gallery): | |
return None, "" | |
selected_item = gallery[idx] | |
video_path = None | |
# Handle different gallery item formats | |
if isinstance(selected_item, tuple): | |
video_path = selected_item[0] # Gallery returns (path, caption) | |
elif isinstance(selected_item, dict): | |
video_path = selected_item.get("name", selected_item.get("data", None)) | |
elif isinstance(selected_item, str): | |
video_path = selected_item | |
if not video_path: | |
return None, "" | |
# Check if the file exists and is accessible | |
if not os.path.exists(video_path): | |
print(f"Warning: Video file not found at {video_path}") | |
return None, "" | |
return video_path, prompt | |
v2v_send_to_input_btn.click( | |
fn=handle_v2v_send_button, | |
inputs=[v2v_output, v2v_prompt, v2v_selected_index], | |
outputs=[v2v_input, v2v_prompt] | |
).then( | |
lambda: gr.update(visible=True), # Ensure the video input is visible | |
outputs=v2v_input | |
) | |
# Video to Video generation | |
v2v_generate_btn.click( | |
fn=process_batch, | |
inputs=[ | |
v2v_prompt, v2v_width, v2v_height, v2v_batch_size, v2v_video_length, | |
v2v_fps, v2v_infer_steps, v2v_seed, v2v_dit_folder, v2v_model, v2v_vae, v2v_te1, v2v_te2, | |
v2v_save_path, v2v_flow_shift, v2v_cfg_scale, v2v_output_type, v2v_attn_mode, | |
v2v_block_swap, v2v_exclude_single_blocks, v2v_use_split_attn, v2v_lora_folder, | |
*v2v_lora_weights, *v2v_lora_multipliers, v2v_input, v2v_strength, | |
v2v_negative_prompt, v2v_cfg_scale, v2v_split_uncond, v2v_use_fp8 | |
], | |
outputs=[v2v_output, v2v_batch_progress, v2v_progress_text], | |
queue=True | |
).then( | |
fn=lambda batch_size: 0 if batch_size == 1 else None, | |
inputs=[v2v_batch_size], | |
outputs=v2v_selected_index | |
) | |
refresh_outputs = [model] # Add model dropdown to outputs | |
for i in range(4): | |
refresh_outputs.extend([lora_weights[i], lora_multipliers[i]]) | |
refresh_btn.click( | |
fn=update_dit_and_lora_dropdowns, | |
inputs=[dit_folder, lora_folder, model] + lora_weights + lora_multipliers, | |
outputs=refresh_outputs | |
) | |
# Image2Video refresh | |
i2v_refresh_outputs = [i2v_model] # Add model dropdown to outputs | |
for i in range(4): | |
i2v_refresh_outputs.extend([i2v_lora_weights[i], i2v_lora_multipliers[i]]) | |
i2v_refresh_btn.click( | |
fn=update_dit_and_lora_dropdowns, | |
inputs=[i2v_dit_folder, i2v_lora_folder, i2v_model] + i2v_lora_weights + i2v_lora_multipliers, | |
outputs=i2v_refresh_outputs | |
) | |
# Video2Video refresh | |
v2v_refresh_outputs = [v2v_model] # Add model dropdown to outputs | |
for i in range(4): | |
v2v_refresh_outputs.extend([v2v_lora_weights[i], v2v_lora_multipliers[i]]) | |
v2v_refresh_btn.click( | |
fn=update_dit_and_lora_dropdowns, | |
inputs=[v2v_dit_folder, v2v_lora_folder, v2v_model] + v2v_lora_weights + v2v_lora_multipliers, | |
outputs=v2v_refresh_outputs | |
) | |
# WanX-i2v tab connections | |
wanx_prompt.change(fn=count_prompt_tokens, inputs=wanx_prompt, outputs=wanx_token_counter) | |
wanx_stop_btn.click(fn=lambda: stop_event.set(), queue=False) | |
# Image input handling for WanX-i2v | |
wanx_input.change( | |
fn=update_wanx_image_dimensions, | |
inputs=[wanx_input], | |
outputs=[wanx_original_dims, wanx_width, wanx_height] | |
) | |
# Scale slider handling for WanX-i2v | |
wanx_scale_slider.change( | |
fn=update_wanx_from_scale, | |
inputs=[wanx_scale_slider, wanx_original_dims], | |
outputs=[wanx_width, wanx_height] | |
) | |
# Width/height calculation buttons for WanX-i2v | |
wanx_calc_width_btn.click( | |
fn=calculate_wanx_width, | |
inputs=[wanx_height, wanx_original_dims], | |
outputs=[wanx_width] | |
) | |
wanx_calc_height_btn.click( | |
fn=calculate_wanx_height, | |
inputs=[wanx_width, wanx_original_dims], | |
outputs=[wanx_height] | |
) | |
# Add visibility toggle for the folder input components | |
wanx_use_random_folder.change( | |
fn=lambda x: (gr.update(visible=x), gr.update(visible=x), gr.update(visible=x), gr.update(visible=not x)), | |
inputs=[wanx_use_random_folder], | |
outputs=[wanx_input_folder, wanx_folder_status, wanx_validate_folder_btn, wanx_input] | |
) | |
# Validate folder button handler | |
wanx_validate_folder_btn.click( | |
fn=lambda folder: get_random_image_from_folder(folder)[1], | |
inputs=[wanx_input_folder], | |
outputs=[wanx_folder_status] | |
) | |
# Flow shift recommendation buttons | |
wanx_recommend_flow_btn.click( | |
fn=recommend_wanx_flow_shift, | |
inputs=[wanx_width, wanx_height], | |
outputs=[wanx_flow_shift] | |
) | |
wanx_t2v_recommend_flow_btn.click( | |
fn=recommend_wanx_flow_shift, | |
inputs=[wanx_t2v_width, wanx_t2v_height], | |
outputs=[wanx_t2v_flow_shift] | |
) | |
# Generate button handler | |
wanx_generate_btn.click( | |
fn=wanx_batch_handler, | |
inputs=[ | |
wanx_use_random_folder, | |
wanx_prompt, | |
wanx_negative_prompt, | |
wanx_width, | |
wanx_height, | |
wanx_video_length, | |
wanx_fps, | |
wanx_infer_steps, | |
wanx_flow_shift, | |
wanx_guidance_scale, | |
wanx_seed, | |
wanx_batch_size, | |
wanx_input_folder, | |
wanx_task, | |
wanx_dit_path, | |
wanx_vae_path, | |
wanx_t5_path, | |
wanx_clip_path, | |
wanx_save_path, | |
wanx_output_type, | |
wanx_sample_solver, | |
wanx_exclude_single_blocks, | |
wanx_attn_mode, | |
wanx_block_swap, | |
wanx_fp8, | |
wanx_fp8_t5, | |
wanx_lora_folder, | |
*wanx_lora_weights, | |
*wanx_lora_multipliers, | |
wanx_input # Include input image path for non-batch mode | |
], | |
outputs=[wanx_output, wanx_batch_progress, wanx_progress_text], | |
queue=True | |
).then( | |
fn=lambda batch_size: 0 if batch_size == 1 else None, | |
inputs=[wanx_batch_size], | |
outputs=wanx_i2v_selected_index # Update to use correct state | |
) | |
# Add refresh button handler for WanX-i2v tab | |
wanx_refresh_outputs = [] | |
for i in range(4): | |
wanx_refresh_outputs.extend([wanx_lora_weights[i], wanx_lora_multipliers[i]]) | |
wanx_refresh_btn.click( | |
fn=update_lora_dropdowns, | |
inputs=[wanx_lora_folder] + wanx_lora_weights + wanx_lora_multipliers, | |
outputs=wanx_refresh_outputs | |
) | |
# Gallery selection handling | |
wanx_output.select( | |
fn=handle_wanx_gallery_select, | |
inputs=[wanx_output], | |
outputs=[wanx_i2v_selected_index, wanx_base_video] | |
) | |
# Send to Video2Video handler | |
wanx_send_to_v2v_btn.click( | |
fn=send_wanx_to_v2v, | |
inputs=[ | |
wanx_output, # Gallery with videos | |
wanx_prompt, # Prompt text | |
wanx_i2v_selected_index, # Use the correct selected index state | |
wanx_width, | |
wanx_height, | |
wanx_video_length, | |
wanx_fps, | |
wanx_infer_steps, | |
wanx_seed, | |
wanx_flow_shift, | |
wanx_guidance_scale, | |
wanx_negative_prompt | |
], | |
outputs=[ | |
v2v_input, # Video input in V2V tab | |
v2v_prompt, # Prompt in V2V tab | |
v2v_width, | |
v2v_height, | |
v2v_video_length, | |
v2v_fps, | |
v2v_infer_steps, | |
v2v_seed, | |
v2v_flow_shift, | |
v2v_cfg_scale, | |
v2v_negative_prompt | |
] | |
).then( | |
fn=change_to_tab_two, # Function to switch to Video2Video tab | |
inputs=None, | |
outputs=[tabs] | |
) | |
# Add state for T2V tab selected index | |
wanx_t2v_selected_index = gr.State(value=None) | |
# Connect prompt token counter | |
wanx_t2v_prompt.change(fn=count_prompt_tokens, inputs=wanx_t2v_prompt, outputs=wanx_t2v_token_counter) | |
# Stop button handler | |
wanx_t2v_stop_btn.click(fn=lambda: stop_event.set(), queue=False) | |
# Flow shift recommendation button | |
wanx_t2v_recommend_flow_btn.click( | |
fn=recommend_wanx_flow_shift, | |
inputs=[wanx_t2v_width, wanx_t2v_height], | |
outputs=[wanx_t2v_flow_shift] | |
) | |
# Task change handler to update CLIP visibility and path | |
def update_clip_visibility(task): | |
is_i2v = "i2v" in task | |
return gr.update(visible=is_i2v) | |
wanx_t2v_task.change( | |
fn=update_clip_visibility, | |
inputs=[wanx_t2v_task], | |
outputs=[wanx_t2v_clip_path] | |
) | |
# Generate button handler for T2V | |
wanx_t2v_generate_btn.click( | |
fn=wanx_generate_video_batch, | |
inputs=[ | |
wanx_t2v_prompt, | |
wanx_t2v_negative_prompt, | |
wanx_t2v_width, | |
wanx_t2v_height, | |
wanx_t2v_video_length, | |
wanx_t2v_fps, | |
wanx_t2v_infer_steps, | |
wanx_t2v_flow_shift, | |
wanx_t2v_guidance_scale, | |
wanx_t2v_seed, | |
wanx_t2v_task, | |
wanx_t2v_dit_path, | |
wanx_t2v_vae_path, | |
wanx_t2v_t5_path, | |
wanx_t2v_clip_path, | |
wanx_t2v_save_path, | |
wanx_t2v_output_type, | |
wanx_t2v_sample_solver, | |
wanx_t2v_exclude_single_blocks, | |
wanx_t2v_attn_mode, | |
wanx_t2v_block_swap, | |
wanx_t2v_fp8, | |
wanx_t2v_fp8_t5, | |
wanx_t2v_lora_folder, | |
*wanx_t2v_lora_weights, | |
*wanx_t2v_lora_multipliers, | |
wanx_t2v_batch_size, | |
# input_image is now optional and not included here | |
], | |
outputs=[wanx_t2v_output, wanx_t2v_batch_progress, wanx_t2v_progress_text], | |
queue=True | |
).then( | |
fn=lambda batch_size: 0 if batch_size == 1 else None, | |
inputs=[wanx_t2v_batch_size], | |
outputs=wanx_t2v_selected_index | |
) | |
# Add refresh button handler for WanX-t2v tab | |
wanx_t2v_refresh_outputs = [] | |
for i in range(4): | |
wanx_t2v_refresh_outputs.extend([wanx_t2v_lora_weights[i], wanx_t2v_lora_multipliers[i]]) | |
wanx_t2v_refresh_btn.click( | |
fn=update_lora_dropdowns, | |
inputs=[wanx_t2v_lora_folder] + wanx_t2v_lora_weights + wanx_t2v_lora_multipliers, | |
outputs=wanx_t2v_refresh_outputs | |
) | |
# Gallery selection handling | |
wanx_t2v_output.select( | |
fn=handle_wanx_t2v_gallery_select, | |
outputs=wanx_t2v_selected_index | |
) | |
# Send to Video2Video handler | |
wanx_t2v_send_to_v2v_btn.click( | |
fn=send_wanx_t2v_to_v2v, | |
inputs=[ | |
wanx_t2v_output, | |
wanx_t2v_prompt, | |
wanx_t2v_selected_index, | |
wanx_t2v_width, | |
wanx_t2v_height, | |
wanx_t2v_video_length, | |
wanx_t2v_fps, | |
wanx_t2v_infer_steps, | |
wanx_t2v_seed, | |
wanx_t2v_flow_shift, | |
wanx_t2v_guidance_scale, | |
wanx_t2v_negative_prompt | |
], | |
outputs=[ | |
v2v_input, | |
v2v_prompt, | |
v2v_width, | |
v2v_height, | |
v2v_video_length, | |
v2v_fps, | |
v2v_infer_steps, | |
v2v_seed, | |
v2v_flow_shift, | |
v2v_cfg_scale, | |
v2v_negative_prompt | |
] | |
).then( | |
fn=change_to_tab_two, | |
inputs=None, | |
outputs=[tabs] | |
) | |
demo.queue().launch(server_name="0.0.0.0", share=False) |