File size: 10,338 Bytes
ee2f3a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import os
import sys
import json
import time
import requests
import gradio as gr
import numpy as np
from PIL import Image
import io
import base64
import spaces

# ComfyUI API endpoint
COMFY_API = "http://127.0.0.1:8188/api"
WORKFLOW_PATH = "/app/workflows/Workflow_12_11.json"

# Load the workflow template
try:
    with open(WORKFLOW_PATH, "r") as f:
        workflow_template = json.load(f)
    print(f"Loaded workflow template from {WORKFLOW_PATH}")
except Exception as e:
    print(f"Error loading workflow template: {str(e)}")
    workflow_template = {}

def queue_prompt(prompt):
    """Send a prompt to ComfyUI for processing"""
    p = {"prompt": prompt}
    try:
        response = requests.post(f"{COMFY_API}/prompt", json=p)
        return response.json()
    except Exception as e:
        print(f"Error queuing prompt: {str(e)}")
        return {"error": str(e)}

def get_image(filename, subfolder, folder_type):
    """Get an image from ComfyUI's output folder"""
    try:
        response = requests.get(f"{COMFY_API}/view?filename={filename}&subfolder={subfolder}&type={folder_type}")
        return Image.open(io.BytesIO(response.content))
    except Exception as e:
        print(f"Error getting image {filename}: {str(e)}")
        return None

def upload_image(image, filename):
    """Upload an image to ComfyUI's input folder"""
    try:
        if isinstance(image, str):  # Base64 string
            image_data = base64.b64decode(image.split(",")[1])
            files = {"image": (filename, image_data)}
        else:  # PIL Image or numpy array
            if isinstance(image, np.ndarray):
                image = Image.fromarray(image)
            
            img_byte_arr = io.BytesIO()
            image.save(img_byte_arr, format='PNG')
            img_byte_arr.seek(0)
            files = {"image": (filename, img_byte_arr.getvalue())}
        
        response = requests.post(f"{COMFY_API}/upload/image", files=files)
        return response.json()
    except Exception as e:
        print(f"Error uploading image: {str(e)}")
        return {"error": str(e)}

def check_progress(prompt_id):
    """Check the progress of a ComfyUI prompt"""
    try:
        response = requests.get(f"{COMFY_API}/history/{prompt_id}")
        return response.json()
    except Exception as e:
        print(f"Error checking progress: {str(e)}")
        return {"error": str(e)}

@spaces.GPU
def generate_avatar(player_image, pose_image, shirt_image, player_name, team_name, age_gender_bg, shirt_color, style):
    """Generate a football player avatar using ComfyUI workflow"""
    # Upload images to ComfyUI
    player_upload = upload_image(player_image, "player.png")
    pose_upload = upload_image(pose_image, "pose.png")
    shirt_upload = upload_image(shirt_image, "shirt.png")
    
    if "error" in player_upload or "error" in pose_upload or "error" in shirt_upload:
        return None, f"Error uploading images: {player_upload.get('error', '')} {pose_upload.get('error', '')} {shirt_upload.get('error', '')}"
    
    # Create a copy of the workflow template
    workflow = workflow_template.copy()
    
    # Update workflow nodes with our parameters
    # Player image node
    workflow["391"]["inputs"]["image"] = player_upload["name"]
    
    # Pose image node
    workflow["310"]["inputs"]["image"] = pose_upload["name"]
    
    # Shirt image node
    workflow["636"]["inputs"]["image"] = shirt_upload["name"]
    
    # Player name node
    workflow["471"]["inputs"]["string"] = f"_{player_name}"
    
    # Team name node
    workflow["667"]["inputs"]["string"] = team_name
    
    # Age, gender, background prompt node
    workflow["420"]["inputs"]["string"] = age_gender_bg
    
    # Shirt color node
    workflow["528"]["inputs"]["string"] = f"({shirt_color}:1.2) blank t-shirt, black shorts, "
    
    # Style node
    workflow["422"]["inputs"]["string"] = style
    
    # Queue the prompt in ComfyUI
    prompt_response = queue_prompt(workflow)
    
    if "error" in prompt_response:
        return None, f"Error queuing prompt: {prompt_response['error']}"
    
    prompt_id = prompt_response["prompt_id"]
    
    # Wait for the processing to complete
    status = "Generating avatar..."
    retries = 0
    max_retries = 60  # 5 minutes timeout
    
    while retries < max_retries:
        time.sleep(5)
        progress = check_progress(prompt_id)
        
        if "error" in progress:
            retries += 1
            continue
            
        if prompt_id in progress and len(progress[prompt_id]["outputs"]) > 0:
            # Get the output image
            for node_id, output in progress[prompt_id]["outputs"].items():
                if node_id == "308" or node_id == "679":  # Save Image nodes
                    image_filename = output.get("images", [{}])[0].get("filename", "")
                    if image_filename:
                        result_image = get_image(image_filename, "", "output")
                        masked_filename = output.get("images", [{}])[0].get("filename", "").replace(".png", "_Masked.png")
                        masked_image = get_image(masked_filename, "", "output")
                        
                        # Return the masked image if available, otherwise the regular image
                        return masked_image if masked_image else result_image, "Avatar generated successfully!"
            
            return None, "Completed, but couldn't find output image."
            
        retries += 1
        status = f"Generating avatar... (attempt {retries}/{max_retries})"
    
    return None, "Timed out waiting for the avatar generation to complete."

def create_interface():
    """Create the Gradio interface for the avatar generator"""
    with gr.Blocks(title="Football Player Avatar Generator") as demo:
        gr.Markdown("# Football Player Avatar Generator")
        gr.Markdown("Create stylized football player avatars from photos")
        
        with gr.Row():
            with gr.Column():
                player_image = gr.Image(label="Upload Player Photo", type="pil")
                
                with gr.Row():
                    pose_image = gr.Image(label="Select Pose Template", type="pil")
                    shirt_image = gr.Image(label="Select Shirt Template", type="pil")
                
                player_name = gr.Textbox(label="Player Name", value="Player")
                team_name = gr.Textbox(label="Team Name", value="Kurjet")
                
                age_gender_options = [
                    "9 year old boy with on light grey studio background, upper body portrait",
                    "10 year old boy with on light grey studio background, upper body portrait",
                    "adult man with on light grey studio background, upper body portrait",
                    "adult woman with on light grey studio background, upper body portrait"
                ]
                age_gender_bg = gr.Dropdown(label="Age, Gender & Background", choices=age_gender_options, value=age_gender_options[0])
                
                shirt_color_options = ["black", "red", "blue", "green", "yellow", "white"]
                shirt_color = gr.Dropdown(label="Shirt Color", choices=shirt_color_options, value="black")
                
                style_options = [
                    "3d pixar character portrait, award winning, 3d animation, octane rendering",
                    "digital painting, detailed, concept art, smooth, sharp focus, illustration, trending on artstation",
                    "cartoon drawing, hand drawn, pencil on paper, sketch art",
                    "watercolor painting, beautiful, smooth, sharp focus, colorful, professional"
                ]
                style = gr.Dropdown(label="Art Style", choices=style_options, value=style_options[0])
                
                generate_button = gr.Button("Generate Avatar", variant="primary")
            
            with gr.Column():
                output_image = gr.Image(label="Generated Avatar")
                status_text = gr.Textbox(label="Status", interactive=False)
        
        # Load default images for pose and shirt
        try:
            default_pose = Image.open("/app/ComfyUI/input/pose4.jpg")
            default_shirt = Image.open("/app/ComfyUI/input/paita2.jpg")
            pose_image.value = default_pose
            shirt_image.value = default_shirt
        except Exception as e:
            print(f"Error loading default images: {str(e)}")
        
        # Set up the button click event
        generate_button.click(
            fn=generate_avatar,
            inputs=[player_image, pose_image, shirt_image, player_name, team_name, age_gender_bg, shirt_color, style],
            outputs=[output_image, status_text]
        )
        
        # Add examples
        if os.path.exists("/app/ComfyUI/input"):
            example_images = []
            for filename in os.listdir("/app/ComfyUI/input"):
                if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
                    example_images.append(os.path.join("/app/ComfyUI/input", filename))
            
            if example_images:
                gr.Examples(
                    examples=[[img] for img in example_images[:3]],
                    inputs=[player_image]
                )
    
    return demo

# Check for ComfyUI server and launch the app
if __name__ == "__main__":
    # Check if ComfyUI is running
    retry_count = 0
    max_retries = 5
    comfy_running = False
    
    while retry_count < max_retries and not comfy_running:
        try:
            response = requests.get(f"{COMFY_API}/system_stats")
            if response.status_code == 200:
                print("ComfyUI is running, starting Gradio interface...")
                comfy_running = True
                break
        except:
            pass
        
        retry_count += 1
        print(f"Waiting for ComfyUI to start... (attempt {retry_count}/{max_retries})")
        time.sleep(10)
    
    if not comfy_running:
        print("WARNING: Could not connect to ComfyUI server. The application may not function correctly.")
    
    # Create and launch the interface
    demo = create_interface()
    demo.launch(server_name="0.0.0.0", server_port=7860, share=False)