VTon-Demo-v1 / app.py
saidennis's picture
Update app.py
e72c3c8 verified
raw
history blame
7.7 kB
import gradio as gr
import requests
import time
import io
import os
from PIL import Image
# Global variable to track active requests
active_requests = {}
# Predefined assets
ASSETS = {
"top": {
"reference_image": "pullover.png",
"prompt": "black pullover with the logo on the chest"
},
"bottom": {
"reference_image": "sweatpants.png",
"prompt": "black sweatpants with the silver logo on it"
},
"logo": "logo.png"
}
def validate_assets():
"""Check if all required asset files exist"""
missing = []
for asset_type in ["top", "bottom"]:
if not os.path.exists(ASSETS[asset_type]["reference_image"]):
missing.append(ASSETS[asset_type]["reference_image"])
if not os.path.exists(ASSETS["logo"]):
missing.append(ASSETS["logo"])
if missing:
raise FileNotFoundError(f"Missing required asset files: {', '.join(missing)}")
def generate_and_wait_for_image(
api_key: str,
input_image: str,
garment_type: str,
progress=gr.Progress()
):
"""Make POST request and automatically poll for the result image"""
# Validate assets first
try:
validate_assets()
except FileNotFoundError as e:
return None, str(e)
# Create a unique ID for this request
request_id = str(time.time())
active_requests[request_id] = True
try:
# Start the job
post_url = "https://api.stability.ai/private/alo/v1/vto-acolade"
headers = {
"Authorization": f"Bearer {api_key}",
"Accept": "application/json"
}
files = {}
try:
progress(0, desc="Starting image generation...")
# Prepare all required files
files = {
'input_image': open(input_image, 'rb'),
'logo_image': open(ASSETS["logo"], 'rb'),
'reference_image': open(ASSETS[garment_type]["reference_image"], 'rb')
}
data = {
'reference_image_type': (f'{garment_type}'),
'output_format': (None, "png") # Hardcoded to PNG
}
# Submit the job with timeout
print(headers)
print(files)
response = requests.post(post_url, headers=headers, files=files, data=data, timeout=10)
response.raise_for_status()
job_data = response.json()
job_id = job_data.get('id')
if not job_id:
return None, "Error: No job ID received in response"
# Now poll for results with optimized timing
get_url = f"https://api.stability.ai/private/alo/v1/results/{job_id}"
headers = {
"authorization": f"{api_key}",
"accept": "*/*"
}
progress(0.3, desc="Processing your image...")
# Optimized polling strategy
max_attempts = 20
initial_delay = 1.0
max_delay = 5.0
current_delay = initial_delay
for attempt in range(max_attempts):
if not active_requests.get(request_id, False):
return None, "Request cancelled by user"
time.sleep(current_delay)
progress(0.3 + (0.7 * attempt/max_attempts),
desc=f"Checking status (attempt {attempt + 1}/{max_attempts})")
try:
with requests.Session() as session:
response = session.get(get_url, headers=headers, timeout=10)
if response.status_code == 200:
if 'image' in response.headers.get('Content-Type', ''):
img = Image.open(io.BytesIO(response.content))
progress(1.0, desc="Done!")
return img, f"Success! Job ID: {job_id}"
else:
json_response = response.json()
if json_response.get('status') == 'processing':
current_delay = min(current_delay * 1.5, max_delay)
continue
return None, f"API response: {json_response}"
elif response.status_code == 202:
current_delay = min(current_delay * 1.5, max_delay)
continue
else:
response.raise_for_status()
except requests.exceptions.RequestException:
current_delay = min(current_delay * 1.5, max_delay)
continue
return None, f"Timeout after {max_attempts} attempts. Job ID: {job_id}"
except Exception as e:
return None, f"Error: {str(e)}"
finally:
# Clean up file handles
for key in files:
if hasattr(files[key], 'close'):
files[key].close()
finally:
# Clean up the request tracking
active_requests.pop(request_id, None)
def cancel_request():
"""Function to cancel active requests"""
for req_id in list(active_requests.keys()):
active_requests[req_id] = False
with gr.Blocks(title="Virtual Try-On Demo") as demo:
gr.Markdown("""
# Virtual Try-On Demo v1
Upload your photo and select garment type to generate your VTon image.
""")
with gr.Row():
with gr.Column():
api_key = gr.Textbox(
label="API Key",
value="",
type="password"
)
input_image = gr.Image(
label="Upload Your Photo",
type="filepath",
sources=["upload"],
height=300
)
garment_type = gr.Dropdown(
label="Garment Type",
choices=["top", "bottom"],
value="top"
)
with gr.Row():
submit_btn = gr.Button("Generate Image", variant="primary")
cancel_btn = gr.Button("Cancel Request")
with gr.Column():
output_image = gr.Image(
label="Generated Result",
interactive=False,
height=400
)
status_output = gr.Textbox(
label="Status",
interactive=False
)
submit_btn.click(
fn=generate_and_wait_for_image,
inputs=[api_key, input_image, garment_type],
outputs=[output_image, status_output]
)
cancel_btn.click(
fn=cancel_request,
inputs=None,
outputs=None,
queue=False
)
gr.Markdown("""
# Input image requirements:\n
- Supported formats: jpeg, png, webp
- Minimum Resolution: at least 64 on each side
- Total pixel count: 4,096 to 9,437,184 pixels
- Aspect ratio: 1:2.5 to 2.5:1
""")
if __name__ == "__main__":
# Verify required files exist before launching
try:
validate_assets()
demo.launch()
except FileNotFoundError as e:
print(f"Error: {str(e)}")
print("Please make sure these files exist in the same directory:")
print("- logo.png")
print("- pullover.png (for top selection)")
print("- sweatpants.png (for bottom selection)")