Spaces:
Running
Running
File size: 8,843 Bytes
e72c3c8 ab3fb39 becf435 ab3fb39 becf435 e72c3c8 ab3fb39 e72c3c8 ab3fb39 e72c3c8 ab3fb39 e72c3c8 becf435 e72c3c8 b321854 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 |
import gradio as gr
import requests
import time
import io
import os
from PIL import Image
# Global variable to track active requests
active_requests = {}
# Predefined assets
ASSETS = {
"top": {
"reference_image": "pullover.png",
"prompt": "black pullover with the logo on the chest"
},
"bottom": {
"reference_image": "sweatpants.png",
"prompt": "black sweatpants with the silver logo on it"
},
"logo": "logo.png"
}
def needs_resizing(image_path, max_short_edge=1280):
# Check if image needs resizing (returns True if short edge > max_short_edge)
with Image.open(image_path) as img:
width, height = img.size
return min(width, height) > max_short_edge
def resize_image_to_short_edge(image_path, max_short_edge=1280):
# Resize image so the short edge is at most max_short_edge pixels
with Image.open(image_path) as img:
width, height = img.size
format = img.format
# Determine scaling factor
if width < height:
scaling_factor = max_short_edge / width
else:
scaling_factor = max_short_edge / height
# Calculate new dimensions and resize
new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)
resized_img = img.resize((new_width, new_height), Image.LANCZOS)
# Save to bytes buffer
buffer = io.BytesIO()
resized_img.save(buffer, format=format)
buffer.seek(0)
return buffer
def validate_assets():
# Check if all required asset files exist
missing = []
for asset_type in ["top", "bottom"]:
if not os.path.exists(ASSETS[asset_type]["reference_image"]):
missing.append(ASSETS[asset_type]["reference_image"])
if not os.path.exists(ASSETS["logo"]):
missing.append(ASSETS["logo"])
if missing:
raise FileNotFoundError(f"Missing required asset files: {', '.join(missing)}")
def generate_and_wait_for_image(
api_key: str,
input_image: str,
garment_type: str,
progress=gr.Progress()
):
# Make POST request and automatically poll for the result image
# Validate assets first
try:
validate_assets()
except FileNotFoundError as e:
return None, str(e)
# Create a unique ID for this request
request_id = str(time.time())
active_requests[request_id] = True
try:
# Start the job
post_url = "https://api.stability.ai/private/alo/v1/vto-acolade"
headers = {
"Authorization": f"Bearer {api_key}",
"Accept": "application/json"
}
files = {}
try:
progress(0, desc="Starting image generation...")
# Prepare all required files
files = {
'logo_image': open(ASSETS["logo"], 'rb'),
'reference_image': open(ASSETS[garment_type]["reference_image"], 'rb')
}
# Only resize input_image if needed
if needs_resizing(input_image):
files['input_image'] = resize_image_to_short_edge(input_image)
else:
files['input_image'] = open(input_image, 'rb')
data = {
'reference_image_type': (f'{garment_type}'),
'output_format': (None, "png") # Hardcoded to PNG
}
# Submit the job with timeout
print(headers)
print(files)
response = requests.post(post_url, headers=headers, files=files, data=data, timeout=10)
response.raise_for_status()
job_data = response.json()
job_id = job_data.get('id')
if not job_id:
return None, "Error: No job ID received in response"
# Now poll for results with optimized timing
get_url = f"https://api.stability.ai/private/alo/v1/results/{job_id}"
headers = {
"authorization": f"{api_key}",
"accept": "*/*"
}
progress(0.3, desc="Processing your image...")
# Optimized polling strategy
max_attempts = 20
initial_delay = 1.0
max_delay = 5.0
current_delay = initial_delay
for attempt in range(max_attempts):
if not active_requests.get(request_id, False):
return None, "Request cancelled by user"
time.sleep(current_delay)
progress(0.3 + (0.7 * attempt/max_attempts),
desc=f"Checking status (attempt {attempt + 1}/{max_attempts})")
try:
with requests.Session() as session:
response = session.get(get_url, headers=headers, timeout=10)
if response.status_code == 200:
if 'image' in response.headers.get('Content-Type', ''):
img = Image.open(io.BytesIO(response.content))
progress(1.0, desc="Done!")
return img, f"Success! Job ID: {job_id}"
else:
json_response = response.json()
if json_response.get('status') == 'processing':
current_delay = min(current_delay * 1.5, max_delay)
continue
return None, f"API response: {json_response}"
elif response.status_code == 202:
current_delay = min(current_delay * 1.5, max_delay)
continue
else:
response.raise_for_status()
except requests.exceptions.RequestException:
current_delay = min(current_delay * 1.5, max_delay)
continue
return None, f"Timeout after {max_attempts} attempts. Job ID: {job_id}"
except Exception as e:
return None, f"Error: {str(e)}"
finally:
# Clean up file handles
for key in files:
if hasattr(files[key], 'close'):
files[key].close()
finally:
# Clean up the request tracking
active_requests.pop(request_id, None)
def cancel_request():
"""Function to cancel active requests"""
for req_id in list(active_requests.keys()):
active_requests[req_id] = False
with gr.Blocks(title="Virtual Try-On Demo") as demo:
gr.Markdown("""
# Virtual Try-On Demo v1
Upload your photo and select garment type to generate your VTon image.
""")
with gr.Row():
with gr.Column():
api_key = gr.Textbox(
label="API Key",
value="",
type="password"
)
input_image = gr.Image(
label="Upload Your Photo",
type="filepath",
sources=["upload"],
height=300
)
garment_type = gr.Dropdown(
label="Garment Type",
choices=["top", "bottom"],
value="top"
)
with gr.Row():
submit_btn = gr.Button("Generate Image", variant="primary")
cancel_btn = gr.Button("Cancel Request")
with gr.Column():
output_image = gr.Image(
label="Generated Result",
interactive=False,
height=400
)
status_output = gr.Textbox(
label="Status",
interactive=False
)
submit_btn.click(
fn=generate_and_wait_for_image,
inputs=[api_key, input_image, garment_type],
outputs=[output_image, status_output]
)
cancel_btn.click(
fn=cancel_request,
inputs=None,
outputs=None,
queue=False
)
gr.Markdown("""
# Note:\n
The image will be resized to 1280px on the short edge.
""")
if __name__ == "__main__":
# Verify required files exist before launching
try:
validate_assets()
demo.launch()
except FileNotFoundError as e:
print(f"Error: {str(e)}")
print("Please make sure these files exist in the same directory:")
print("- logo.png")
print("- pullover.png (for top selection)")
print("- sweatpants.png (for bottom selection)") |