import gradio as gr import asyncio import aiohttp import time from datetime import datetime import plotly.graph_objects as go from typing import Dict, List import os from dotenv import load_dotenv import json from PIL import Image, ImageDraw, ImageFont import uuid import threading # Load environment variables first load_dotenv() # Constants API_BASE_URL = "https://api.wavespeed.ai/api/v2" API_KEY = os.getenv("WAVESPEED_API_KEY") # Move API_KEY to global scope if not API_KEY: raise ValueError("WAVESPEED_API_KEY not found in environment variables") # Rest of constants BACKENDS = { "flux-dev": { "endpoint": f"{API_BASE_URL}/wavespeed-ai/flux-dev-ultra-fast", "name": "Flux-dev", "color": "#FF9800", }, "hidream-dev": { "endpoint": f"{API_BASE_URL}/wavespeed-ai/hidream-i1-dev", "name": "HiDream-dev", "color": "#2196F3", }, "hidream-full": { "endpoint": f"{API_BASE_URL}/wavespeed-ai/hidream-i1-full", "name": "HiDream-full", "color": "#4CAF50", }, } class BackendStatus: def __init__(self): self.reset() self.history: List[Dict] = [] def reset(self): self.status = "idle" self.progress = 0 self.start_time = None self.end_time = None def start(self): self.status = "processing" self.progress = 0 self.start_time = time.time() self.end_time = None def complete(self): self.status = "completed" self.progress = 100 self.end_time = time.time() self.history.append({ "timestamp": datetime.now(), "duration": self.end_time - self.start_time }) def fail(self): self.status = "failed" self.end_time = time.time() class SessionManager: _instances = {} _lock = threading.Lock() @classmethod def get_manager(cls, session_id=None): if session_id is None: session_id = str(uuid.uuid4()) with cls._lock: if session_id not in cls._instances: cls._instances[session_id] = GenerationManager() return session_id, cls._instances[session_id] @classmethod def cleanup_old_sessions(cls, max_age=3600): # 1 hour default current_time = time.time() with cls._lock: to_remove = [] for session_id, manager in cls._instances.items(): if (hasattr(manager, "last_activity") and current_time - manager.last_activity > max_age): to_remove.append(session_id) for session_id in to_remove: del cls._instances[session_id] class GenerationManager: def __init__(self): self.backend_statuses = { backend: BackendStatus() for backend in BACKENDS } self.last_activity = time.time() def update_activity(self): self.last_activity = time.time() def get_performance_plot(self): fig = go.Figure() has_data = False for backend, status in self.backend_statuses.items(): durations = [h["duration"] for h in status.history] if durations: has_data = True avg_duration = sum(durations) / len(durations) # Use bar chart instead of box plot fig.add_trace( go.Bar( y=[avg_duration], # Average duration x=[BACKENDS[backend]["name"]], # Backend name name=BACKENDS[backend]["name"], marker_color=BACKENDS[backend]["color"], text=[f"{avg_duration:.2f}s"], # Show time in seconds textposition="auto", width=[0.5], # Make bars narrower )) # Set a minimum y-axis range if we have data if has_data: max_duration = max([ max([h["duration"] for h in status.history] or [0]) for status in self.backend_statuses.values() ]) # Add 20% padding to the top y_max = max_duration * 1.2 # Ensure the y-axis always starts at 0 fig.update_yaxes(range=[0, y_max]) fig.update_layout( title="Average Generation Time", yaxis_title="Seconds", xaxis_title="", showlegend=False, template="simple_white", height=400, # Increase height margin=dict(l=50, r=50, t=50, b=50), # Add margins font=dict(size=14), # Larger font ) # Make sure we have a valid figure even if no data if not has_data: fig.add_annotation( text="No timing data available yet", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=16), ) return fig async def submit_task(self, backend: str, prompt: str) -> str: status = self.backend_statuses[backend] status.start() try: url = BACKENDS[backend]["endpoint"] headers = { "Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}", } payload = { "prompt": prompt, "enable_safety_checker": False, "enable_base64_output": True, # Enable base64 output "size": "1024*1024", "seed": -1, } if backend == "flux-dev": payload.update({ "guidance_scale": 3.5, "num_images": 1, "num_inference_steps": 28, "strength": 0.8, }) print(f"Submitting task to {backend}") print(f"URL: {url}") print(f"Payload: {json.dumps(payload, indent=2)}") # Use aiohttp instead of requests for async async with aiohttp.ClientSession() as session: async with session.post(url, headers=headers, json=payload) as response: if response.status == 200: result = await response.json() request_id = result["data"]["id"] print( f"Task submitted successfully. Request ID: {request_id}" ) return request_id else: text = await response.text() raise Exception( f"API error: {response.status}, {text}") except Exception as e: status.fail() raise Exception(f"Failed to submit task: {str(e)}") # Add this method to reset history def reset_history(self): """Reset history for all backends""" for status in self.backend_statuses.values(): status.history = [] # Clear history data return self # Helper function to create error images as data URIs def create_error_image(backend, error_message): try: import base64 from io import BytesIO # Create an in-memory image img = Image.new("RGB", (512, 512), color="#ffdddd") draw = ImageDraw.Draw(img) try: font = ImageFont.truetype("Arial", 20) except: font = ImageFont.load_default() # Wrap and draw error message words = error_message.split(" ") lines = [] line = "" for word in words: if len(line + word) < 40: line += word + " " else: lines.append(line) line = word + " " if line: lines.append(line) y_position = 100 for line in lines: draw.text((50, y_position), line, fill="black", font=font) y_position += 30 # Save to a BytesIO object instead of a file buffer = BytesIO() img.save(buffer, format="PNG") img_bytes = buffer.getvalue() # Convert to base64 and return as data URI return f"data:image/jpeg;base64,{base64.b64encode(img_bytes).decode('utf-8')}" except Exception as e: print(f"Failed to create error image: {e}") # Return a simple error message as fallback return "Error: " + error_message # Fix the poll_once function to accept a manager parameter async def poll_once(manager, backend, request_id): """Poll once and return result if complete, otherwise None""" headers = {"Authorization": f"Bearer {API_KEY}"} url = f"{API_BASE_URL}/predictions/{request_id}/result" async with aiohttp.ClientSession() as session: async with session.get(url, headers=headers) as response: if response.status == 200: result = await response.json() data = result["data"] current_status = data["status"] if current_status == "completed": # IMPORTANT: Update status BEFORE returning - using the passed manager manager.backend_statuses[backend].complete() manager.update_activity() # Handle base64 output output = data["outputs"][0] # Check if it's a base64 string or URL if isinstance(output, str) and output.startswith("http"): # It's a URL - return as is return output else: # It's base64 data - format it as a data URI if needed try: # Format as data URI for Gradio to display directly if isinstance( output, str ) and not output.startswith("data:image"): # Convert raw base64 to data URI format return f"data:image/png;base64,{output}" else: # Already in data URI format return output except Exception as e: print(f"Error processing base64 image: {e}") raise Exception( f"Failed to process base64 image: {str(e)}") elif current_status == "failed": manager.backend_statuses[backend].fail() manager.update_activity() error = data.get("error", "Unknown error") raise Exception(error) # Still processing return None else: raise Exception(f"Poll error: {response.status}") # Use a state variable to store session ID with gr.Blocks(theme=gr.themes.Soft()) as demo: session_id = gr.State(None) # Add this to store session ID gr.Markdown("# 🌊 HiDream Arena powered by WaveSpeed AI Image Generator") # Add the introduction with link to WaveSpeedAI gr.Markdown( "[WaveSpeedAI](https://wavespeed.ai/) is the global pioneer in accelerating AI-powered video and image generation." ) gr.Markdown( "Our in-house inference accelerator provides lossless speedup on image & video generation based on our rich inference optimization software stack, including our in-house inference compiler, CUDA kernel libraries and parallel computing libraries." ) with gr.Row(): with gr.Column(scale=3): input_text = gr.Textbox( label="Enter your prompt", placeholder="Type here...", lines=3, ) with gr.Column(scale=1): generate_btn = gr.Button("Generate", variant="primary") # Two status boxes - small (default) and big (during generation) small_status_box = gr.Markdown("Ready to generate images", elem_id="small-status") # Big status box in its own row with styling with gr.Row(elem_id="big-status-row"): big_status_box = gr.Markdown("", elem_id="big-status", visible=False, elem_classes="big-status-box") with gr.Row(): with gr.Column(): draft_output = gr.Image(label="Flux-dev") with gr.Column(): quick_output = gr.Image(label="HiDream-dev") with gr.Column(): best_output = gr.Image(label="HiDream-full") performance_plot = gr.Plot(label="Performance Metrics") # Add custom CSS for the big status box css = """ #big-status-row { margin: 20px 0; } #big-status { font-size: 28px; /* Even larger font size */ font-weight: bold; padding: 30px; /* More padding */ background-color: #0D47A1; /* Deeper blue background */ color: white; /* White text */ border-radius: 10px; text-align: center; margin: 0 auto; box-shadow: 0 6px 12px rgba(0, 0, 0, 0.2); /* Stronger shadow */ animation: deep-breath 3s infinite; /* Slower, deeper breathing animation */ width: 100%; /* Full width */ max-width: 800px; /* Maximum width */ transition: all 0.3s ease; /* Smooth transitions */ border-left: 6px solid #64B5F6; /* Add a colored border */ border-right: 6px solid #64B5F6; /* Add a colored border */ } /* Deeper breathing animation */ @keyframes deep-breath { 0% { opacity: 0.7; transform: scale(0.98); box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); } 50% { opacity: 1; transform: scale(1.01); box-shadow: 0 8px 16px rgba(0, 0, 0, 0.3); } 100% { opacity: 0.7; transform: scale(0.98); box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); } } """ gr.HTML(f"") # Update the generation function to use session manager async def generate_all_backends_with_status_boxes(prompt, current_session_id): """Generate images with big status box during generation""" # Get or create a session manager session_id, manager = SessionManager.get_manager(current_session_id) manager.update_activity() # IMPORTANT: Reset history when starting a new generation if prompt and prompt.strip() != "": manager.reset_history() # Clear previous performance metrics if not prompt or prompt.strip() == "": # Handle empty prompt case yield ( "⚠️ Please enter a prompt first", "⚠️ Please enter a prompt first", gr.update(visible=True), gr.update(visible=False), None, None, None, None, session_id, # Return the session ID ) return # Status message status_message = f"🔄 PROCESSING: '{prompt}'" # Initial state - clear all images, show big status box yield ( status_message, status_message, gr.update(visible=True), gr.update(visible=False), None, None, None, None, session_id, # Return the session ID ) # For production mode: completed_backends = set() results = {"flux-dev": None, "hidream-dev": None, "hidream-full": None} try: # Submit all tasks request_ids = {} for backend in BACKENDS: try: request_id = await manager.submit_task(backend, prompt) request_ids[backend] = request_id except Exception as e: # Handle submission error print(f"Error submitting task for {backend}: {e}") results[backend] = create_error_image(backend, str(e)) completed_backends.add(backend) # Poll all backends until they complete max_poll_attempts = 300 poll_attempt = 0 # Main polling loop while len(completed_backends ) < 3 and poll_attempt < max_poll_attempts: poll_attempt += 1 # Poll each pending backend for backend in list(BACKENDS.keys()): if backend in completed_backends: continue try: # Only do actual API calls every few attempts to reduce load if poll_attempt % 2 == 0 or backend == "flux-dev": # Use the session manager instead of global manager result = await poll_once(manager, backend, request_ids[backend]) if result: # Backend completed results[backend] = result completed_backends.add(backend) # Yield updated state when an image completes yield ( status_message, status_message, gr.update(visible=True), gr.update(visible=False), results["flux-dev"], results["hidream-dev"], results["hidream-full"], (manager.get_performance_plot() if any(completed_backends) else None), session_id, ) except Exception as e: print(f"Error polling {backend}: {str(e)}") # Wait between poll attempts await asyncio.sleep(0.1) # Final status final_status = ("✅ All generations completed!" if len(completed_backends) == 3 else "⚠️ Some generations timed out") # Final yield yield ( final_status, final_status, gr.update(visible=False), gr.update(visible=True), results["flux-dev"], results["hidream-dev"], results["hidream-full"], manager.get_performance_plot(), session_id, ) except Exception as e: # Error handling error_message = f"❌ Error: {str(e)}" yield ( error_message, error_message, gr.update(visible=False), gr.update(visible=True), None, None, None, None, session_id, ) # Schedule periodic cleanup of old sessions def cleanup_task(): SessionManager.cleanup_old_sessions() # Schedule the next cleanup threading.Timer(3600, cleanup_task).start() # Run every hour # Start the cleanup task cleanup_task() # Update the click handler to include session_id generate_btn.click( fn=generate_all_backends_with_status_boxes, inputs=[input_text, session_id], outputs=[ small_status_box, big_status_box, big_status_box, # visibility small_status_box, # visibility draft_output, quick_output, best_output, performance_plot, session_id, # Update the session ID ], api_name="generate", max_batch_size=10, # Process up to 10 requests at once concurrency_limit=20, # Allow up to 20 concurrent requests concurrency_id="generation", # Group concurrent requests under this ID ) # Launch with increased max_threads if __name__ == "__main__": demo.queue(max_size=50).launch( server_name="0.0.0.0", max_threads=16, # Increase thread count for better concurrency )