chenguittiMaroua commited on
Commit
1cdc794
·
verified ·
1 Parent(s): c6e8137

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +163 -0
main.py CHANGED
@@ -17,6 +17,13 @@ from slowapi import Limiter
17
  from slowapi.util import get_remote_address
18
  from slowapi.errors import RateLimitExceeded
19
  from slowapi.middleware import SlowAPIMiddleware
 
 
 
 
 
 
 
20
 
21
  # Initialize rate limiter
22
  limiter = Limiter(key_func=get_remote_address)
@@ -248,5 +255,161 @@ async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
248
  content={"detail": "Too many requests. Please try again later."}
249
  )
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  if __name__ == "__main__":
252
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
17
  from slowapi.util import get_remote_address
18
  from slowapi.errors import RateLimitExceeded
19
  from slowapi.middleware import SlowAPIMiddleware
20
+ import matplotlib.pyplot as plt
21
+ import seaborn as sns
22
+ import tempfile
23
+ import base64
24
+ from io import BytesIO
25
+ from typing import Optional
26
+ from pydantic import BaseModel
27
 
28
  # Initialize rate limiter
29
  limiter = Limiter(key_func=get_remote_address)
 
255
  content={"detail": "Too many requests. Please try again later."}
256
  )
257
 
258
+
259
+
260
+
261
+
262
+ # Add this new Pydantic model for visualization requests
263
+ class VisualizationRequest(BaseModel):
264
+ chart_type: str
265
+ x_column: Optional[str] = None
266
+ y_column: Optional[str] = None
267
+ hue_column: Optional[str] = None
268
+ title: Optional[str] = None
269
+ x_label: Optional[str] = None
270
+ y_label: Optional[str] = None
271
+ style: str = "seaborn" # seaborn or matplotlib
272
+
273
+ # Add this new function for visualization code generation
274
+ def generate_visualization(df: pd.DataFrame, request: VisualizationRequest) -> str:
275
+ """Generate and execute visualization code based on request"""
276
+ plt.style.use(request.style)
277
+
278
+ code_lines = [
279
+ "import matplotlib.pyplot as plt",
280
+ "import seaborn as sns",
281
+ "import pandas as pd",
282
+ "",
283
+ "# Data preparation",
284
+ f"df = pd.DataFrame({df.head().to_dict()})", # Simplified for demo
285
+ "",
286
+ "# Visualization code"
287
+ ]
288
+
289
+ if request.chart_type == "line":
290
+ code_lines.append(f"plt.figure(figsize=(10, 6))")
291
+ if request.hue_column:
292
+ code_lines.append(f"sns.lineplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
293
+ else:
294
+ code_lines.append(f"plt.plot(df['{request.x_column}'], df['{request.y_column}'])")
295
+ elif request.chart_type == "bar":
296
+ code_lines.append(f"plt.figure(figsize=(10, 6))")
297
+ if request.hue_column:
298
+ code_lines.append(f"sns.barplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
299
+ else:
300
+ code_lines.append(f"plt.bar(df['{request.x_column}'], df['{request.y_column}'])")
301
+ elif request.chart_type == "scatter":
302
+ code_lines.append(f"plt.figure(figsize=(10, 6))")
303
+ if request.hue_column:
304
+ code_lines.append(f"sns.scatterplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
305
+ else:
306
+ code_lines.append(f"plt.scatter(df['{request.x_column}'], df['{request.y_column}'])")
307
+ elif request.chart_type == "histogram":
308
+ code_lines.append(f"plt.figure(figsize=(10, 6))")
309
+ code_lines.append(f"plt.hist(df['{request.x_column}'], bins=20)")
310
+ else:
311
+ raise ValueError("Unsupported chart type")
312
+
313
+ # Add labels and title
314
+ if request.title:
315
+ code_lines.append(f"plt.title('{request.title}')")
316
+ if request.x_label:
317
+ code_lines.append(f"plt.xlabel('{request.x_label}')")
318
+ if request.y_label:
319
+ code_lines.append(f"plt.ylabel('{request.y_label}')")
320
+
321
+ code_lines.append("plt.tight_layout()")
322
+ code_lines.append("plt.show()")
323
+
324
+ return "\n".join(code_lines)
325
+
326
+ # Add this new endpoint for visualization
327
+ @app.post("/visualize")
328
+ @limiter.limit("5/minute")
329
+ async def generate_visualization_from_excel(
330
+ request: Request,
331
+ file: UploadFile = File(...),
332
+ chart_type: str = Form(...),
333
+ x_column: Optional[str] = Form(None),
334
+ y_column: Optional[str] = Form(None),
335
+ hue_column: Optional[str] = Form(None),
336
+ title: Optional[str] = Form(None),
337
+ x_label: Optional[str] = Form(None),
338
+ y_label: Optional[str] = Form(None),
339
+ style: str = Form("seaborn")
340
+ ):
341
+ try:
342
+ # Validate file
343
+ file_ext, content = await validate_file(file)
344
+ if file_ext not in {"xlsx", "xls"}:
345
+ raise HTTPException(400, "Only Excel files are supported for visualization")
346
+
347
+ # Read Excel file
348
+ df = pd.read_excel(io.BytesIO(content))
349
+
350
+ # Generate visualization request
351
+ vis_request = VisualizationRequest(
352
+ chart_type=chart_type,
353
+ x_column=x_column,
354
+ y_column=y_column,
355
+ hue_column=hue_column,
356
+ title=title,
357
+ x_label=x_label,
358
+ y_label=y_label,
359
+ style=style
360
+ )
361
+
362
+ # Generate and execute the visualization code
363
+ plt.figure()
364
+ exec(generate_visualization(df, vis_request), globals(), locals())
365
+
366
+ # Save the plot to a temporary file
367
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
368
+ plt.savefig(tmpfile.name, format='png', dpi=300)
369
+ plt.close()
370
+
371
+ # Read the image back as bytes
372
+ with open(tmpfile.name, "rb") as f:
373
+ image_bytes = f.read()
374
+
375
+ # Encode image as base64
376
+ image_base64 = base64.b64encode(image_bytes).decode('utf-8')
377
+
378
+ return {
379
+ "status": "success",
380
+ "image": f"data:image/png;base64,{image_base64}",
381
+ "code": generate_visualization(df, vis_request)
382
+ }
383
+
384
+ except HTTPException:
385
+ raise
386
+ except Exception as e:
387
+ logger.error(f"Visualization failed: {str(e)}\n{traceback.format_exc()}")
388
+ raise HTTPException(500, detail=f"Visualization failed: {str(e)}")
389
+
390
+ # Add this new endpoint for getting column names
391
+ @app.post("/get_columns")
392
+ @limiter.limit("10/minute")
393
+ async def get_excel_columns(
394
+ request: Request,
395
+ file: UploadFile = File(...)
396
+ ):
397
+ try:
398
+ file_ext, content = await validate_file(file)
399
+ if file_ext not in {"xlsx", "xls"}:
400
+ raise HTTPException(400, "Only Excel files are supported")
401
+
402
+ df = pd.read_excel(io.BytesIO(content))
403
+ return {
404
+ "columns": list(df.columns),
405
+ "sample_data": df.head().to_dict(orient='records')
406
+ }
407
+ except Exception as e:
408
+ logger.error(f"Column extraction failed: {str(e)}")
409
+ raise HTTPException(500, detail="Failed to extract columns from Excel file")
410
+
411
+
412
+
413
+
414
  if __name__ == "__main__":
415
  uvicorn.run(app, host="0.0.0.0", port=7860)