ParthSadaria commited on
Commit
757b439
·
verified ·
1 Parent(s): e589919

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +27 -17
main.py CHANGED
@@ -206,9 +206,9 @@ async def get_completion(payload: Payload, request: Request):
206
 
207
  return StreamingResponse(stream_generator(payload_dict), media_type="application/json")
208
  # Remove the duplicated endpoint and combine the functionality
209
- @app.get("/images/generations") #pollinations.ai thanks to them :)
210
  async def generate_image(
211
- prompt: str,
212
  model: str = "flux", # Default model
213
  seed: Optional[int] = None,
214
  width: Optional[int] = None,
@@ -216,6 +216,7 @@ async def generate_image(
216
  nologo: Optional[bool] = True,
217
  private: Optional[bool] = None,
218
  enhance: Optional[bool] = None,
 
219
  ):
220
  """
221
  Generate an image using the Image Generation API.
@@ -223,22 +224,31 @@ async def generate_image(
223
  # Validate the image endpoint
224
  if not image_endpoint:
225
  raise HTTPException(status_code=500, detail="Image endpoint not configured in environment variables.")
226
-
227
- # Validate prompt
228
- if not prompt or not prompt.strip():
229
- raise HTTPException(status_code=400, detail="Prompt cannot be empty")
230
-
 
 
 
 
 
 
 
 
 
 
231
  # Sanitize and encode the prompt
232
- sanitized_prompt = prompt.strip()
233
- encoded_prompt = httpx.QueryParams({'prompt': sanitized_prompt}).get('prompt')
234
-
235
  # Construct the URL with the encoded prompt
236
  base_url = image_endpoint.rstrip('/') # Remove trailing slash if present
237
  url = f"{base_url}/{encoded_prompt}"
238
-
239
  # Prepare query parameters with validation
240
  params = {}
241
-
242
  if model and isinstance(model, str):
243
  params['model'] = model
244
  if seed is not None and isinstance(seed, int):
@@ -253,12 +263,12 @@ async def generate_image(
253
  params['private'] = str(private).lower()
254
  if enhance is not None:
255
  params['enhance'] = str(enhance).lower()
256
-
257
  try:
258
  timeout = httpx.Timeout(60.0) # Set a reasonable timeout
259
  async with httpx.AsyncClient(timeout=timeout) as client:
260
  response = await client.get(url, params=params, follow_redirects=True)
261
-
262
  # Check for various error conditions
263
  if response.status_code == 404:
264
  raise HTTPException(status_code=404, detail="Image generation service not found")
@@ -271,7 +281,7 @@ async def generate_image(
271
  status_code=response.status_code,
272
  detail=f"Image generation failed with status code {response.status_code}"
273
  )
274
-
275
  # Verify content type
276
  content_type = response.headers.get('content-type', '')
277
  if not content_type.startswith('image/'):
@@ -279,7 +289,7 @@ async def generate_image(
279
  status_code=500,
280
  detail=f"Unexpected content type received: {content_type}"
281
  )
282
-
283
  return StreamingResponse(
284
  response.iter_bytes(),
285
  media_type=content_type,
@@ -288,7 +298,7 @@ async def generate_image(
288
  'Pragma': 'no-cache'
289
  }
290
  )
291
-
292
  except httpx.TimeoutException:
293
  raise HTTPException(status_code=504, detail="Image generation request timed out")
294
  except httpx.RequestError as e:
 
206
 
207
  return StreamingResponse(stream_generator(payload_dict), media_type="application/json")
208
  # Remove the duplicated endpoint and combine the functionality
209
+ @app.api_route("/images/generations", methods=["GET", "POST"]) # Support both GET and POST
210
  async def generate_image(
211
+ prompt: Optional[str] = None,
212
  model: str = "flux", # Default model
213
  seed: Optional[int] = None,
214
  width: Optional[int] = None,
 
216
  nologo: Optional[bool] = True,
217
  private: Optional[bool] = None,
218
  enhance: Optional[bool] = None,
219
+ request: Request = None, # Access raw POST data
220
  ):
221
  """
222
  Generate an image using the Image Generation API.
 
224
  # Validate the image endpoint
225
  if not image_endpoint:
226
  raise HTTPException(status_code=500, detail="Image endpoint not configured in environment variables.")
227
+
228
+ # Handle GET and POST prompts
229
+ if request.method == "POST":
230
+ try:
231
+ body = await request.json() # Parse JSON body
232
+ prompt = body.get("prompt", "").strip()
233
+ if not prompt:
234
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
235
+ except Exception:
236
+ raise HTTPException(status_code=400, detail="Invalid JSON payload")
237
+ elif request.method == "GET":
238
+ if not prompt or not prompt.strip():
239
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
240
+ prompt = prompt.strip()
241
+
242
  # Sanitize and encode the prompt
243
+ encoded_prompt = httpx.QueryParams({'prompt': prompt}).get('prompt')
244
+
 
245
  # Construct the URL with the encoded prompt
246
  base_url = image_endpoint.rstrip('/') # Remove trailing slash if present
247
  url = f"{base_url}/{encoded_prompt}"
248
+
249
  # Prepare query parameters with validation
250
  params = {}
251
+
252
  if model and isinstance(model, str):
253
  params['model'] = model
254
  if seed is not None and isinstance(seed, int):
 
263
  params['private'] = str(private).lower()
264
  if enhance is not None:
265
  params['enhance'] = str(enhance).lower()
266
+
267
  try:
268
  timeout = httpx.Timeout(60.0) # Set a reasonable timeout
269
  async with httpx.AsyncClient(timeout=timeout) as client:
270
  response = await client.get(url, params=params, follow_redirects=True)
271
+
272
  # Check for various error conditions
273
  if response.status_code == 404:
274
  raise HTTPException(status_code=404, detail="Image generation service not found")
 
281
  status_code=response.status_code,
282
  detail=f"Image generation failed with status code {response.status_code}"
283
  )
284
+
285
  # Verify content type
286
  content_type = response.headers.get('content-type', '')
287
  if not content_type.startswith('image/'):
 
289
  status_code=500,
290
  detail=f"Unexpected content type received: {content_type}"
291
  )
292
+
293
  return StreamingResponse(
294
  response.iter_bytes(),
295
  media_type=content_type,
 
298
  'Pragma': 'no-cache'
299
  }
300
  )
301
+
302
  except httpx.TimeoutException:
303
  raise HTTPException(status_code=504, detail="Image generation request timed out")
304
  except httpx.RequestError as e: