import gradio as gr import requests import time import io import os from PIL import Image # Global variable to track active requests active_requests = {} # Predefined assets ASSETS = { "top": { "reference_image": "pullover.png", "prompt": "black pullover with the logo on the chest" }, "bottom": { "reference_image": "sweatpants.png", "prompt": "black sweatpants with the silver logo on it" }, "logo": "logo.png" } def needs_resizing(image_path, max_short_edge=1280): # Check if image needs resizing (returns True if short edge > max_short_edge) with Image.open(image_path) as img: width, height = img.size return min(width, height) > max_short_edge def resize_image_to_short_edge(image_path, max_short_edge=1280): # Resize image so the short edge is at most max_short_edge pixels with Image.open(image_path) as img: width, height = img.size format = img.format # Determine scaling factor if width < height: scaling_factor = max_short_edge / width else: scaling_factor = max_short_edge / height # Calculate new dimensions and resize new_width = int(width * scaling_factor) new_height = int(height * scaling_factor) resized_img = img.resize((new_width, new_height), Image.LANCZOS) # Save to bytes buffer buffer = io.BytesIO() resized_img.save(buffer, format=format) buffer.seek(0) return buffer def validate_assets(): # Check if all required asset files exist missing = [] for asset_type in ["top", "bottom"]: if not os.path.exists(ASSETS[asset_type]["reference_image"]): missing.append(ASSETS[asset_type]["reference_image"]) if not os.path.exists(ASSETS["logo"]): missing.append(ASSETS["logo"]) if missing: raise FileNotFoundError(f"Missing required asset files: {', '.join(missing)}") def generate_and_wait_for_image( api_key: str, input_image: str, garment_type: str, progress=gr.Progress() ): # Make POST request and automatically poll for the result image # Validate assets first try: validate_assets() except FileNotFoundError as e: return None, str(e) # Create a unique ID for this request request_id = str(time.time()) active_requests[request_id] = True try: # Start the job post_url = "https://api.stability.ai/private/alo/v1/vto-acolade" headers = { "Authorization": f"Bearer {api_key}", "Accept": "application/json" } files = {} try: progress(0, desc="Starting image generation...") # Prepare all required files files = { 'logo_image': open(ASSETS["logo"], 'rb'), 'reference_image': open(ASSETS[garment_type]["reference_image"], 'rb') } # Only resize input_image if needed if needs_resizing(input_image): files['input_image'] = resize_image_to_short_edge(input_image) else: files['input_image'] = open(input_image, 'rb') data = { 'reference_image_type': (f'{garment_type}'), 'output_format': (None, "png") # Hardcoded to PNG } # Submit the job with timeout print(headers) print(files) response = requests.post(post_url, headers=headers, files=files, data=data, timeout=10) response.raise_for_status() job_data = response.json() job_id = job_data.get('id') if not job_id: return None, "Error: No job ID received in response" # Now poll for results with optimized timing get_url = f"https://api.stability.ai/private/alo/v1/results/{job_id}" headers = { "authorization": f"{api_key}", "accept": "*/*" } progress(0.3, desc="Processing your image...") # Optimized polling strategy max_attempts = 20 initial_delay = 1.0 max_delay = 5.0 current_delay = initial_delay for attempt in range(max_attempts): if not active_requests.get(request_id, False): return None, "Request cancelled by user" time.sleep(current_delay) progress(0.3 + (0.7 * attempt/max_attempts), desc=f"Checking status (attempt {attempt + 1}/{max_attempts})") try: with requests.Session() as session: response = session.get(get_url, headers=headers, timeout=10) if response.status_code == 200: if 'image' in response.headers.get('Content-Type', ''): img = Image.open(io.BytesIO(response.content)) progress(1.0, desc="Done!") return img, f"Success! Job ID: {job_id}" else: json_response = response.json() if json_response.get('status') == 'processing': current_delay = min(current_delay * 1.5, max_delay) continue return None, f"API response: {json_response}" elif response.status_code == 202: current_delay = min(current_delay * 1.5, max_delay) continue else: response.raise_for_status() except requests.exceptions.RequestException: current_delay = min(current_delay * 1.5, max_delay) continue return None, f"Timeout after {max_attempts} attempts. Job ID: {job_id}" except Exception as e: return None, f"Error: {str(e)}" finally: # Clean up file handles for key in files: if hasattr(files[key], 'close'): files[key].close() finally: # Clean up the request tracking active_requests.pop(request_id, None) def cancel_request(): """Function to cancel active requests""" for req_id in list(active_requests.keys()): active_requests[req_id] = False with gr.Blocks(title="Virtual Try-On Demo") as demo: gr.Markdown(""" # Virtual Try-On Demo v1 Upload your photo and select garment type to generate your VTon image. """) with gr.Row(): with gr.Column(): api_key = gr.Textbox( label="API Key", value="", type="password" ) input_image = gr.Image( label="Upload Your Photo", type="filepath", sources=["upload"], height=300 ) garment_type = gr.Dropdown( label="Garment Type", choices=["top", "bottom"], value="top" ) with gr.Row(): submit_btn = gr.Button("Generate Image", variant="primary") cancel_btn = gr.Button("Cancel Request") with gr.Column(): output_image = gr.Image( label="Generated Result", interactive=False, height=400 ) status_output = gr.Textbox( label="Status", interactive=False ) submit_btn.click( fn=generate_and_wait_for_image, inputs=[api_key, input_image, garment_type], outputs=[output_image, status_output] ) cancel_btn.click( fn=cancel_request, inputs=None, outputs=None, queue=False ) gr.Markdown(""" # Note:\n The image will be resized to 1280px on the short edge. """) if __name__ == "__main__": # Verify required files exist before launching try: validate_assets() demo.launch() except FileNotFoundError as e: print(f"Error: {str(e)}") print("Please make sure these files exist in the same directory:") print("- logo.png") print("- pullover.png (for top selection)") print("- sweatpants.png (for bottom selection)")