Spaces:
Running
Running
import gradio as gr | |
import requests | |
import time | |
import io | |
import os | |
from PIL import Image | |
# Global variable to track active requests | |
active_requests = {} | |
# Predefined assets | |
ASSETS = { | |
"top": { | |
"reference_image": "pullover.png", | |
"prompt": "black pullover with the logo on the chest" | |
}, | |
"bottom": { | |
"reference_image": "sweatpants.png", | |
"prompt": "black sweatpants with the silver logo on it" | |
}, | |
"logo": "logo.png" | |
} | |
def needs_resizing(image_path, max_short_edge=1280): | |
# Check if image needs resizing (returns True if short edge > max_short_edge) | |
with Image.open(image_path) as img: | |
width, height = img.size | |
return min(width, height) > max_short_edge | |
def resize_image_to_short_edge(image_path, max_short_edge=1280): | |
# Resize image so the short edge is at most max_short_edge pixels | |
with Image.open(image_path) as img: | |
width, height = img.size | |
format = img.format | |
# Determine scaling factor | |
if width < height: | |
scaling_factor = max_short_edge / width | |
else: | |
scaling_factor = max_short_edge / height | |
# Calculate new dimensions and resize | |
new_width = int(width * scaling_factor) | |
new_height = int(height * scaling_factor) | |
resized_img = img.resize((new_width, new_height), Image.LANCZOS) | |
# Save to bytes buffer | |
buffer = io.BytesIO() | |
resized_img.save(buffer, format=format) | |
buffer.seek(0) | |
return buffer | |
def validate_assets(): | |
# Check if all required asset files exist | |
missing = [] | |
for asset_type in ["top", "bottom"]: | |
if not os.path.exists(ASSETS[asset_type]["reference_image"]): | |
missing.append(ASSETS[asset_type]["reference_image"]) | |
if not os.path.exists(ASSETS["logo"]): | |
missing.append(ASSETS["logo"]) | |
if missing: | |
raise FileNotFoundError(f"Missing required asset files: {', '.join(missing)}") | |
def generate_and_wait_for_image( | |
api_key: str, | |
input_image: str, | |
garment_type: str, | |
progress=gr.Progress() | |
): | |
# Make POST request and automatically poll for the result image | |
# Validate assets first | |
try: | |
validate_assets() | |
except FileNotFoundError as e: | |
return None, str(e) | |
# Create a unique ID for this request | |
request_id = str(time.time()) | |
active_requests[request_id] = True | |
try: | |
# Start the job | |
post_url = "https://api.stability.ai/private/alo/v1/vto-acolade" | |
headers = { | |
"Authorization": f"Bearer {api_key}", | |
"Accept": "application/json" | |
} | |
files = {} | |
try: | |
progress(0, desc="Starting image generation...") | |
# Prepare all required files | |
files = { | |
'logo_image': open(ASSETS["logo"], 'rb'), | |
'reference_image': open(ASSETS[garment_type]["reference_image"], 'rb') | |
} | |
# Only resize input_image if needed | |
if needs_resizing(input_image): | |
files['input_image'] = resize_image_to_short_edge(input_image) | |
else: | |
files['input_image'] = open(input_image, 'rb') | |
data = { | |
'reference_image_type': (f'{garment_type}'), | |
'output_format': (None, "png") # Hardcoded to PNG | |
} | |
# Submit the job with timeout | |
print(headers) | |
print(files) | |
response = requests.post(post_url, headers=headers, files=files, data=data, timeout=10) | |
response.raise_for_status() | |
job_data = response.json() | |
job_id = job_data.get('id') | |
if not job_id: | |
return None, "Error: No job ID received in response" | |
# Now poll for results with optimized timing | |
get_url = f"https://api.stability.ai/private/alo/v1/results/{job_id}" | |
headers = { | |
"authorization": f"{api_key}", | |
"accept": "*/*" | |
} | |
progress(0.3, desc="Processing your image...") | |
# Optimized polling strategy | |
max_attempts = 20 | |
initial_delay = 1.0 | |
max_delay = 5.0 | |
current_delay = initial_delay | |
for attempt in range(max_attempts): | |
if not active_requests.get(request_id, False): | |
return None, "Request cancelled by user" | |
time.sleep(current_delay) | |
progress(0.3 + (0.7 * attempt/max_attempts), | |
desc=f"Checking status (attempt {attempt + 1}/{max_attempts})") | |
try: | |
with requests.Session() as session: | |
response = session.get(get_url, headers=headers, timeout=10) | |
if response.status_code == 200: | |
if 'image' in response.headers.get('Content-Type', ''): | |
img = Image.open(io.BytesIO(response.content)) | |
progress(1.0, desc="Done!") | |
return img, f"Success! Job ID: {job_id}" | |
else: | |
json_response = response.json() | |
if json_response.get('status') == 'processing': | |
current_delay = min(current_delay * 1.5, max_delay) | |
continue | |
return None, f"API response: {json_response}" | |
elif response.status_code == 202: | |
current_delay = min(current_delay * 1.5, max_delay) | |
continue | |
else: | |
response.raise_for_status() | |
except requests.exceptions.RequestException: | |
current_delay = min(current_delay * 1.5, max_delay) | |
continue | |
return None, f"Timeout after {max_attempts} attempts. Job ID: {job_id}" | |
except Exception as e: | |
return None, f"Error: {str(e)}" | |
finally: | |
# Clean up file handles | |
for key in files: | |
if hasattr(files[key], 'close'): | |
files[key].close() | |
finally: | |
# Clean up the request tracking | |
active_requests.pop(request_id, None) | |
def cancel_request(): | |
"""Function to cancel active requests""" | |
for req_id in list(active_requests.keys()): | |
active_requests[req_id] = False | |
with gr.Blocks(title="Virtual Try-On Demo") as demo: | |
gr.Markdown(""" | |
# Virtual Try-On Demo v1 | |
Upload your photo and select garment type to generate your VTon image. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
api_key = gr.Textbox( | |
label="API Key", | |
value="", | |
type="password" | |
) | |
input_image = gr.Image( | |
label="Upload Your Photo", | |
type="filepath", | |
sources=["upload"], | |
height=300 | |
) | |
garment_type = gr.Dropdown( | |
label="Garment Type", | |
choices=["top", "bottom"], | |
value="top" | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Generate Image", variant="primary") | |
cancel_btn = gr.Button("Cancel Request") | |
with gr.Column(): | |
output_image = gr.Image( | |
label="Generated Result", | |
interactive=False, | |
height=400 | |
) | |
status_output = gr.Textbox( | |
label="Status", | |
interactive=False | |
) | |
submit_btn.click( | |
fn=generate_and_wait_for_image, | |
inputs=[api_key, input_image, garment_type], | |
outputs=[output_image, status_output] | |
) | |
cancel_btn.click( | |
fn=cancel_request, | |
inputs=None, | |
outputs=None, | |
queue=False | |
) | |
gr.Markdown(""" | |
# Note:\n | |
The image will be resized to 1280px on the short edge. | |
""") | |
if __name__ == "__main__": | |
# Verify required files exist before launching | |
try: | |
validate_assets() | |
demo.launch() | |
except FileNotFoundError as e: | |
print(f"Error: {str(e)}") | |
print("Please make sure these files exist in the same directory:") | |
print("- logo.png") | |
print("- pullover.png (for top selection)") | |
print("- sweatpants.png (for bottom selection)") |