Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -206,9 +206,9 @@ async def get_completion(payload: Payload, request: Request):
|
|
206 |
|
207 |
return StreamingResponse(stream_generator(payload_dict), media_type="application/json")
|
208 |
# Remove the duplicated endpoint and combine the functionality
|
209 |
-
@app.
|
210 |
async def generate_image(
|
211 |
-
prompt: str,
|
212 |
model: str = "flux", # Default model
|
213 |
seed: Optional[int] = None,
|
214 |
width: Optional[int] = None,
|
@@ -216,6 +216,7 @@ async def generate_image(
|
|
216 |
nologo: Optional[bool] = True,
|
217 |
private: Optional[bool] = None,
|
218 |
enhance: Optional[bool] = None,
|
|
|
219 |
):
|
220 |
"""
|
221 |
Generate an image using the Image Generation API.
|
@@ -223,22 +224,31 @@ async def generate_image(
|
|
223 |
# Validate the image endpoint
|
224 |
if not image_endpoint:
|
225 |
raise HTTPException(status_code=500, detail="Image endpoint not configured in environment variables.")
|
226 |
-
|
227 |
-
#
|
228 |
-
if
|
229 |
-
|
230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
# Sanitize and encode the prompt
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
# Construct the URL with the encoded prompt
|
236 |
base_url = image_endpoint.rstrip('/') # Remove trailing slash if present
|
237 |
url = f"{base_url}/{encoded_prompt}"
|
238 |
-
|
239 |
# Prepare query parameters with validation
|
240 |
params = {}
|
241 |
-
|
242 |
if model and isinstance(model, str):
|
243 |
params['model'] = model
|
244 |
if seed is not None and isinstance(seed, int):
|
@@ -253,12 +263,12 @@ async def generate_image(
|
|
253 |
params['private'] = str(private).lower()
|
254 |
if enhance is not None:
|
255 |
params['enhance'] = str(enhance).lower()
|
256 |
-
|
257 |
try:
|
258 |
timeout = httpx.Timeout(60.0) # Set a reasonable timeout
|
259 |
async with httpx.AsyncClient(timeout=timeout) as client:
|
260 |
response = await client.get(url, params=params, follow_redirects=True)
|
261 |
-
|
262 |
# Check for various error conditions
|
263 |
if response.status_code == 404:
|
264 |
raise HTTPException(status_code=404, detail="Image generation service not found")
|
@@ -271,7 +281,7 @@ async def generate_image(
|
|
271 |
status_code=response.status_code,
|
272 |
detail=f"Image generation failed with status code {response.status_code}"
|
273 |
)
|
274 |
-
|
275 |
# Verify content type
|
276 |
content_type = response.headers.get('content-type', '')
|
277 |
if not content_type.startswith('image/'):
|
@@ -279,7 +289,7 @@ async def generate_image(
|
|
279 |
status_code=500,
|
280 |
detail=f"Unexpected content type received: {content_type}"
|
281 |
)
|
282 |
-
|
283 |
return StreamingResponse(
|
284 |
response.iter_bytes(),
|
285 |
media_type=content_type,
|
@@ -288,7 +298,7 @@ async def generate_image(
|
|
288 |
'Pragma': 'no-cache'
|
289 |
}
|
290 |
)
|
291 |
-
|
292 |
except httpx.TimeoutException:
|
293 |
raise HTTPException(status_code=504, detail="Image generation request timed out")
|
294 |
except httpx.RequestError as e:
|
|
|
206 |
|
207 |
return StreamingResponse(stream_generator(payload_dict), media_type="application/json")
|
208 |
# Remove the duplicated endpoint and combine the functionality
|
209 |
+
@app.api_route("/images/generations", methods=["GET", "POST"]) # Support both GET and POST
|
210 |
async def generate_image(
|
211 |
+
prompt: Optional[str] = None,
|
212 |
model: str = "flux", # Default model
|
213 |
seed: Optional[int] = None,
|
214 |
width: Optional[int] = None,
|
|
|
216 |
nologo: Optional[bool] = True,
|
217 |
private: Optional[bool] = None,
|
218 |
enhance: Optional[bool] = None,
|
219 |
+
request: Request = None, # Access raw POST data
|
220 |
):
|
221 |
"""
|
222 |
Generate an image using the Image Generation API.
|
|
|
224 |
# Validate the image endpoint
|
225 |
if not image_endpoint:
|
226 |
raise HTTPException(status_code=500, detail="Image endpoint not configured in environment variables.")
|
227 |
+
|
228 |
+
# Handle GET and POST prompts
|
229 |
+
if request.method == "POST":
|
230 |
+
try:
|
231 |
+
body = await request.json() # Parse JSON body
|
232 |
+
prompt = body.get("prompt", "").strip()
|
233 |
+
if not prompt:
|
234 |
+
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
235 |
+
except Exception:
|
236 |
+
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|
237 |
+
elif request.method == "GET":
|
238 |
+
if not prompt or not prompt.strip():
|
239 |
+
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
240 |
+
prompt = prompt.strip()
|
241 |
+
|
242 |
# Sanitize and encode the prompt
|
243 |
+
encoded_prompt = httpx.QueryParams({'prompt': prompt}).get('prompt')
|
244 |
+
|
|
|
245 |
# Construct the URL with the encoded prompt
|
246 |
base_url = image_endpoint.rstrip('/') # Remove trailing slash if present
|
247 |
url = f"{base_url}/{encoded_prompt}"
|
248 |
+
|
249 |
# Prepare query parameters with validation
|
250 |
params = {}
|
251 |
+
|
252 |
if model and isinstance(model, str):
|
253 |
params['model'] = model
|
254 |
if seed is not None and isinstance(seed, int):
|
|
|
263 |
params['private'] = str(private).lower()
|
264 |
if enhance is not None:
|
265 |
params['enhance'] = str(enhance).lower()
|
266 |
+
|
267 |
try:
|
268 |
timeout = httpx.Timeout(60.0) # Set a reasonable timeout
|
269 |
async with httpx.AsyncClient(timeout=timeout) as client:
|
270 |
response = await client.get(url, params=params, follow_redirects=True)
|
271 |
+
|
272 |
# Check for various error conditions
|
273 |
if response.status_code == 404:
|
274 |
raise HTTPException(status_code=404, detail="Image generation service not found")
|
|
|
281 |
status_code=response.status_code,
|
282 |
detail=f"Image generation failed with status code {response.status_code}"
|
283 |
)
|
284 |
+
|
285 |
# Verify content type
|
286 |
content_type = response.headers.get('content-type', '')
|
287 |
if not content_type.startswith('image/'):
|
|
|
289 |
status_code=500,
|
290 |
detail=f"Unexpected content type received: {content_type}"
|
291 |
)
|
292 |
+
|
293 |
return StreamingResponse(
|
294 |
response.iter_bytes(),
|
295 |
media_type=content_type,
|
|
|
298 |
'Pragma': 'no-cache'
|
299 |
}
|
300 |
)
|
301 |
+
|
302 |
except httpx.TimeoutException:
|
303 |
raise HTTPException(status_code=504, detail="Image generation request timed out")
|
304 |
except httpx.RequestError as e:
|