Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -17,11 +17,12 @@ load_dotenv()
|
|
17 |
app = FastAPI()
|
18 |
|
19 |
# Get API keys and secret endpoint from environment variables
|
20 |
-
api_keys_str = os.getenv('API_KEYS')
|
21 |
valid_api_keys = api_keys_str.split(',') if api_keys_str else []
|
22 |
secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
|
23 |
secret_api_endpoint_2 = os.getenv('SECRET_API_ENDPOINT_2')
|
24 |
secret_api_endpoint_3 = os.getenv('SECRET_API_ENDPOINT_3') # New endpoint for searchgpt
|
|
|
25 |
|
26 |
# Validate if the main secret API endpoints are set
|
27 |
if not secret_api_endpoint or not secret_api_endpoint_2 or not secret_api_endpoint_3:
|
@@ -145,7 +146,89 @@ available_model_ids = [
|
|
145 |
"llama-3.1-8b", "gemini-1.5-flash", "mixtral-8x7b" , "command-r","gemini-pro",
|
146 |
"gpt-3.5-turbo", "command"
|
147 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
@app.post("/chat/completions")
|
150 |
@app.post("/v1/chat/completions")
|
151 |
async def get_completion(payload: Payload,request: Request):
|
|
|
17 |
app = FastAPI()
|
18 |
|
19 |
# Get API keys and secret endpoint from environment variables
|
20 |
+
api_keys_str = os.getenv('API_KEYS') #deprecated -_-
|
21 |
valid_api_keys = api_keys_str.split(',') if api_keys_str else []
|
22 |
secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
|
23 |
secret_api_endpoint_2 = os.getenv('SECRET_API_ENDPOINT_2')
|
24 |
secret_api_endpoint_3 = os.getenv('SECRET_API_ENDPOINT_3') # New endpoint for searchgpt
|
25 |
+
image_endpoint = os.getenv("IMAGE_ENDPOINT")
|
26 |
|
27 |
# Validate if the main secret API endpoints are set
|
28 |
if not secret_api_endpoint or not secret_api_endpoint_2 or not secret_api_endpoint_3:
|
|
|
146 |
"llama-3.1-8b", "gemini-1.5-flash", "mixtral-8x7b" , "command-r","gemini-pro",
|
147 |
"gpt-3.5-turbo", "command"
|
148 |
]
|
149 |
+
@app.get("/images/generations")
|
150 |
+
async def generate_image(
|
151 |
+
prompt: str,
|
152 |
+
model: str = "flux", # Default model
|
153 |
+
seed: Optional[int] = None,
|
154 |
+
width: Optional[int] = None,
|
155 |
+
height: Optional[int] = None,
|
156 |
+
nologo: Optional[bool] = True,
|
157 |
+
private: Optional[bool] = None,
|
158 |
+
enhance: Optional[bool] = None,
|
159 |
+
):
|
160 |
+
"""
|
161 |
+
Generate an image using the Image Generation API.
|
162 |
|
163 |
+
Example usage:
|
164 |
+
- /images/generations?prompt=beautiful+sunset&model=flux&width=1024&height=768
|
165 |
+
"""
|
166 |
+
# Ensure the IMAGE_ENDPOINT is configured
|
167 |
+
if not image_endpoint:
|
168 |
+
raise HTTPException(
|
169 |
+
status_code=500,
|
170 |
+
detail="Image endpoint not configured in environment variables."
|
171 |
+
)
|
172 |
+
|
173 |
+
# Check if prompt is valid
|
174 |
+
if not prompt or len(prompt.strip()) == 0:
|
175 |
+
raise HTTPException(
|
176 |
+
status_code=400,
|
177 |
+
detail="Prompt is required and cannot be empty. Example: prompt=beautiful+sunset"
|
178 |
+
)
|
179 |
+
|
180 |
+
# Construct the URL with the prompt
|
181 |
+
url = f"{image_endpoint}/{prompt.strip()}"
|
182 |
+
|
183 |
+
# Prepare query parameters
|
184 |
+
params = {
|
185 |
+
"model": model,
|
186 |
+
"seed": seed,
|
187 |
+
"width": width,
|
188 |
+
"height": height,
|
189 |
+
"nologo": nologo,
|
190 |
+
"private": private,
|
191 |
+
"enhance": enhance,
|
192 |
+
}
|
193 |
+
# Remove keys with `None` values to avoid invalid params
|
194 |
+
params = {k: v for k, v in params.items() if v is not None}
|
195 |
+
|
196 |
+
try:
|
197 |
+
# Send GET request to the image generation endpoint
|
198 |
+
async with httpx.AsyncClient() as client:
|
199 |
+
response = await client.get(url, params=params)
|
200 |
+
|
201 |
+
# Handle non-successful HTTP status codes
|
202 |
+
if response.status_code == 400:
|
203 |
+
raise HTTPException(
|
204 |
+
status_code=400,
|
205 |
+
detail="Invalid request to the image generation API. Please check your prompt and parameters."
|
206 |
+
)
|
207 |
+
elif response.status_code == 500:
|
208 |
+
raise HTTPException(
|
209 |
+
status_code=500,
|
210 |
+
detail="Server error occurred while generating the image. Please try again later."
|
211 |
+
)
|
212 |
+
elif response.status_code != 200:
|
213 |
+
raise HTTPException(
|
214 |
+
status_code=response.status_code,
|
215 |
+
detail=f"Error generating image: {response.text}"
|
216 |
+
)
|
217 |
+
|
218 |
+
# Return the image as a file response
|
219 |
+
return StreamingResponse(response.aiter_bytes(), media_type="image/jpeg")
|
220 |
+
|
221 |
+
except httpx.TimeoutException:
|
222 |
+
raise HTTPException(
|
223 |
+
status_code=504,
|
224 |
+
detail="The request to the image generation API timed out. Please try again later."
|
225 |
+
)
|
226 |
+
except httpx.RequestError as e:
|
227 |
+
# Handle other request errors
|
228 |
+
raise HTTPException(
|
229 |
+
status_code=500,
|
230 |
+
detail=f"Error contacting the image endpoint: {str(e)}"
|
231 |
+
)
|
232 |
@app.post("/chat/completions")
|
233 |
@app.post("/v1/chat/completions")
|
234 |
async def get_completion(payload: Payload,request: Request):
|