VTon-Demo-v1 / app.py
saidennis's picture
Update app.py
ab3fb39 verified
import gradio as gr
import requests
import time
import io
import os
from PIL import Image
# Global variable to track active requests
active_requests = {}
# Predefined assets
ASSETS = {
"top": {
"reference_image": "pullover.png",
"prompt": "black pullover with the logo on the chest"
},
"bottom": {
"reference_image": "sweatpants.png",
"prompt": "black sweatpants with the silver logo on it"
},
"logo": "logo.png"
}
def needs_resizing(image_path, max_short_edge=1280):
# Check if image needs resizing (returns True if short edge > max_short_edge)
with Image.open(image_path) as img:
width, height = img.size
return min(width, height) > max_short_edge
def resize_image_to_short_edge(image_path, max_short_edge=1280):
# Resize image so the short edge is at most max_short_edge pixels
with Image.open(image_path) as img:
width, height = img.size
format = img.format
# Determine scaling factor
if width < height:
scaling_factor = max_short_edge / width
else:
scaling_factor = max_short_edge / height
# Calculate new dimensions and resize
new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)
resized_img = img.resize((new_width, new_height), Image.LANCZOS)
# Save to bytes buffer
buffer = io.BytesIO()
resized_img.save(buffer, format=format)
buffer.seek(0)
return buffer
def validate_assets():
# Check if all required asset files exist
missing = []
for asset_type in ["top", "bottom"]:
if not os.path.exists(ASSETS[asset_type]["reference_image"]):
missing.append(ASSETS[asset_type]["reference_image"])
if not os.path.exists(ASSETS["logo"]):
missing.append(ASSETS["logo"])
if missing:
raise FileNotFoundError(f"Missing required asset files: {', '.join(missing)}")
def generate_and_wait_for_image(
api_key: str,
input_image: str,
garment_type: str,
progress=gr.Progress()
):
# Make POST request and automatically poll for the result image
# Validate assets first
try:
validate_assets()
except FileNotFoundError as e:
return None, str(e)
# Create a unique ID for this request
request_id = str(time.time())
active_requests[request_id] = True
try:
# Start the job
post_url = "https://api.stability.ai/private/alo/v1/vto-acolade"
headers = {
"Authorization": f"Bearer {api_key}",
"Accept": "application/json"
}
files = {}
try:
progress(0, desc="Starting image generation...")
# Prepare all required files
files = {
'logo_image': open(ASSETS["logo"], 'rb'),
'reference_image': open(ASSETS[garment_type]["reference_image"], 'rb')
}
# Only resize input_image if needed
if needs_resizing(input_image):
files['input_image'] = resize_image_to_short_edge(input_image)
else:
files['input_image'] = open(input_image, 'rb')
data = {
'reference_image_type': (f'{garment_type}'),
'output_format': (None, "png") # Hardcoded to PNG
}
# Submit the job with timeout
print(headers)
print(files)
response = requests.post(post_url, headers=headers, files=files, data=data, timeout=10)
response.raise_for_status()
job_data = response.json()
job_id = job_data.get('id')
if not job_id:
return None, "Error: No job ID received in response"
# Now poll for results with optimized timing
get_url = f"https://api.stability.ai/private/alo/v1/results/{job_id}"
headers = {
"authorization": f"{api_key}",
"accept": "*/*"
}
progress(0.3, desc="Processing your image...")
# Optimized polling strategy
max_attempts = 20
initial_delay = 1.0
max_delay = 5.0
current_delay = initial_delay
for attempt in range(max_attempts):
if not active_requests.get(request_id, False):
return None, "Request cancelled by user"
time.sleep(current_delay)
progress(0.3 + (0.7 * attempt/max_attempts),
desc=f"Checking status (attempt {attempt + 1}/{max_attempts})")
try:
with requests.Session() as session:
response = session.get(get_url, headers=headers, timeout=10)
if response.status_code == 200:
if 'image' in response.headers.get('Content-Type', ''):
img = Image.open(io.BytesIO(response.content))
progress(1.0, desc="Done!")
return img, f"Success! Job ID: {job_id}"
else:
json_response = response.json()
if json_response.get('status') == 'processing':
current_delay = min(current_delay * 1.5, max_delay)
continue
return None, f"API response: {json_response}"
elif response.status_code == 202:
current_delay = min(current_delay * 1.5, max_delay)
continue
else:
response.raise_for_status()
except requests.exceptions.RequestException:
current_delay = min(current_delay * 1.5, max_delay)
continue
return None, f"Timeout after {max_attempts} attempts. Job ID: {job_id}"
except Exception as e:
return None, f"Error: {str(e)}"
finally:
# Clean up file handles
for key in files:
if hasattr(files[key], 'close'):
files[key].close()
finally:
# Clean up the request tracking
active_requests.pop(request_id, None)
def cancel_request():
"""Function to cancel active requests"""
for req_id in list(active_requests.keys()):
active_requests[req_id] = False
with gr.Blocks(title="Virtual Try-On Demo") as demo:
gr.Markdown("""
# Virtual Try-On Demo v1
Upload your photo and select garment type to generate your VTon image.
""")
with gr.Row():
with gr.Column():
api_key = gr.Textbox(
label="API Key",
value="",
type="password"
)
input_image = gr.Image(
label="Upload Your Photo",
type="filepath",
sources=["upload"],
height=300
)
garment_type = gr.Dropdown(
label="Garment Type",
choices=["top", "bottom"],
value="top"
)
with gr.Row():
submit_btn = gr.Button("Generate Image", variant="primary")
cancel_btn = gr.Button("Cancel Request")
with gr.Column():
output_image = gr.Image(
label="Generated Result",
interactive=False,
height=400
)
status_output = gr.Textbox(
label="Status",
interactive=False
)
submit_btn.click(
fn=generate_and_wait_for_image,
inputs=[api_key, input_image, garment_type],
outputs=[output_image, status_output]
)
cancel_btn.click(
fn=cancel_request,
inputs=None,
outputs=None,
queue=False
)
gr.Markdown("""
# Note:\n
The image will be resized to 1280px on the short edge.
""")
if __name__ == "__main__":
# Verify required files exist before launching
try:
validate_assets()
demo.launch()
except FileNotFoundError as e:
print(f"Error: {str(e)}")
print("Please make sure these files exist in the same directory:")
print("- logo.png")
print("- pullover.png (for top selection)")
print("- sweatpants.png (for bottom selection)")