Spaces:
Running
Running
File size: 5,907 Bytes
fdd10b2 cf9b762 fdd10b2 cf9b762 fdd10b2 1281704 d5758be 06cb664 d5758be 851f0d3 06cb664 1281704 fdd10b2 1281704 fdd10b2 4d8ffe1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import gradio as gr
from gradio_client import Client, handle_file
import re
import time
import os
import traceback
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Get Hugging Face token from environment variable
hf_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
# Initialize client with auth
client = Client(
"levihsu/OOTDiffusion",
hf_token=hf_token
)
def generate_outfit(model_image, garment_image, n_samples=1, n_steps=20, image_scale=2, seed=-1):
if model_image is None or garment_image is None:
return None, "Please upload both model and garment images"
max_retries = 3
for attempt in range(max_retries):
try:
# Use the client to predict
result = client.predict(
vton_img=handle_file(model_image),
garm_img=handle_file(garment_image),
n_samples=n_samples,
n_steps=n_steps,
image_scale=image_scale,
seed=seed,
api_name="/process_hd"
)
# If result is a list, get the first item
if isinstance(result, list):
result = result[0]
# If result is a dictionary, try to get the image path
if isinstance(result, dict):
if 'image' in result:
return result['image'], None
else:
return None, "API returned unexpected format"
return result, None
except Exception as e:
error_msg = str(e)
if "exceeded your GPU quota" in error_msg:
wait_time_match = re.search(r'retry in (\d+:\d+:\d+)', error_msg)
wait_time = wait_time_match.group(1) if wait_time_match else "60:00" # Default to 1 hour
wait_seconds = sum(int(x) * 60 ** i for i, x in enumerate(reversed(wait_time.split(':')))) # Convert wait time to seconds
if attempt < max_retries - 1:
time.sleep(wait_seconds) # Wait before retrying
return None, f"GPU quota exceeded. Please wait {wait_time} before trying again."
else:
# Log the full traceback for debugging
traceback.print_exc()
return None, f"Error: {str(e)}"
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("""
## Zuri Africa - Try On Virtual Outfits
⚠️ **Note**: You need to select or upload your model:
- Followed by selecting or uploading the garment you want to try-on.
- Next click Try Outfit button and wait
""")
with gr.Row():
with gr.Column():
model_image = gr.Image(
label="Upload Model Image (person wearing clothes)",
type="filepath",
height=300
)
model_examples = [
"https://levihsu-ootdiffusion.hf.space/file=/tmp/gradio/ba5ba7978e7302e8ab5eb733cc7221394c4e6faf/model_5.png",
"https://levihsu-ootdiffusion.hf.space/file=/tmp/gradio/40dade4a04a827c0fdf63c6c70b42ef26480f391/01861_00.jpg",
"https://levihsu-ootdiffusion.hf.space/file=/tmp/gradio/3c4639c5fab3cdcd3239609dca5afee7b0677286/model_6.png",
"https://levihsu-ootdiffusion.hf.space/file=/tmp/gradio/0089171df270f4532eec3d80a8f36cc8218c6840/01008_00.jpg"
]
gr.Examples(examples=model_examples, inputs=model_image)
garment_image = gr.Image(
label="Upload Garment Image (clothing item)",
type="filepath",
height=300
)
garment_examples = [
"https://levihsu-ootdiffusion.hf.space/file=/tmp/gradio/180d4e2a1139071a8685a5edee7ab24bcf1639f5/03244_00.jpg",
"https://levihsu-ootdiffusion.hf.space/file=/tmp/gradio/584dda2c5ee1d8271a6cd06225c07db89c79ca03/04825_00.jpg",
"https://levihsu-ootdiffusion.hf.space/file=/tmp/gradio/a51938ec99f13e548d365a9ca6d794b6fe7462af/049949_1.jpg",
"https://levihsu-ootdiffusion.hf.space/file=/tmp/gradio/2d64241101189251ce415df84dc9205cda9a36ca/03032_00.jpg",
"https://levihsu-ootdiffusion.hf.space/file=/tmp/gradio/44aee6b576cae51eeb979311306375b56b7e0d8b/02305_00.jpg",
"https://levihsu-ootdiffusion.hf.space/file=/tmp/gradio/578dfa869dedb649e91eccbe566fc76435bb6bbe/049920_1.jpg"
]
gr.Examples(examples=garment_examples, inputs=garment_image)
with gr.Column():
output_image = gr.Image(label="Generated Output")
error_text = gr.Markdown() # Add error display
with gr.Row():
with gr.Column():
n_samples = gr.Slider(
label="Number of Samples",
minimum=1,
maximum=5,
step=1,
value=1
)
n_steps = gr.Slider(
label="Steps (lower = faster, try 10-15)",
minimum=1,
maximum=50,
step=1,
value=10 # Reduced default
)
image_scale = gr.Slider(
label="Scale (lower = faster, try 1-2)",
minimum=1,
maximum=5,
step=1,
value=1 # Reduced default
)
seed = gr.Number(
label="Random Seed (-1 for random)",
value=-1
)
generate_button = gr.Button("Generate Outfit")
# Set up the action for the button
generate_button.click(
fn=generate_outfit,
inputs=[model_image, garment_image, n_samples, n_steps, image_scale, seed],
outputs=[output_image, error_text]
)
# Launch the app
demo.launch(share=True) |