|
import os |
|
import sys |
|
import json |
|
import time |
|
import requests |
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
import io |
|
import base64 |
|
import spaces |
|
|
|
|
|
COMFY_API = "http://localhost:8188/api" |
|
WORKFLOW_PATH = "/app/workflows/Workflow_12_11.json" |
|
|
|
|
|
try: |
|
with open(WORKFLOW_PATH, "r") as f: |
|
workflow_template = json.load(f) |
|
print(f"Loaded workflow template from {WORKFLOW_PATH}") |
|
except Exception as e: |
|
print(f"Error loading workflow template: {str(e)}") |
|
workflow_template = {} |
|
|
|
def is_comfyui_running(): |
|
"""Check if ComfyUI server is running and accessible""" |
|
try: |
|
response = requests.get(f"{COMFY_API}/system_stats", timeout=5) |
|
return response.status_code == 200 |
|
except Exception as e: |
|
print(f"ComfyUI server check failed: {str(e)}") |
|
return False |
|
|
|
def queue_prompt(prompt): |
|
"""Send a prompt to ComfyUI for processing""" |
|
if not is_comfyui_running(): |
|
return {"error": "ComfyUI server is not running"} |
|
|
|
p = {"prompt": prompt} |
|
try: |
|
response = requests.post(f"{COMFY_API}/prompt", json=p, timeout=30) |
|
return response.json() |
|
except Exception as e: |
|
print(f"Error queuing prompt: {str(e)}") |
|
return {"error": str(e)} |
|
|
|
def get_image(filename, subfolder, folder_type): |
|
"""Get an image from ComfyUI's output folder""" |
|
if not is_comfyui_running(): |
|
return None |
|
|
|
try: |
|
response = requests.get(f"{COMFY_API}/view?filename={filename}&subfolder={subfolder}&type={folder_type}", timeout=30) |
|
return Image.open(io.BytesIO(response.content)) |
|
except Exception as e: |
|
print(f"Error getting image {filename}: {str(e)}") |
|
return None |
|
|
|
def upload_image(image, filename): |
|
"""Upload an image to ComfyUI's input folder""" |
|
if not is_comfyui_running(): |
|
return {"error": "ComfyUI server is not running"} |
|
|
|
try: |
|
if isinstance(image, str): |
|
image_data = base64.b64decode(image.split(",")[1]) |
|
files = {"image": (filename, image_data)} |
|
else: |
|
if isinstance(image, np.ndarray): |
|
image = Image.fromarray(image) |
|
|
|
img_byte_arr = io.BytesIO() |
|
image.save(img_byte_arr, format='PNG') |
|
img_byte_arr.seek(0) |
|
files = {"image": (filename, img_byte_arr.getvalue())} |
|
|
|
|
|
response = requests.post(f"{COMFY_API}/upload/image", files=files, timeout=60) |
|
return response.json() |
|
except requests.exceptions.ConnectionError as e: |
|
error_msg = f"Connection error when uploading image: {str(e)}" |
|
print(error_msg) |
|
return {"error": error_msg} |
|
except Exception as e: |
|
error_msg = f"Error uploading image: {str(e)}" |
|
print(error_msg) |
|
return {"error": error_msg} |
|
|
|
def check_progress(prompt_id): |
|
"""Check the progress of a ComfyUI prompt""" |
|
if not is_comfyui_running(): |
|
return {"error": "ComfyUI server is not running"} |
|
|
|
try: |
|
response = requests.get(f"{COMFY_API}/history/{prompt_id}", timeout=30) |
|
return response.json() |
|
except Exception as e: |
|
print(f"Error checking progress: {str(e)}") |
|
return {"error": str(e)} |
|
|
|
@spaces.GPU |
|
def generate_avatar(player_image, pose_image, shirt_image, player_name, team_name, age_gender_bg, shirt_color, style): |
|
"""Generate a football player avatar using ComfyUI workflow""" |
|
if not is_comfyui_running(): |
|
return None, "Error: ComfyUI server is not running. Please try again later." |
|
|
|
|
|
print("Uploading player image...") |
|
player_upload = upload_image(player_image, "player.png") |
|
|
|
print("Uploading pose image...") |
|
pose_upload = upload_image(pose_image, "pose.png") |
|
|
|
print("Uploading shirt image...") |
|
shirt_upload = upload_image(shirt_image, "shirt.png") |
|
|
|
if "error" in player_upload or "error" in pose_upload or "error" in shirt_upload: |
|
error_msg = f"Error uploading images: {player_upload.get('error', '')} {pose_upload.get('error', '')} {shirt_upload.get('error', '')}" |
|
print(error_msg) |
|
return None, error_msg |
|
|
|
print("All images uploaded successfully") |
|
|
|
|
|
workflow = workflow_template.copy() |
|
|
|
|
|
try: |
|
|
|
workflow["391"]["inputs"]["image"] = player_upload["name"] |
|
|
|
|
|
workflow["310"]["inputs"]["image"] = pose_upload["name"] |
|
|
|
|
|
workflow["636"]["inputs"]["image"] = shirt_upload["name"] |
|
|
|
|
|
workflow["471"]["inputs"]["string"] = f"_{player_name}" |
|
|
|
|
|
workflow["667"]["inputs"]["string"] = team_name |
|
|
|
|
|
workflow["420"]["inputs"]["string"] = age_gender_bg |
|
|
|
|
|
workflow["528"]["inputs"]["string"] = f"({shirt_color}:1.2) blank t-shirt, black shorts, " |
|
|
|
|
|
workflow["422"]["inputs"]["string"] = style |
|
except KeyError as e: |
|
error_msg = f"Error updating workflow parameters: {str(e)}. The workflow structure may have changed." |
|
print(error_msg) |
|
return None, error_msg |
|
|
|
|
|
print("Queuing prompt in ComfyUI...") |
|
prompt_response = queue_prompt(workflow) |
|
|
|
if "error" in prompt_response: |
|
error_msg = f"Error queuing prompt: {prompt_response['error']}" |
|
print(error_msg) |
|
return None, error_msg |
|
|
|
prompt_id = prompt_response["prompt_id"] |
|
print(f"Prompt queued with ID: {prompt_id}") |
|
|
|
|
|
status = "Generating avatar..." |
|
retries = 0 |
|
max_retries = 60 |
|
|
|
while retries < max_retries: |
|
time.sleep(5) |
|
progress = check_progress(prompt_id) |
|
|
|
if "error" in progress: |
|
retries += 1 |
|
print(f"Error checking progress (retry {retries}/{max_retries}): {progress['error']}") |
|
continue |
|
|
|
if prompt_id in progress and len(progress[prompt_id]["outputs"]) > 0: |
|
|
|
for node_id, output in progress[prompt_id]["outputs"].items(): |
|
if node_id == "308" or node_id == "679": |
|
image_filename = output.get("images", [{}])[0].get("filename", "") |
|
if image_filename: |
|
print(f"Found output image: {image_filename}") |
|
result_image = get_image(image_filename, "", "output") |
|
masked_filename = image_filename.replace(".png", "_Masked.png") |
|
masked_image = get_image(masked_filename, "", "output") |
|
|
|
|
|
if masked_image: |
|
print("Returning masked image") |
|
return masked_image, "Avatar generated successfully!" |
|
elif result_image: |
|
print("Returning regular image") |
|
return result_image, "Avatar generated successfully!" |
|
else: |
|
return None, "Generated image could not be retrieved." |
|
|
|
return None, "Completed, but couldn't find output image." |
|
|
|
retries += 1 |
|
status = f"Generating avatar... (attempt {retries}/{max_retries})" |
|
print(status) |
|
|
|
return None, "Timed out waiting for the avatar generation to complete." |
|
|
|
def create_interface(): |
|
"""Create the Gradio interface for the avatar generator""" |
|
with gr.Blocks(title="Football Player Avatar Generator") as demo: |
|
gr.Markdown("# Football Player Avatar Generator") |
|
gr.Markdown("Create stylized football player avatars from photos") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
player_image = gr.Image(label="Upload Player Photo", type="pil") |
|
|
|
with gr.Row(): |
|
pose_image = gr.Image(label="Select Pose Template", type="pil") |
|
shirt_image = gr.Image(label="Select Shirt Template", type="pil") |
|
|
|
player_name = gr.Textbox(label="Player Name", value="Player") |
|
team_name = gr.Textbox(label="Team Name", value="Kurjet") |
|
|
|
age_gender_options = [ |
|
"9 year old boy with on light grey studio background, upper body portrait", |
|
"10 year old boy with on light grey studio background, upper body portrait", |
|
"adult man with on light grey studio background, upper body portrait", |
|
"adult woman with on light grey studio background, upper body portrait" |
|
] |
|
age_gender_bg = gr.Dropdown(label="Age, Gender & Background", choices=age_gender_options, value=age_gender_options[0]) |
|
|
|
shirt_color_options = ["black", "red", "blue", "green", "yellow", "white"] |
|
shirt_color = gr.Dropdown(label="Shirt Color", choices=shirt_color_options, value="black") |
|
|
|
style_options = [ |
|
"3d pixar character portrait, award winning, 3d animation, octane rendering", |
|
"digital painting, detailed, concept art, smooth, sharp focus, illustration, trending on artstation", |
|
"cartoon drawing, hand drawn, pencil on paper, sketch art", |
|
"watercolor painting, beautiful, smooth, sharp focus, colorful, professional" |
|
] |
|
style = gr.Dropdown(label="Art Style", choices=style_options, value=style_options[0]) |
|
|
|
generate_button = gr.Button("Generate Avatar", variant="primary") |
|
|
|
with gr.Column(): |
|
output_image = gr.Image(label="Generated Avatar") |
|
status_text = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
comfy_status = gr.Textbox(label="ComfyUI Server Status", interactive=False) |
|
check_comfy_button = gr.Button("Check ComfyUI Server Status") |
|
|
|
def check_comfy_server(): |
|
if is_comfyui_running(): |
|
return "✅ ComfyUI server is running" |
|
else: |
|
return "❌ ComfyUI server is not running" |
|
|
|
check_comfy_button.click( |
|
fn=check_comfy_server, |
|
inputs=[], |
|
outputs=[comfy_status] |
|
) |
|
|
|
|
|
try: |
|
default_pose = Image.open("/app/ComfyUI/input/pose4.jpg") |
|
default_shirt = Image.open("/app/ComfyUI/input/paita2.jpg") |
|
pose_image.value = default_pose |
|
shirt_image.value = default_shirt |
|
except Exception as e: |
|
print(f"Error loading default images: {str(e)}") |
|
|
|
|
|
generate_button.click( |
|
fn=generate_avatar, |
|
inputs=[player_image, pose_image, shirt_image, player_name, team_name, age_gender_bg, shirt_color, style], |
|
outputs=[output_image, status_text] |
|
) |
|
|
|
|
|
if os.path.exists("/app/ComfyUI/input"): |
|
example_images = [] |
|
for filename in os.listdir("/app/ComfyUI/input"): |
|
if filename.lower().endswith(('.png', '.jpg', '.jpeg')): |
|
example_images.append(os.path.join("/app/ComfyUI/input", filename)) |
|
|
|
if example_images: |
|
gr.Examples( |
|
examples=[[img] for img in example_images[:3]], |
|
inputs=[player_image] |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
retry_count = 0 |
|
max_retries = 12 |
|
comfy_running = False |
|
|
|
while retry_count < max_retries and not comfy_running: |
|
try: |
|
print(f"Checking if ComfyUI is running (attempt {retry_count+1}/{max_retries})...") |
|
response = requests.get(f"{COMFY_API}/system_stats", timeout=5) |
|
if response.status_code == 200: |
|
print("ComfyUI is running, starting Gradio interface...") |
|
comfy_running = True |
|
break |
|
except Exception as e: |
|
print(f"ComfyUI check failed: {str(e)}") |
|
|
|
retry_count += 1 |
|
print(f"Waiting for ComfyUI to start... (attempt {retry_count}/{max_retries})") |
|
time.sleep(10) |
|
|
|
if not comfy_running: |
|
print("WARNING: Could not connect to ComfyUI server. The application may not function correctly.") |
|
print("Starting Gradio interface anyway...") |
|
|
|
|
|
demo = create_interface() |
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=False) |
|
|