chenguittiMaroua commited on
Commit
74fd655
·
verified ·
1 Parent(s): 8ea794b

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +151 -462
main.py CHANGED
@@ -1,45 +1,31 @@
1
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import JSONResponse
4
- from transformers import pipeline
5
- from typing import Tuple, Optional
6
- import io
7
- import fitz # PyMuPDF
8
- from PIL import Image
9
  import pandas as pd
10
- import uvicorn
11
- from docx import Document
12
- from pptx import Presentation
13
- import pytesseract
 
14
  import logging
15
- import re
 
16
  from slowapi import Limiter
17
  from slowapi.util import get_remote_address
18
  from slowapi.errors import RateLimitExceeded
19
  from slowapi.middleware import SlowAPIMiddleware
20
- import matplotlib.pyplot as plt
21
- import seaborn as sns
22
- import tempfile
23
- import base64
24
- from io import BytesIO
25
- from pydantic import BaseModel
26
- import traceback
27
- import ast
28
 
29
- # Initialize rate limiter
30
- limiter = Limiter(key_func=get_remote_address)
31
-
32
- # Configure logging
33
- logging.basicConfig(level=logging.INFO)
34
- logger = logging.getLogger(__name__)
35
-
36
- app = FastAPI()
37
 
38
- # Apply rate limiting middleware
 
39
  app.state.limiter = limiter
40
  app.add_middleware(SlowAPIMiddleware)
41
 
42
- # CORS Configuration
43
  app.add_middleware(
44
  CORSMiddleware,
45
  allow_origins=["*"],
@@ -49,386 +35,152 @@ app.add_middleware(
49
 
50
  # Constants
51
  MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
52
- SUPPORTED_FILE_TYPES = {
53
- "docx", "xlsx", "pptx", "pdf", "jpg", "jpeg", "png"
54
- }
55
-
56
- # Model caching
57
- summarizer = None
58
- qa_model = None
59
- image_captioner = None
60
-
61
- def get_summarizer():
62
- global summarizer
63
- if summarizer is None:
64
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
65
- return summarizer
66
-
67
- def get_qa_model():
68
- global qa_model
69
- if qa_model is None:
70
- qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
71
- return qa_model
72
-
73
- def get_image_captioner():
74
- global image_captioner
75
- if image_captioner is None:
76
- image_captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
77
- return image_captioner
78
-
79
- async def process_uploaded_file(file: UploadFile) -> Tuple[str, bytes]:
80
- """Validate and process uploaded file with special handling for each type"""
81
- if not file.filename:
82
- raise HTTPException(400, "No filename provided")
83
-
84
- file_ext = file.filename.split('.')[-1].lower()
85
- if file_ext not in SUPPORTED_FILE_TYPES:
86
- raise HTTPException(400, f"Unsupported file type. Supported: {', '.join(SUPPORTED_FILE_TYPES)}")
87
-
88
- content = await file.read()
89
- if len(content) > MAX_FILE_SIZE:
90
- raise HTTPException(413, f"File too large. Max size: {MAX_FILE_SIZE//1024//1024}MB")
91
-
92
- # Special validation for PDFs
93
- if file_ext == "pdf":
94
- try:
95
- with fitz.open(stream=content, filetype="pdf") as doc:
96
- if doc.is_encrypted:
97
- if not doc.authenticate(""):
98
- raise ValueError("Encrypted PDF - cannot extract text")
99
- if len(doc) > 50:
100
- raise ValueError("PDF too large (max 50 pages)")
101
- except Exception as e:
102
- logger.error(f"PDF validation failed: {str(e)}")
103
- raise HTTPException(422, detail=f"Invalid PDF file: {str(e)}")
104
-
105
- await file.seek(0) # Reset file pointer for processing
106
- return file_ext, content
107
 
108
- def extract_text(content: bytes, file_ext: str) -> str:
109
- """Extract text from various file formats with enhanced support"""
110
- try:
111
- if file_ext == "docx":
112
- doc = Document(io.BytesIO(content))
113
- return "\n".join(para.text for para in doc.paragraphs if para.text.strip())
114
-
115
- elif file_ext in {"xlsx", "xls"}:
116
- df = pd.read_excel(io.BytesIO(content), sheet_name=None)
117
- all_text = []
118
- for sheet_name, sheet_data in df.items():
119
- sheet_text = []
120
- for column in sheet_data.columns:
121
- sheet_text.extend(sheet_data[column].dropna().astype(str).tolist())
122
- all_text.append(f"Sheet: {sheet_name}\n" + "\n".join(sheet_text))
123
- return "\n\n".join(all_text)
124
-
125
- elif file_ext == "pptx":
126
- ppt = Presentation(io.BytesIO(content))
127
- text = []
128
- for slide in ppt.slides:
129
- for shape in slide.shapes:
130
- if hasattr(shape, "text") and shape.text.strip():
131
- text.append(shape.text)
132
- return "\n".join(text)
133
-
134
- elif file_ext == "pdf":
135
- pdf = fitz.open(stream=content, filetype="pdf")
136
- return "\n".join(page.get_text("text") for page in pdf)
137
-
138
- elif file_ext in {"jpg", "jpeg", "png"}:
139
- # First try OCR
140
- try:
141
- image = Image.open(io.BytesIO(content))
142
- text = pytesseract.image_to_string(image, config='--psm 6')
143
- if text.strip():
144
- return text
145
-
146
- # If OCR fails, try image captioning
147
- captioner = get_image_captioner()
148
- result = captioner(image)
149
- return result[0]['generated_text']
150
- except Exception as img_e:
151
- logger.error(f"Image processing failed: {str(img_e)}")
152
- raise ValueError("Could not extract text or caption from image")
153
-
154
- except Exception as e:
155
- logger.error(f"Text extraction failed for {file_ext}: {str(e)}")
156
- raise HTTPException(422, f"Failed to extract text from {file_ext} file")
157
 
158
- # Visualization Models
159
  class VisualizationRequest(BaseModel):
160
  chart_type: str
161
- x_column: Optional[str] = None
162
  y_column: Optional[str] = None
163
  hue_column: Optional[str] = None
164
  title: Optional[str] = None
165
  x_label: Optional[str] = None
166
  y_label: Optional[str] = None
167
  style: str = "seaborn"
168
- filters: Optional[dict] = None
169
 
170
- class NaturalLanguageRequest(BaseModel):
171
- prompt: str
172
- style: str = "seaborn"
 
173
 
174
- def generate_visualization_code(df: pd.DataFrame, request: VisualizationRequest) -> str:
175
- """Generate Python code for visualization based on request parameters"""
176
- code_lines = [
177
- "import matplotlib.pyplot as plt",
178
- "import seaborn as sns",
179
- "import pandas as pd",
180
- "",
181
- "# Data preparation",
182
- f"df = pd.DataFrame({df.to_dict(orient='list')})",
183
- ]
184
-
185
- # Apply filters if specified
186
- if request.filters:
187
- filter_conditions = []
188
- for column, condition in request.filters.items():
189
- if isinstance(condition, dict):
190
- if 'min' in condition and 'max' in condition:
191
- filter_conditions.append(f"(df['{column}'] >= {condition['min']}) & (df['{column}'] <= {condition['max']})")
192
- elif 'values' in condition:
193
- values = ', '.join([f"'{v}'" if isinstance(v, str) else str(v) for v in condition['values']])
194
- filter_conditions.append(f"df['{column}'].isin([{values}])")
195
- else:
196
- filter_conditions.append(f"df['{column}'] == {repr(condition)}")
197
-
198
- if filter_conditions:
199
- code_lines.extend([
200
- "",
201
- "# Apply filters",
202
- f"df = df[{' & '.join(filter_conditions)}]"
203
- ])
204
-
205
- code_lines.extend([
206
- "",
207
- "# Visualization",
208
- f"plt.style.use('{request.style}')",
209
- f"plt.figure(figsize=(10, 6))"
210
- ])
211
-
212
- # Chart type specific code
213
- if request.chart_type == "line":
214
- if request.hue_column:
215
- code_lines.append(f"sns.lineplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
216
- else:
217
- code_lines.append(f"plt.plot(df['{request.x_column}'], df['{request.y_column}'])")
218
- elif request.chart_type == "bar":
219
- if request.hue_column:
220
- code_lines.append(f"sns.barplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
221
- else:
222
- code_lines.append(f"plt.bar(df['{request.x_column}'], df['{request.y_column}'])")
223
- elif request.chart_type == "scatter":
224
- if request.hue_column:
225
- code_lines.append(f"sns.scatterplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
226
- else:
227
- code_lines.append(f"plt.scatter(df['{request.x_column}'], df['{request.y_column}'])")
228
- elif request.chart_type == "histogram":
229
- code_lines.append(f"plt.hist(df['{request.x_column}'], bins=20)")
230
- elif request.chart_type == "boxplot":
231
- if request.hue_column:
232
- code_lines.append(f"sns.boxplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
233
- else:
234
- code_lines.append(f"sns.boxplot(data=df, x='{request.x_column}', y='{request.y_column}')")
235
- elif request.chart_type == "heatmap":
236
- code_lines.append(f"corr = df.corr()")
237
- code_lines.append(f"sns.heatmap(corr, annot=True, cmap='coolwarm')")
238
- else:
239
- raise ValueError(f"Unsupported chart type: {request.chart_type}")
240
-
241
- # Add labels and title
242
- if request.title:
243
- code_lines.append(f"plt.title('{request.title}')")
244
- if request.x_label:
245
- code_lines.append(f"plt.xlabel('{request.x_label}')")
246
- if request.y_label:
247
- code_lines.append(f"plt.ylabel('{request.y_label}')")
248
-
249
- code_lines.extend([
250
- "plt.tight_layout()",
251
- "plt.show()"
252
- ])
253
-
254
- return "\n".join(code_lines)
255
 
256
- def interpret_natural_language(prompt: str, df_columns: list) -> VisualizationRequest:
257
- """Convert natural language prompt to visualization parameters"""
258
- # Simple keyword-based interpretation (could be enhanced with NLP)
259
- prompt = prompt.lower()
260
-
261
- # Determine chart type
262
- chart_type = "bar"
263
- if "line" in prompt:
264
- chart_type = "line"
265
- elif "scatter" in prompt:
266
- chart_type = "scatter"
267
- elif "histogram" in prompt:
268
- chart_type = "histogram"
269
- elif "box" in prompt:
270
- chart_type = "boxplot"
271
- elif "heatmap" in prompt or "correlation" in prompt:
272
- chart_type = "heatmap"
273
-
274
- # Try to detect columns
275
- x_col = None
276
- y_col = None
277
- hue_col = None
278
-
279
- for col in df_columns:
280
- if col.lower() in prompt:
281
- if not x_col:
282
- x_col = col
283
- elif not y_col:
284
- y_col = col
285
- else:
286
- hue_col = col
287
-
288
- # Default to first columns if not detected
289
- if not x_col and len(df_columns) > 0:
290
- x_col = df_columns[0]
291
- if not y_col and len(df_columns) > 1:
292
- y_col = df_columns[1]
293
-
294
- return VisualizationRequest(
295
- chart_type=chart_type,
296
- x_column=x_col,
297
- y_column=y_col,
298
- hue_column=hue_col,
299
- title="Generated from: " + prompt[:50] + ("..." if len(prompt) > 50 else ""),
300
- style="seaborn"
301
- )
302
 
303
- @app.post("/summarize")
304
- @limiter.limit("5/minute")
305
- async def summarize_document(request: Request, file: UploadFile = File(...)):
 
 
306
  try:
307
- file_ext, content = await process_uploaded_file(file)
308
- text = extract_text(content, file_ext)
309
-
310
- if not text.strip():
311
- raise HTTPException(400, "No extractable text found")
312
-
313
- # Clean and chunk text
314
- text = re.sub(r'\s+', ' ', text).strip()
315
- chunks = [text[i:i+1000] for i in range(0, len(text), 1000)]
316
-
317
- # Summarize each chunk
318
- summarizer = get_summarizer()
319
- summaries = []
320
- for chunk in chunks:
321
- summary = summarizer(chunk, max_length=150, min_length=50, do_sample=False)[0]["summary_text"]
322
- summaries.append(summary)
323
-
324
- return {"summary": " ".join(summaries)}
325
-
326
- except HTTPException:
327
- raise
328
  except Exception as e:
329
- logger.error(f"Summarization failed: {str(e)}")
330
- raise HTTPException(500, "Document summarization failed")
 
 
 
 
331
 
332
- @app.post("/qa")
333
- @limiter.limit("5/minute")
334
- async def question_answering(
335
- request: Request,
336
- file: UploadFile = File(...),
337
- question: str = Form(...),
338
- language: str = Form("fr")
339
- ):
340
  try:
341
- file_ext, content = await process_uploaded_file(file)
342
- text = extract_text(content, file_ext)
343
-
344
- if not text.strip():
345
- raise HTTPException(400, "No extractable text found")
346
-
347
- # Clean and truncate text
348
- text = re.sub(r'\s+', ' ', text).strip()[:5000]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
 
350
- # Theme detection
351
- theme_keywords = ["thème", "sujet principal", "quoi le sujet", "theme", "main topic"]
352
- if any(kw in question.lower() for kw in theme_keywords):
353
- try:
354
- summarizer = get_summarizer()
355
- summary_output = summarizer(
356
- text,
357
- max_length=min(100, len(text)//4),
358
- min_length=30,
359
- do_sample=False,
360
- truncation=True
361
- )
362
-
363
- theme = summary_output[0].get("summary_text", text[:200] + "...")
364
- return {
365
- "question": question,
366
- "answer": f"Le document traite principalement de : {theme}",
367
- "confidence": 0.95,
368
- "language": language
369
- }
370
- except Exception:
371
- theme = text[:200] + ("..." if len(text) > 200 else "")
372
- return {
373
- "question": question,
374
- "answer": f"D'après le document : {theme}",
375
- "confidence": 0.7,
376
- "language": language,
377
- "warning": "theme_summary_fallback"
378
- }
379
 
380
- # Standard QA
381
- qa = get_qa_model()
382
- result = qa(question=question, context=text[:3000])
383
 
384
- return {
385
- "question": question,
386
- "answer": result["answer"],
387
- "confidence": result["score"],
388
- "language": language
389
- }
390
-
391
- except HTTPException:
392
- raise
393
  except Exception as e:
394
- logger.error(f"QA processing failed: {str(e)}")
395
- raise HTTPException(500, detail=f"Analysis failed: {str(e)}")
396
 
397
- @app.post("/visualize/code")
398
  @limiter.limit("5/minute")
399
- async def visualize_with_code(
400
  request: Request,
401
- file: UploadFile = File(...),
402
- chart_type: str = Form(...),
403
- x_column: Optional[str] = Form(None),
404
- y_column: Optional[str] = Form(None),
405
- hue_column: Optional[str] = Form(None),
406
- title: Optional[str] = Form(None),
407
- x_label: Optional[str] = Form(None),
408
- y_label: Optional[str] = Form(None),
409
- style: str = Form("seaborn"),
410
- filters: Optional[str] = Form(None)
411
  ):
412
  try:
413
- # Validate file
414
- file_ext, content = await process_uploaded_file(file)
415
- if file_ext not in {"xlsx", "xls"}:
416
- raise HTTPException(400, "Only Excel files are supported for visualization")
417
-
418
- # Read Excel file
419
- df = pd.read_excel(io.BytesIO(content))
420
 
 
 
 
421
  # Parse filters if provided
422
  filter_dict = {}
423
  if filters:
424
  try:
425
- filter_dict = ast.literal_eval(filters)
426
  if not isinstance(filter_dict, dict):
427
  filter_dict = {}
428
  except:
429
  filter_dict = {}
430
-
431
- # Create visualization request
432
  vis_request = VisualizationRequest(
433
  chart_type=chart_type,
434
  x_column=x_column,
@@ -441,117 +193,54 @@ async def visualize_with_code(
441
  filters=filter_dict
442
  )
443
 
444
- # Generate visualization code
445
- visualization_code = generate_visualization_code(df, vis_request)
446
-
447
- # Execute the code to generate the plot
448
- plt.figure()
449
- local_vars = {}
450
- exec(visualization_code, globals(), local_vars)
451
-
452
- # Save the plot to a temporary file
453
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
454
- plt.savefig(tmpfile.name, format='png', dpi=300)
455
- plt.close()
456
-
457
- # Read the image back as bytes
458
- with open(tmpfile.name, "rb") as f:
459
- image_bytes = f.read()
460
-
461
- # Encode image as base64
462
- image_base64 = base64.b64encode(image_bytes).decode('utf-8')
463
 
464
  return {
465
  "status": "success",
466
  "image": f"data:image/png;base64,{image_base64}",
467
- "code": visualization_code,
468
- "data_preview": df.head().to_dict(orient='records')
469
  }
470
-
471
- except HTTPException:
472
- raise
473
- except Exception as e:
474
- logger.error(f"Visualization failed: {str(e)}\n{traceback.format_exc()}")
475
- raise HTTPException(500, detail=f"Visualization failed: {str(e)}")
476
 
477
- @app.post("/visualize/natural")
478
- @limiter.limit("5/minute")
479
- async def visualize_with_natural_language(
480
- request: Request,
481
- file: UploadFile = File(...),
482
- prompt: str = Form(...),
483
- style: str = Form("seaborn")
484
- ):
485
- try:
486
- # Validate file
487
- file_ext, content = await process_uploaded_file(file)
488
- if file_ext not in {"xlsx", "xls"}:
489
- raise HTTPException(400, "Only Excel files are supported for visualization")
490
-
491
- # Read Excel file
492
- df = pd.read_excel(io.BytesIO(content))
493
-
494
- # Convert natural language to visualization parameters
495
- nl_request = NaturalLanguageRequest(prompt=prompt, style=style)
496
- vis_request = interpret_natural_language(nl_request.prompt, df.columns.tolist())
497
-
498
- # Generate visualization code
499
- visualization_code = generate_visualization_code(df, vis_request)
500
-
501
- # Execute the code to generate the plot
502
- plt.figure()
503
- local_vars = {}
504
- exec(visualization_code, globals(), local_vars)
505
-
506
- # Save the plot to a temporary file
507
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
508
- plt.savefig(tmpfile.name, format='png', dpi=300)
509
- plt.close()
510
-
511
- # Read the image back as bytes
512
- with open(tmpfile.name, "rb") as f:
513
- image_bytes = f.read()
514
-
515
- # Encode image as base64
516
- image_base64 = base64.b64encode(image_bytes).decode('utf-8')
517
-
518
- return {
519
- "status": "success",
520
- "image": f"data:image/png;base64,{image_base64}",
521
- "code": visualization_code,
522
- "interpreted_parameters": vis_request.dict(),
523
- "data_preview": df.head().to_dict(orient='records')
524
- }
525
-
526
- except HTTPException:
527
- raise
528
  except Exception as e:
529
- logger.error(f"Natural language visualization failed: {str(e)}\n{traceback.format_exc()}")
530
- raise HTTPException(500, detail=f"Visualization failed: {str(e)}")
531
 
532
  @app.post("/get_columns")
533
  @limiter.limit("10/minute")
534
- async def get_excel_columns(
535
  request: Request,
536
- file: UploadFile = File(...)
537
  ):
538
  try:
539
- file_ext, content = await process_uploaded_file(file)
540
- if file_ext not in {"xlsx", "xls"}:
541
- raise HTTPException(400, "Only Excel files are supported")
542
 
543
- df = pd.read_excel(io.BytesIO(content))
 
 
544
  return {
545
  "columns": list(df.columns),
546
- "sample_data": df.head().to_dict(orient='records'),
547
  "statistics": df.describe().to_dict() if len(df.select_dtypes(include=['number']).columns) > 0 else None
548
  }
 
 
 
 
 
 
549
  except Exception as e:
550
- logger.error(f"Column extraction failed: {str(e)}")
551
- raise HTTPException(500, detail="Failed to extract columns from Excel file")
552
 
553
  @app.exception_handler(RateLimitExceeded)
554
- async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
555
  return JSONResponse(
556
  status_code=429,
557
  content={"detail": "Too many requests. Please try again later."}
 
1
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import JSONResponse
4
+ from pydantic import BaseModel
5
+ from typing import Optional, Dict, List
 
 
 
6
  import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+ import seaborn as sns
9
+ import base64
10
+ import tempfile
11
+ import io
12
  import logging
13
+ import traceback
14
+ import uvicorn
15
  from slowapi import Limiter
16
  from slowapi.util import get_remote_address
17
  from slowapi.errors import RateLimitExceeded
18
  from slowapi.middleware import SlowAPIMiddleware
 
 
 
 
 
 
 
 
19
 
20
+ # Initialize FastAPI app
21
+ app = FastAPI(title="Data Visualization API", version="1.0")
 
 
 
 
 
 
22
 
23
+ # Rate limiting setup
24
+ limiter = Limiter(key_func=get_remote_address)
25
  app.state.limiter = limiter
26
  app.add_middleware(SlowAPIMiddleware)
27
 
28
+ # CORS configuration
29
  app.add_middleware(
30
  CORSMiddleware,
31
  allow_origins=["*"],
 
35
 
36
  # Constants
37
  MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
38
+ SUPPORTED_EXCEL_TYPES = {"xlsx": "Excel Workbook", "xls": "Excel 97-2003 Workbook"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ # Configure logging
41
+ logging.basicConfig(level=logging.INFO)
42
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
 
44
  class VisualizationRequest(BaseModel):
45
  chart_type: str
46
+ x_column: str
47
  y_column: Optional[str] = None
48
  hue_column: Optional[str] = None
49
  title: Optional[str] = None
50
  x_label: Optional[str] = None
51
  y_label: Optional[str] = None
52
  style: str = "seaborn"
53
+ filters: Optional[Dict] = None
54
 
55
+ async def validate_excel_file(file: UploadFile) -> bytes:
56
+ """Validate and process uploaded Excel file with clear error messages"""
57
+ if not file.filename:
58
+ raise HTTPException(400, "No filename provided")
59
 
60
+ file_ext = file.filename.split('.')[-1].lower()
61
+ if file_ext not in SUPPORTED_EXCEL_TYPES:
62
+ supported = ", ".join([f"{ext} ({desc})" for ext, desc in SUPPORTED_EXCEL_TYPES.items()])
63
+ raise HTTPException(
64
+ 400,
65
+ f"Unsupported file type '{file.filename}'. Please upload: {supported}"
66
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ content = await file.read()
69
+ if len(content) > MAX_FILE_SIZE:
70
+ raise HTTPException(413, f"File too large ({len(content)/1024/1024:.1f}MB). Max size: {MAX_FILE_SIZE//1024//1024}MB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ await file.seek(0)
73
+ return content
74
+
75
+ def read_excel_with_fallback(content: bytes) -> pd.DataFrame:
76
+ """Read Excel file with engine fallback and better error handling"""
77
  try:
78
+ # Try openpyxl first (for .xlsx)
79
+ return pd.read_excel(io.BytesIO(content), engine='openpyxl')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  except Exception as e:
81
+ logger.warning(f"Openpyxl failed, trying xlrd: {str(e)}")
82
+ try:
83
+ # Fallback to xlrd (for .xls)
84
+ return pd.read_excel(io.BytesIO(content), engine='xlrd')
85
+ except Exception as e:
86
+ raise ValueError(f"Failed to read Excel file with either engine. Error: {str(e)}")
87
 
88
+ def generate_visualization(df: pd.DataFrame, request: VisualizationRequest) -> str:
89
+ """Generate and save visualization with proper resource cleanup"""
 
 
 
 
 
 
90
  try:
91
+ plt.style.use(request.style)
92
+ fig, ax = plt.subplots(figsize=(10, 6))
93
+
94
+ # Apply filters if specified
95
+ if request.filters:
96
+ for col, condition in request.filters.items():
97
+ if isinstance(condition, dict):
98
+ if 'min' in condition and 'max' in condition:
99
+ df = df[(df[col] >= condition['min']) & (df[col] <= condition['max'])]
100
+ elif 'values' in condition:
101
+ df = df[df[col].isin(condition['values'])]
102
+ else:
103
+ df = df[df[col] == condition]
104
+
105
+ # Generate chart based on type
106
+ if request.chart_type == "line":
107
+ if request.hue_column:
108
+ sns.lineplot(data=df, x=request.x_column, y=request.y_column,
109
+ hue=request.hue_column, ax=ax)
110
+ else:
111
+ ax.plot(df[request.x_column], df[request.y_column])
112
+ elif request.chart_type == "bar":
113
+ if request.hue_column:
114
+ sns.barplot(data=df, x=request.x_column, y=request.y_column,
115
+ hue=request.hue_column, ax=ax)
116
+ else:
117
+ ax.bar(df[request.x_column], df[request.y_column])
118
+ elif request.chart_type == "scatter":
119
+ if request.hue_column:
120
+ sns.scatterplot(data=df, x=request.x_column, y=request.y_column,
121
+ hue=request.hue_column, ax=ax)
122
+ else:
123
+ ax.scatter(df[request.x_column], df[request.y_column])
124
+ elif request.chart_type == "histogram":
125
+ ax.hist(df[request.x_column], bins=20)
126
+ else:
127
+ raise ValueError(f"Unsupported chart type: {request.chart_type}")
128
 
129
+ # Add labels and title
130
+ if request.title:
131
+ ax.set_title(request.title)
132
+ if request.x_label:
133
+ ax.set_xlabel(request.x_label)
134
+ if request.y_label:
135
+ ax.set_ylabel(request.y_label)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
+ plt.tight_layout()
 
 
138
 
139
+ # Save to temporary file
140
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
141
+ fig.savefig(tmpfile.name, format='png', dpi=300)
142
+ plt.close(fig)
143
+ with open(tmpfile.name, "rb") as f:
144
+ return base64.b64encode(f.read()).decode('utf-8')
145
+
 
 
146
  except Exception as e:
147
+ plt.close()
148
+ raise ValueError(f"Chart generation failed: {str(e)}")
149
 
150
+ @app.post("/visualize")
151
  @limiter.limit("5/minute")
152
+ async def create_visualization(
153
  request: Request,
154
+ file: UploadFile = File(..., description="Excel file to visualize"),
155
+ chart_type: str = Form(..., description="Type of chart (line, bar, scatter, histogram)"),
156
+ x_column: str = Form(..., description="Column for x-axis"),
157
+ y_column: Optional[str] = Form(None, description="Column for y-axis"),
158
+ hue_column: Optional[str] = Form(None, description="Column for color grouping"),
159
+ title: Optional[str] = Form(None, description="Chart title"),
160
+ x_label: Optional[str] = Form(None, description="X-axis label"),
161
+ y_label: Optional[str] = Form(None, description="Y-axis label"),
162
+ style: str = Form("seaborn", description="Plot style (seaborn, ggplot, etc.)"),
163
+ filters: Optional[str] = Form(None, description="JSON string of filters to apply")
164
  ):
165
  try:
166
+ # Validate and read file
167
+ content = await validate_excel_file(file)
168
+ df = read_excel_with_fallback(content)
 
 
 
 
169
 
170
+ if df.empty:
171
+ raise ValueError("Excel file contains no data")
172
+
173
  # Parse filters if provided
174
  filter_dict = {}
175
  if filters:
176
  try:
177
+ filter_dict = eval(filters) if filters else {}
178
  if not isinstance(filter_dict, dict):
179
  filter_dict = {}
180
  except:
181
  filter_dict = {}
182
+
183
+ # Create visualization
184
  vis_request = VisualizationRequest(
185
  chart_type=chart_type,
186
  x_column=x_column,
 
193
  filters=filter_dict
194
  )
195
 
196
+ image_base64 = generate_visualization(df, vis_request)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  return {
199
  "status": "success",
200
  "image": f"data:image/png;base64,{image_base64}",
201
+ "columns": list(df.columns),
202
+ "filtered_data": df.to_dict(orient='records')
203
  }
 
 
 
 
 
 
204
 
205
+ except HTTPException as he:
206
+ raise he
207
+ except ValueError as ve:
208
+ logger.error(f"Validation error: {str(ve)}")
209
+ raise HTTPException(422, detail=str(ve))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  except Exception as e:
211
+ logger.error(f"Unexpected error: {str(e)}\n{traceback.format_exc()}")
212
+ raise HTTPException(500, detail="Failed to generate visualization")
213
 
214
  @app.post("/get_columns")
215
  @limiter.limit("10/minute")
216
+ async def get_columns(
217
  request: Request,
218
+ file: UploadFile = File(..., description="Excel file to analyze")
219
  ):
220
  try:
221
+ content = await validate_excel_file(file)
222
+ df = read_excel_with_fallback(content)
 
223
 
224
+ if df.empty:
225
+ raise ValueError("Excel file contains no data")
226
+
227
  return {
228
  "columns": list(df.columns),
229
+ "sample_data": df.head().replace({float('nan'): None}).to_dict(orient='records'),
230
  "statistics": df.describe().to_dict() if len(df.select_dtypes(include=['number']).columns) > 0 else None
231
  }
232
+
233
+ except HTTPException as he:
234
+ raise he
235
+ except ValueError as ve:
236
+ logger.error(f"Validation error: {str(ve)}")
237
+ raise HTTPException(422, detail=str(ve))
238
  except Exception as e:
239
+ logger.error(f"Unexpected error: {str(e)}\n{traceback.format_exc()}")
240
+ raise HTTPException(500, detail="Failed to process Excel file")
241
 
242
  @app.exception_handler(RateLimitExceeded)
243
+ async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
244
  return JSONResponse(
245
  status_code=429,
246
  content={"detail": "Too many requests. Please try again later."}