ParthSadaria commited on
Commit
029405b
·
verified ·
1 Parent(s): 2d6ed81

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +84 -1
main.py CHANGED
@@ -17,11 +17,12 @@ load_dotenv()
17
  app = FastAPI()
18
 
19
  # Get API keys and secret endpoint from environment variables
20
- api_keys_str = os.getenv('API_KEYS')
21
  valid_api_keys = api_keys_str.split(',') if api_keys_str else []
22
  secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
23
  secret_api_endpoint_2 = os.getenv('SECRET_API_ENDPOINT_2')
24
  secret_api_endpoint_3 = os.getenv('SECRET_API_ENDPOINT_3') # New endpoint for searchgpt
 
25
 
26
  # Validate if the main secret API endpoints are set
27
  if not secret_api_endpoint or not secret_api_endpoint_2 or not secret_api_endpoint_3:
@@ -145,7 +146,89 @@ available_model_ids = [
145
  "llama-3.1-8b", "gemini-1.5-flash", "mixtral-8x7b" , "command-r","gemini-pro",
146
  "gpt-3.5-turbo", "command"
147
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  @app.post("/chat/completions")
150
  @app.post("/v1/chat/completions")
151
  async def get_completion(payload: Payload,request: Request):
 
17
  app = FastAPI()
18
 
19
  # Get API keys and secret endpoint from environment variables
20
+ api_keys_str = os.getenv('API_KEYS') #deprecated -_-
21
  valid_api_keys = api_keys_str.split(',') if api_keys_str else []
22
  secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
23
  secret_api_endpoint_2 = os.getenv('SECRET_API_ENDPOINT_2')
24
  secret_api_endpoint_3 = os.getenv('SECRET_API_ENDPOINT_3') # New endpoint for searchgpt
25
+ image_endpoint = os.getenv("IMAGE_ENDPOINT")
26
 
27
  # Validate if the main secret API endpoints are set
28
  if not secret_api_endpoint or not secret_api_endpoint_2 or not secret_api_endpoint_3:
 
146
  "llama-3.1-8b", "gemini-1.5-flash", "mixtral-8x7b" , "command-r","gemini-pro",
147
  "gpt-3.5-turbo", "command"
148
  ]
149
+ @app.get("/images/generations")
150
+ async def generate_image(
151
+ prompt: str,
152
+ model: str = "flux", # Default model
153
+ seed: Optional[int] = None,
154
+ width: Optional[int] = None,
155
+ height: Optional[int] = None,
156
+ nologo: Optional[bool] = True,
157
+ private: Optional[bool] = None,
158
+ enhance: Optional[bool] = None,
159
+ ):
160
+ """
161
+ Generate an image using the Image Generation API.
162
 
163
+ Example usage:
164
+ - /images/generations?prompt=beautiful+sunset&model=flux&width=1024&height=768
165
+ """
166
+ # Ensure the IMAGE_ENDPOINT is configured
167
+ if not image_endpoint:
168
+ raise HTTPException(
169
+ status_code=500,
170
+ detail="Image endpoint not configured in environment variables."
171
+ )
172
+
173
+ # Check if prompt is valid
174
+ if not prompt or len(prompt.strip()) == 0:
175
+ raise HTTPException(
176
+ status_code=400,
177
+ detail="Prompt is required and cannot be empty. Example: prompt=beautiful+sunset"
178
+ )
179
+
180
+ # Construct the URL with the prompt
181
+ url = f"{image_endpoint}/{prompt.strip()}"
182
+
183
+ # Prepare query parameters
184
+ params = {
185
+ "model": model,
186
+ "seed": seed,
187
+ "width": width,
188
+ "height": height,
189
+ "nologo": nologo,
190
+ "private": private,
191
+ "enhance": enhance,
192
+ }
193
+ # Remove keys with `None` values to avoid invalid params
194
+ params = {k: v for k, v in params.items() if v is not None}
195
+
196
+ try:
197
+ # Send GET request to the image generation endpoint
198
+ async with httpx.AsyncClient() as client:
199
+ response = await client.get(url, params=params)
200
+
201
+ # Handle non-successful HTTP status codes
202
+ if response.status_code == 400:
203
+ raise HTTPException(
204
+ status_code=400,
205
+ detail="Invalid request to the image generation API. Please check your prompt and parameters."
206
+ )
207
+ elif response.status_code == 500:
208
+ raise HTTPException(
209
+ status_code=500,
210
+ detail="Server error occurred while generating the image. Please try again later."
211
+ )
212
+ elif response.status_code != 200:
213
+ raise HTTPException(
214
+ status_code=response.status_code,
215
+ detail=f"Error generating image: {response.text}"
216
+ )
217
+
218
+ # Return the image as a file response
219
+ return StreamingResponse(response.aiter_bytes(), media_type="image/jpeg")
220
+
221
+ except httpx.TimeoutException:
222
+ raise HTTPException(
223
+ status_code=504,
224
+ detail="The request to the image generation API timed out. Please try again later."
225
+ )
226
+ except httpx.RequestError as e:
227
+ # Handle other request errors
228
+ raise HTTPException(
229
+ status_code=500,
230
+ detail=f"Error contacting the image endpoint: {str(e)}"
231
+ )
232
  @app.post("/chat/completions")
233
  @app.post("/v1/chat/completions")
234
  async def get_completion(payload: Payload,request: Request):