saidennis commited on
Commit
b321854
·
verified ·
1 Parent(s): 4f94a43

Upload 5 files

Browse files
Files changed (6) hide show
  1. .gitattributes +2 -0
  2. app.py +221 -0
  3. logo.png +0 -0
  4. pullover.png +3 -0
  5. requirements.txt +2 -0
  6. sweatpants.png +3 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ pullover.png filter=lfs diff=lfs merge=lfs -text
37
+ sweatpants.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import time
4
+ import io
5
+ import os
6
+ from PIL import Image
7
+
8
+ # Global variable to track active requests
9
+ active_requests = {}
10
+
11
+ # Predefined assets
12
+ ASSETS = {
13
+ "top": {
14
+ "reference_image": "pullover.png",
15
+ "prompt": "black pullover with the logo on the chest"
16
+ },
17
+ "bottom": {
18
+ "reference_image": "sweatpants.png",
19
+ "prompt": "black sweatpants with the silver logo on it"
20
+ },
21
+ "logo": "logo.png"
22
+ }
23
+
24
+ def validate_assets():
25
+ """Check if all required asset files exist"""
26
+ missing = []
27
+ for asset_type in ["top", "bottom"]:
28
+ if not os.path.exists(ASSETS[asset_type]["reference_image"]):
29
+ missing.append(ASSETS[asset_type]["reference_image"])
30
+ if not os.path.exists(ASSETS["logo"]):
31
+ missing.append(ASSETS["logo"])
32
+ if missing:
33
+ raise FileNotFoundError(f"Missing required asset files: {', '.join(missing)}")
34
+
35
+ def generate_and_wait_for_image(
36
+ api_key: str,
37
+ input_image: str,
38
+ garment_type: str,
39
+ progress=gr.Progress()
40
+ ):
41
+ """Make POST request and automatically poll for the result image"""
42
+ # Validate assets first
43
+ try:
44
+ validate_assets()
45
+ except FileNotFoundError as e:
46
+ return None, str(e)
47
+
48
+ # Create a unique ID for this request
49
+ request_id = str(time.time())
50
+ active_requests[request_id] = True
51
+
52
+ try:
53
+ # Start the job
54
+ post_url = "https://api.stability.ai/private/alo/v1/vto-acolade"
55
+ headers = {
56
+ "Authorization": f"Bearer {api_key}",
57
+ "Accept": "application/json"
58
+ }
59
+
60
+ files = {}
61
+
62
+ try:
63
+ progress(0, desc="Starting image generation...")
64
+
65
+ # Prepare all required files
66
+ files = {
67
+ 'input_image': open(input_image, 'rb'),
68
+ 'logo_image': open(ASSETS["logo"], 'rb'),
69
+ 'reference_image': open(ASSETS[garment_type]["reference_image"], 'rb')
70
+ }
71
+ data = {
72
+ 'reference_image_type': (f'{garment_type}'),
73
+ 'output_format': (None, "png") # Hardcoded to PNG
74
+ }
75
+
76
+ # Submit the job with timeout
77
+ print(headers)
78
+ print(files)
79
+ response = requests.post(post_url, headers=headers, files=files, data=data, timeout=10)
80
+ response.raise_for_status()
81
+ job_data = response.json()
82
+ job_id = job_data.get('id')
83
+
84
+ if not job_id:
85
+ return None, "Error: No job ID received in response"
86
+
87
+ # Now poll for results with optimized timing
88
+ get_url = f"https://api.stability.ai/private/alo/v1/results/{job_id}"
89
+ headers = {
90
+ "authorization": f"{api_key}",
91
+ "accept": "*/*"
92
+ }
93
+
94
+ progress(0.3, desc="Processing your image...")
95
+
96
+ # Optimized polling strategy
97
+ max_attempts = 20
98
+ initial_delay = 1.0
99
+ max_delay = 5.0
100
+ current_delay = initial_delay
101
+
102
+ for attempt in range(max_attempts):
103
+ if not active_requests.get(request_id, False):
104
+ return None, "Request cancelled by user"
105
+
106
+ time.sleep(current_delay)
107
+ progress(0.3 + (0.7 * attempt/max_attempts),
108
+ desc=f"Checking status (attempt {attempt + 1}/{max_attempts})")
109
+
110
+ try:
111
+ with requests.Session() as session:
112
+ response = session.get(get_url, headers=headers, timeout=10)
113
+
114
+ if response.status_code == 200:
115
+ if 'image' in response.headers.get('Content-Type', ''):
116
+ img = Image.open(io.BytesIO(response.content))
117
+ progress(1.0, desc="Done!")
118
+ return img, f"Success! Job ID: {job_id}"
119
+ else:
120
+ json_response = response.json()
121
+ if json_response.get('status') == 'processing':
122
+ current_delay = min(current_delay * 1.5, max_delay)
123
+ continue
124
+ return None, f"API response: {json_response}"
125
+
126
+ elif response.status_code == 202:
127
+ current_delay = min(current_delay * 1.5, max_delay)
128
+ continue
129
+
130
+ else:
131
+ response.raise_for_status()
132
+
133
+ except requests.exceptions.RequestException:
134
+ current_delay = min(current_delay * 1.5, max_delay)
135
+ continue
136
+
137
+ return None, f"Timeout after {max_attempts} attempts. Job ID: {job_id}"
138
+
139
+ except Exception as e:
140
+ return None, f"Error: {str(e)}"
141
+ finally:
142
+ # Clean up file handles
143
+ for key in files:
144
+ if hasattr(files[key], 'close'):
145
+ files[key].close()
146
+
147
+ finally:
148
+ # Clean up the request tracking
149
+ active_requests.pop(request_id, None)
150
+
151
+ def cancel_request():
152
+ """Function to cancel active requests"""
153
+ for req_id in list(active_requests.keys()):
154
+ active_requests[req_id] = False
155
+
156
+ with gr.Blocks(title="Virtual Try-On Demo") as demo:
157
+ gr.Markdown("""
158
+ # Virtual Try-On Demo v1
159
+ Upload your photo and select garment type to generate your VTon image.
160
+ """)
161
+
162
+ with gr.Row():
163
+ with gr.Column():
164
+ api_key = gr.Textbox(
165
+ label="API Key",
166
+ value="",
167
+ type="password"
168
+ )
169
+
170
+ input_image = gr.Image(
171
+ label="Upload Your Photo",
172
+ type="filepath",
173
+ sources=["upload"],
174
+ height=300
175
+ )
176
+
177
+ garment_type = gr.Dropdown(
178
+ label="Garment Type",
179
+ choices=["top", "bottom"],
180
+ value="top"
181
+ )
182
+
183
+ with gr.Row():
184
+ submit_btn = gr.Button("Generate Image", variant="primary")
185
+ cancel_btn = gr.Button("Cancel Request")
186
+
187
+ with gr.Column():
188
+ output_image = gr.Image(
189
+ label="Generated Result",
190
+ interactive=False,
191
+ height=400
192
+ )
193
+ status_output = gr.Textbox(
194
+ label="Status",
195
+ interactive=False
196
+ )
197
+
198
+ submit_btn.click(
199
+ fn=generate_and_wait_for_image,
200
+ inputs=[api_key, input_image, garment_type],
201
+ outputs=[output_image, status_output]
202
+ )
203
+
204
+ cancel_btn.click(
205
+ fn=cancel_request,
206
+ inputs=None,
207
+ outputs=None,
208
+ queue=False
209
+ )
210
+
211
+ if __name__ == "__main__":
212
+ # Verify required files exist before launching
213
+ try:
214
+ validate_assets()
215
+ demo.launch()
216
+ except FileNotFoundError as e:
217
+ print(f"Error: {str(e)}")
218
+ print("Please make sure these files exist in the same directory:")
219
+ print("- logo.png")
220
+ print("- pullover.png (for top selection)")
221
+ print("- sweatpants.png (for bottom selection)")
logo.png ADDED
pullover.png ADDED

Git LFS Details

  • SHA256: 801dcb4715c6fac9b803705954e5aed38419c2422d097b6f45f402ea3d1af019
  • Pointer size: 131 Bytes
  • Size of remote file: 236 kB
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio
2
+ requests
sweatpants.png ADDED

Git LFS Details

  • SHA256: 7e530ca6942124b7627de4bb5750317ace26debcd439c242ce4d8e003f19abd7
  • Pointer size: 131 Bytes
  • Size of remote file: 681 kB