chenguittiMaroua commited on
Commit
56b4bf4
·
verified ·
1 Parent(s): 043cd21

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +238 -92
main.py CHANGED
@@ -2,7 +2,7 @@ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import JSONResponse
4
  from transformers import pipeline
5
- from typing import Tuple
6
  import io
7
  import fitz # PyMuPDF
8
  from PIL import Image
@@ -22,8 +22,9 @@ import seaborn as sns
22
  import tempfile
23
  import base64
24
  from io import BytesIO
25
- from typing import Optional
26
  from pydantic import BaseModel
 
 
27
 
28
  # Initialize rate limiter
29
  limiter = Limiter(key_func=get_remote_address)
@@ -154,6 +155,151 @@ def extract_text(content: bytes, file_ext: str) -> str:
154
  logger.error(f"Text extraction failed for {file_ext}: {str(e)}")
155
  raise HTTPException(422, f"Failed to extract text from {file_ext} file")
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  @app.post("/summarize")
158
  @limiter.limit("5/minute")
159
  async def summarize_document(request: Request, file: UploadFile = File(...)):
@@ -248,85 +394,9 @@ async def question_answering(
248
  logger.error(f"QA processing failed: {str(e)}")
249
  raise HTTPException(500, detail=f"Analysis failed: {str(e)}")
250
 
251
- @app.exception_handler(RateLimitExceeded)
252
- async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
253
- return JSONResponse(
254
- status_code=429,
255
- content={"detail": "Too many requests. Please try again later."}
256
- )
257
-
258
-
259
-
260
-
261
-
262
- # Add this new Pydantic model for visualization requests
263
- class VisualizationRequest(BaseModel):
264
- chart_type: str
265
- x_column: Optional[str] = None
266
- y_column: Optional[str] = None
267
- hue_column: Optional[str] = None
268
- title: Optional[str] = None
269
- x_label: Optional[str] = None
270
- y_label: Optional[str] = None
271
- style: str = "seaborn" # seaborn or matplotlib
272
-
273
- # Add this new function for visualization code generation
274
- def generate_visualization(df: pd.DataFrame, request: VisualizationRequest) -> str:
275
- """Generate and execute visualization code based on request"""
276
- plt.style.use(request.style)
277
-
278
- code_lines = [
279
- "import matplotlib.pyplot as plt",
280
- "import seaborn as sns",
281
- "import pandas as pd",
282
- "",
283
- "# Data preparation",
284
- f"df = pd.DataFrame({df.head().to_dict()})", # Simplified for demo
285
- "",
286
- "# Visualization code"
287
- ]
288
-
289
- if request.chart_type == "line":
290
- code_lines.append(f"plt.figure(figsize=(10, 6))")
291
- if request.hue_column:
292
- code_lines.append(f"sns.lineplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
293
- else:
294
- code_lines.append(f"plt.plot(df['{request.x_column}'], df['{request.y_column}'])")
295
- elif request.chart_type == "bar":
296
- code_lines.append(f"plt.figure(figsize=(10, 6))")
297
- if request.hue_column:
298
- code_lines.append(f"sns.barplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
299
- else:
300
- code_lines.append(f"plt.bar(df['{request.x_column}'], df['{request.y_column}'])")
301
- elif request.chart_type == "scatter":
302
- code_lines.append(f"plt.figure(figsize=(10, 6))")
303
- if request.hue_column:
304
- code_lines.append(f"sns.scatterplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
305
- else:
306
- code_lines.append(f"plt.scatter(df['{request.x_column}'], df['{request.y_column}'])")
307
- elif request.chart_type == "histogram":
308
- code_lines.append(f"plt.figure(figsize=(10, 6))")
309
- code_lines.append(f"plt.hist(df['{request.x_column}'], bins=20)")
310
- else:
311
- raise ValueError("Unsupported chart type")
312
-
313
- # Add labels and title
314
- if request.title:
315
- code_lines.append(f"plt.title('{request.title}')")
316
- if request.x_label:
317
- code_lines.append(f"plt.xlabel('{request.x_label}')")
318
- if request.y_label:
319
- code_lines.append(f"plt.ylabel('{request.y_label}')")
320
-
321
- code_lines.append("plt.tight_layout()")
322
- code_lines.append("plt.show()")
323
-
324
- return "\n".join(code_lines)
325
-
326
- # Add this new endpoint for visualization
327
- @app.post("/visualize")
328
  @limiter.limit("5/minute")
329
- async def generate_visualization_from_excel(
330
  request: Request,
331
  file: UploadFile = File(...),
332
  chart_type: str = Form(...),
@@ -336,18 +406,29 @@ async def generate_visualization_from_excel(
336
  title: Optional[str] = Form(None),
337
  x_label: Optional[str] = Form(None),
338
  y_label: Optional[str] = Form(None),
339
- style: str = Form("seaborn")
 
340
  ):
341
  try:
342
  # Validate file
343
- file_ext, content = await validate_file(file)
344
  if file_ext not in {"xlsx", "xls"}:
345
  raise HTTPException(400, "Only Excel files are supported for visualization")
346
 
347
  # Read Excel file
348
  df = pd.read_excel(io.BytesIO(content))
349
 
350
- # Generate visualization request
 
 
 
 
 
 
 
 
 
 
351
  vis_request = VisualizationRequest(
352
  chart_type=chart_type,
353
  x_column=x_column,
@@ -356,12 +437,17 @@ async def generate_visualization_from_excel(
356
  title=title,
357
  x_label=x_label,
358
  y_label=y_label,
359
- style=style
 
360
  )
361
 
362
- # Generate and execute the visualization code
 
 
 
363
  plt.figure()
364
- exec(generate_visualization(df, vis_request), globals(), locals())
 
365
 
366
  # Save the plot to a temporary file
367
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
@@ -378,7 +464,8 @@ async def generate_visualization_from_excel(
378
  return {
379
  "status": "success",
380
  "image": f"data:image/png;base64,{image_base64}",
381
- "code": generate_visualization(df, vis_request)
 
382
  }
383
 
384
  except HTTPException:
@@ -387,7 +474,61 @@ async def generate_visualization_from_excel(
387
  logger.error(f"Visualization failed: {str(e)}\n{traceback.format_exc()}")
388
  raise HTTPException(500, detail=f"Visualization failed: {str(e)}")
389
 
390
- # Add this new endpoint for getting column names
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  @app.post("/get_columns")
392
  @limiter.limit("10/minute")
393
  async def get_excel_columns(
@@ -395,21 +536,26 @@ async def get_excel_columns(
395
  file: UploadFile = File(...)
396
  ):
397
  try:
398
- file_ext, content = await validate_file(file)
399
  if file_ext not in {"xlsx", "xls"}:
400
  raise HTTPException(400, "Only Excel files are supported")
401
 
402
  df = pd.read_excel(io.BytesIO(content))
403
  return {
404
  "columns": list(df.columns),
405
- "sample_data": df.head().to_dict(orient='records')
 
406
  }
407
  except Exception as e:
408
  logger.error(f"Column extraction failed: {str(e)}")
409
  raise HTTPException(500, detail="Failed to extract columns from Excel file")
410
 
411
-
412
-
 
 
 
 
413
 
414
  if __name__ == "__main__":
415
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import JSONResponse
4
  from transformers import pipeline
5
+ from typing import Tuple, Optional
6
  import io
7
  import fitz # PyMuPDF
8
  from PIL import Image
 
22
  import tempfile
23
  import base64
24
  from io import BytesIO
 
25
  from pydantic import BaseModel
26
+ import traceback
27
+ import ast
28
 
29
  # Initialize rate limiter
30
  limiter = Limiter(key_func=get_remote_address)
 
155
  logger.error(f"Text extraction failed for {file_ext}: {str(e)}")
156
  raise HTTPException(422, f"Failed to extract text from {file_ext} file")
157
 
158
+ # Visualization Models
159
+ class VisualizationRequest(BaseModel):
160
+ chart_type: str
161
+ x_column: Optional[str] = None
162
+ y_column: Optional[str] = None
163
+ hue_column: Optional[str] = None
164
+ title: Optional[str] = None
165
+ x_label: Optional[str] = None
166
+ y_label: Optional[str] = None
167
+ style: str = "seaborn"
168
+ filters: Optional[dict] = None
169
+
170
+ class NaturalLanguageRequest(BaseModel):
171
+ prompt: str
172
+ style: str = "seaborn"
173
+
174
+ def generate_visualization_code(df: pd.DataFrame, request: VisualizationRequest) -> str:
175
+ """Generate Python code for visualization based on request parameters"""
176
+ code_lines = [
177
+ "import matplotlib.pyplot as plt",
178
+ "import seaborn as sns",
179
+ "import pandas as pd",
180
+ "",
181
+ "# Data preparation",
182
+ f"df = pd.DataFrame({df.to_dict(orient='list')})",
183
+ ]
184
+
185
+ # Apply filters if specified
186
+ if request.filters:
187
+ filter_conditions = []
188
+ for column, condition in request.filters.items():
189
+ if isinstance(condition, dict):
190
+ if 'min' in condition and 'max' in condition:
191
+ filter_conditions.append(f"(df['{column}'] >= {condition['min']}) & (df['{column}'] <= {condition['max']})")
192
+ elif 'values' in condition:
193
+ values = ', '.join([f"'{v}'" if isinstance(v, str) else str(v) for v in condition['values']])
194
+ filter_conditions.append(f"df['{column}'].isin([{values}])")
195
+ else:
196
+ filter_conditions.append(f"df['{column}'] == {repr(condition)}")
197
+
198
+ if filter_conditions:
199
+ code_lines.extend([
200
+ "",
201
+ "# Apply filters",
202
+ f"df = df[{' & '.join(filter_conditions)}]"
203
+ ])
204
+
205
+ code_lines.extend([
206
+ "",
207
+ "# Visualization",
208
+ f"plt.style.use('{request.style}')",
209
+ f"plt.figure(figsize=(10, 6))"
210
+ ])
211
+
212
+ # Chart type specific code
213
+ if request.chart_type == "line":
214
+ if request.hue_column:
215
+ code_lines.append(f"sns.lineplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
216
+ else:
217
+ code_lines.append(f"plt.plot(df['{request.x_column}'], df['{request.y_column}'])")
218
+ elif request.chart_type == "bar":
219
+ if request.hue_column:
220
+ code_lines.append(f"sns.barplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
221
+ else:
222
+ code_lines.append(f"plt.bar(df['{request.x_column}'], df['{request.y_column}'])")
223
+ elif request.chart_type == "scatter":
224
+ if request.hue_column:
225
+ code_lines.append(f"sns.scatterplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
226
+ else:
227
+ code_lines.append(f"plt.scatter(df['{request.x_column}'], df['{request.y_column}'])")
228
+ elif request.chart_type == "histogram":
229
+ code_lines.append(f"plt.hist(df['{request.x_column}'], bins=20)")
230
+ elif request.chart_type == "boxplot":
231
+ if request.hue_column:
232
+ code_lines.append(f"sns.boxplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
233
+ else:
234
+ code_lines.append(f"sns.boxplot(data=df, x='{request.x_column}', y='{request.y_column}')")
235
+ elif request.chart_type == "heatmap":
236
+ code_lines.append(f"corr = df.corr()")
237
+ code_lines.append(f"sns.heatmap(corr, annot=True, cmap='coolwarm')")
238
+ else:
239
+ raise ValueError(f"Unsupported chart type: {request.chart_type}")
240
+
241
+ # Add labels and title
242
+ if request.title:
243
+ code_lines.append(f"plt.title('{request.title}')")
244
+ if request.x_label:
245
+ code_lines.append(f"plt.xlabel('{request.x_label}')")
246
+ if request.y_label:
247
+ code_lines.append(f"plt.ylabel('{request.y_label}')")
248
+
249
+ code_lines.extend([
250
+ "plt.tight_layout()",
251
+ "plt.show()"
252
+ ])
253
+
254
+ return "\n".join(code_lines)
255
+
256
+ def interpret_natural_language(prompt: str, df_columns: list) -> VisualizationRequest:
257
+ """Convert natural language prompt to visualization parameters"""
258
+ # Simple keyword-based interpretation (could be enhanced with NLP)
259
+ prompt = prompt.lower()
260
+
261
+ # Determine chart type
262
+ chart_type = "bar"
263
+ if "line" in prompt:
264
+ chart_type = "line"
265
+ elif "scatter" in prompt:
266
+ chart_type = "scatter"
267
+ elif "histogram" in prompt:
268
+ chart_type = "histogram"
269
+ elif "box" in prompt:
270
+ chart_type = "boxplot"
271
+ elif "heatmap" in prompt or "correlation" in prompt:
272
+ chart_type = "heatmap"
273
+
274
+ # Try to detect columns
275
+ x_col = None
276
+ y_col = None
277
+ hue_col = None
278
+
279
+ for col in df_columns:
280
+ if col.lower() in prompt:
281
+ if not x_col:
282
+ x_col = col
283
+ elif not y_col:
284
+ y_col = col
285
+ else:
286
+ hue_col = col
287
+
288
+ # Default to first columns if not detected
289
+ if not x_col and len(df_columns) > 0:
290
+ x_col = df_columns[0]
291
+ if not y_col and len(df_columns) > 1:
292
+ y_col = df_columns[1]
293
+
294
+ return VisualizationRequest(
295
+ chart_type=chart_type,
296
+ x_column=x_col,
297
+ y_column=y_col,
298
+ hue_column=hue_col,
299
+ title="Generated from: " + prompt[:50] + ("..." if len(prompt) > 50 else ""),
300
+ style="seaborn"
301
+ )
302
+
303
  @app.post("/summarize")
304
  @limiter.limit("5/minute")
305
  async def summarize_document(request: Request, file: UploadFile = File(...)):
 
394
  logger.error(f"QA processing failed: {str(e)}")
395
  raise HTTPException(500, detail=f"Analysis failed: {str(e)}")
396
 
397
+ @app.post("/visualize/code")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  @limiter.limit("5/minute")
399
+ async def visualize_with_code(
400
  request: Request,
401
  file: UploadFile = File(...),
402
  chart_type: str = Form(...),
 
406
  title: Optional[str] = Form(None),
407
  x_label: Optional[str] = Form(None),
408
  y_label: Optional[str] = Form(None),
409
+ style: str = Form("seaborn"),
410
+ filters: Optional[str] = Form(None)
411
  ):
412
  try:
413
  # Validate file
414
+ file_ext, content = await process_uploaded_file(file)
415
  if file_ext not in {"xlsx", "xls"}:
416
  raise HTTPException(400, "Only Excel files are supported for visualization")
417
 
418
  # Read Excel file
419
  df = pd.read_excel(io.BytesIO(content))
420
 
421
+ # Parse filters if provided
422
+ filter_dict = {}
423
+ if filters:
424
+ try:
425
+ filter_dict = ast.literal_eval(filters)
426
+ if not isinstance(filter_dict, dict):
427
+ filter_dict = {}
428
+ except:
429
+ filter_dict = {}
430
+
431
+ # Create visualization request
432
  vis_request = VisualizationRequest(
433
  chart_type=chart_type,
434
  x_column=x_column,
 
437
  title=title,
438
  x_label=x_label,
439
  y_label=y_label,
440
+ style=style,
441
+ filters=filter_dict
442
  )
443
 
444
+ # Generate visualization code
445
+ visualization_code = generate_visualization_code(df, vis_request)
446
+
447
+ # Execute the code to generate the plot
448
  plt.figure()
449
+ local_vars = {}
450
+ exec(visualization_code, globals(), local_vars)
451
 
452
  # Save the plot to a temporary file
453
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
 
464
  return {
465
  "status": "success",
466
  "image": f"data:image/png;base64,{image_base64}",
467
+ "code": visualization_code,
468
+ "data_preview": df.head().to_dict(orient='records')
469
  }
470
 
471
  except HTTPException:
 
474
  logger.error(f"Visualization failed: {str(e)}\n{traceback.format_exc()}")
475
  raise HTTPException(500, detail=f"Visualization failed: {str(e)}")
476
 
477
+ @app.post("/visualize/natural")
478
+ @limiter.limit("5/minute")
479
+ async def visualize_with_natural_language(
480
+ request: Request,
481
+ file: UploadFile = File(...),
482
+ prompt: str = Form(...),
483
+ style: str = Form("seaborn")
484
+ ):
485
+ try:
486
+ # Validate file
487
+ file_ext, content = await process_uploaded_file(file)
488
+ if file_ext not in {"xlsx", "xls"}:
489
+ raise HTTPException(400, "Only Excel files are supported for visualization")
490
+
491
+ # Read Excel file
492
+ df = pd.read_excel(io.BytesIO(content))
493
+
494
+ # Convert natural language to visualization parameters
495
+ nl_request = NaturalLanguageRequest(prompt=prompt, style=style)
496
+ vis_request = interpret_natural_language(nl_request.prompt, df.columns.tolist())
497
+
498
+ # Generate visualization code
499
+ visualization_code = generate_visualization_code(df, vis_request)
500
+
501
+ # Execute the code to generate the plot
502
+ plt.figure()
503
+ local_vars = {}
504
+ exec(visualization_code, globals(), local_vars)
505
+
506
+ # Save the plot to a temporary file
507
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
508
+ plt.savefig(tmpfile.name, format='png', dpi=300)
509
+ plt.close()
510
+
511
+ # Read the image back as bytes
512
+ with open(tmpfile.name, "rb") as f:
513
+ image_bytes = f.read()
514
+
515
+ # Encode image as base64
516
+ image_base64 = base64.b64encode(image_bytes).decode('utf-8')
517
+
518
+ return {
519
+ "status": "success",
520
+ "image": f"data:image/png;base64,{image_base64}",
521
+ "code": visualization_code,
522
+ "interpreted_parameters": vis_request.dict(),
523
+ "data_preview": df.head().to_dict(orient='records')
524
+ }
525
+
526
+ except HTTPException:
527
+ raise
528
+ except Exception as e:
529
+ logger.error(f"Natural language visualization failed: {str(e)}\n{traceback.format_exc()}")
530
+ raise HTTPException(500, detail=f"Visualization failed: {str(e)}")
531
+
532
  @app.post("/get_columns")
533
  @limiter.limit("10/minute")
534
  async def get_excel_columns(
 
536
  file: UploadFile = File(...)
537
  ):
538
  try:
539
+ file_ext, content = await process_uploaded_file(file)
540
  if file_ext not in {"xlsx", "xls"}:
541
  raise HTTPException(400, "Only Excel files are supported")
542
 
543
  df = pd.read_excel(io.BytesIO(content))
544
  return {
545
  "columns": list(df.columns),
546
+ "sample_data": df.head().to_dict(orient='records'),
547
+ "statistics": df.describe().to_dict() if len(df.select_dtypes(include=['number']).columns) > 0 else None
548
  }
549
  except Exception as e:
550
  logger.error(f"Column extraction failed: {str(e)}")
551
  raise HTTPException(500, detail="Failed to extract columns from Excel file")
552
 
553
+ @app.exception_handler(RateLimitExceeded)
554
+ async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
555
+ return JSONResponse(
556
+ status_code=429,
557
+ content={"detail": "Too many requests. Please try again later."}
558
+ )
559
 
560
  if __name__ == "__main__":
561
  uvicorn.run(app, host="0.0.0.0", port=7860)