chenguittiMaroua commited on
Commit
8ea794b
·
verified ·
1 Parent(s): 051e65c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +537 -129
main.py CHANGED
@@ -1,28 +1,45 @@
1
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException
 
2
  from fastapi.responses import JSONResponse
 
 
 
 
 
 
 
 
 
 
 
 
3
  from slowapi import Limiter
4
  from slowapi.util import get_remote_address
5
  from slowapi.errors import RateLimitExceeded
6
- from fastapi.middleware.cors import CORSMiddleware
7
- from starlette.requests import Request
8
-
9
- import pytesseract
10
- from PIL import Image
11
- import fitz # PyMuPDF
12
- import docx
13
- import pptx
14
- import pandas as pd
15
- import io
16
-
17
- from transformers import pipeline
18
  import matplotlib.pyplot as plt
19
  import seaborn as sns
20
- import uuid
21
- import os
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  app = FastAPI()
24
 
25
- # CORS (optional, for frontend access)
 
 
 
 
26
  app.add_middleware(
27
  CORSMiddleware,
28
  allow_origins=["*"],
@@ -30,124 +47,515 @@ app.add_middleware(
30
  allow_headers=["*"],
31
  )
32
 
33
- # Rate Limiting
34
- limiter = Limiter(key_func=get_remote_address)
35
- app.state.limiter = limiter
 
 
36
 
37
- @app.exception_handler(RateLimitExceeded)
38
- async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
39
- return JSONResponse(
40
- status_code=429,
41
- content={"error": "Rate limit exceeded. Please try again later."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  )
43
 
44
- # Hugging Face Pipelines
45
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
46
- qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2")
47
- image_captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
48
-
49
- # Utility: Save image and return path
50
- def save_temp_image(upload: UploadFile):
51
- image_path = f"temp/{uuid.uuid4().hex}_{upload.filename}"
52
- with open(image_path, "wb") as f:
53
- f.write(upload.file.read())
54
- return image_path
55
-
56
- # --- File Parsing Utilities ---
57
- def extract_text_from_pdf(file_bytes: bytes) -> str:
58
- doc = fitz.open(stream=file_bytes, filetype="pdf")
59
- return "\n".join(page.get_text() for page in doc)
60
-
61
- def extract_text_from_docx(file_bytes: bytes) -> str:
62
- doc = docx.Document(io.BytesIO(file_bytes))
63
- return "\n".join(p.text for p in doc.paragraphs)
64
-
65
- def extract_text_from_pptx(file_bytes: bytes) -> str:
66
- prs = pptx.Presentation(io.BytesIO(file_bytes))
67
- text = ""
68
- for slide in prs.slides:
69
- for shape in slide.shapes:
70
- if hasattr(shape, "text"):
71
- text += shape.text + "\n"
72
- return text
73
-
74
- def extract_text_from_image(file_bytes: bytes) -> str:
75
- img = Image.open(io.BytesIO(file_bytes))
76
- return pytesseract.image_to_string(img)
77
-
78
- def extract_data_from_excel(file_bytes: bytes) -> pd.DataFrame:
79
- return pd.read_excel(io.BytesIO(file_bytes))
80
-
81
- # --- API Endpoints ---
82
- @app.post("/process/")
83
- @limiter.limit("10/minute")
84
- async def process_file(
85
  request: Request,
86
  file: UploadFile = File(...),
87
- task: str = Form(...),
88
- question: str = Form(None)
89
  ):
90
- content_type = file.content_type
91
- file_bytes = await file.read()
92
-
93
- # --- Task: Summarization or QA ---
94
- if task in ["summarization", "question_answering"]:
95
- if content_type == "application/pdf":
96
- text = extract_text_from_pdf(file_bytes)
97
- elif content_type in ["application/vnd.openxmlformats-officedocument.wordprocessingml.document"]:
98
- text = extract_text_from_docx(file_bytes)
99
- elif content_type in ["application/vnd.openxmlformats-officedocument.presentationml.presentation"]:
100
- text = extract_text_from_pptx(file_bytes)
101
- elif content_type in ["image/png", "image/jpeg"]:
102
- text = extract_text_from_image(file_bytes)
103
- else:
104
- raise HTTPException(status_code=400, detail="Unsupported file format for this task.")
105
-
106
- if task == "summarization":
107
- summary = summarizer(text[:3000])[0]["summary_text"] # truncate long text
108
- return {"summary": summary}
109
-
110
- if task == "question_answering":
111
- if not question:
112
- raise HTTPException(status_code=400, detail="Question is required for QA.")
113
- answer = qa_pipeline(question=question, context=text)
114
- return {"answer": answer["answer"]}
115
-
116
- # --- Task: Image Captioning ---
117
- elif task == "captioning":
118
- if content_type not in ["image/png", "image/jpeg"]:
119
- raise HTTPException(status_code=400, detail="Only image files supported for captioning.")
120
- image_path = save_temp_image(file)
121
- caption = image_captioner(image_path)[0]["generated_text"]
122
- os.remove(image_path)
123
- return {"caption": caption}
124
-
125
- # --- Task: Data Visualization ---
126
- elif task == "visualization":
127
- if content_type != "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
128
- raise HTTPException(status_code=400, detail="Only Excel files supported for visualization.")
129
- df = extract_data_from_excel(file_bytes)
130
-
131
- if df.empty:
132
- raise HTTPException(status_code=400, detail="No data found in Excel file.")
133
-
134
- # Example visualization: correlation heatmap
135
- numeric_df = df.select_dtypes(include="number")
136
- if numeric_df.empty:
137
- raise HTTPException(status_code=400, detail="No numeric data available for visualization.")
138
-
139
- plt.figure(figsize=(10, 6))
140
- sns.heatmap(numeric_df.corr(), annot=True, cmap="coolwarm")
141
- viz_path = f"temp/viz_{uuid.uuid4().hex}.png"
142
- plt.savefig(viz_path)
143
- plt.close()
144
-
145
- with open(viz_path, "rb") as img_file:
146
- img_bytes = img_file.read()
147
- os.remove(viz_path)
148
-
149
- return JSONResponse(content={"image_bytes": list(img_bytes)})
150
 
151
- else:
152
- raise HTTPException(status_code=400, detail="Unsupported task.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
2
+ from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import JSONResponse
4
+ from transformers import pipeline
5
+ from typing import Tuple, Optional
6
+ import io
7
+ import fitz # PyMuPDF
8
+ from PIL import Image
9
+ import pandas as pd
10
+ import uvicorn
11
+ from docx import Document
12
+ from pptx import Presentation
13
+ import pytesseract
14
+ import logging
15
+ import re
16
  from slowapi import Limiter
17
  from slowapi.util import get_remote_address
18
  from slowapi.errors import RateLimitExceeded
19
+ from slowapi.middleware import SlowAPIMiddleware
 
 
 
 
 
 
 
 
 
 
 
20
  import matplotlib.pyplot as plt
21
  import seaborn as sns
22
+ import tempfile
23
+ import base64
24
+ from io import BytesIO
25
+ from pydantic import BaseModel
26
+ import traceback
27
+ import ast
28
+
29
+ # Initialize rate limiter
30
+ limiter = Limiter(key_func=get_remote_address)
31
+
32
+ # Configure logging
33
+ logging.basicConfig(level=logging.INFO)
34
+ logger = logging.getLogger(__name__)
35
 
36
  app = FastAPI()
37
 
38
+ # Apply rate limiting middleware
39
+ app.state.limiter = limiter
40
+ app.add_middleware(SlowAPIMiddleware)
41
+
42
+ # CORS Configuration
43
  app.add_middleware(
44
  CORSMiddleware,
45
  allow_origins=["*"],
 
47
  allow_headers=["*"],
48
  )
49
 
50
+ # Constants
51
+ MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
52
+ SUPPORTED_FILE_TYPES = {
53
+ "docx", "xlsx", "pptx", "pdf", "jpg", "jpeg", "png"
54
+ }
55
 
56
+ # Model caching
57
+ summarizer = None
58
+ qa_model = None
59
+ image_captioner = None
60
+
61
+ def get_summarizer():
62
+ global summarizer
63
+ if summarizer is None:
64
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
65
+ return summarizer
66
+
67
+ def get_qa_model():
68
+ global qa_model
69
+ if qa_model is None:
70
+ qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
71
+ return qa_model
72
+
73
+ def get_image_captioner():
74
+ global image_captioner
75
+ if image_captioner is None:
76
+ image_captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
77
+ return image_captioner
78
+
79
+ async def process_uploaded_file(file: UploadFile) -> Tuple[str, bytes]:
80
+ """Validate and process uploaded file with special handling for each type"""
81
+ if not file.filename:
82
+ raise HTTPException(400, "No filename provided")
83
+
84
+ file_ext = file.filename.split('.')[-1].lower()
85
+ if file_ext not in SUPPORTED_FILE_TYPES:
86
+ raise HTTPException(400, f"Unsupported file type. Supported: {', '.join(SUPPORTED_FILE_TYPES)}")
87
+
88
+ content = await file.read()
89
+ if len(content) > MAX_FILE_SIZE:
90
+ raise HTTPException(413, f"File too large. Max size: {MAX_FILE_SIZE//1024//1024}MB")
91
+
92
+ # Special validation for PDFs
93
+ if file_ext == "pdf":
94
+ try:
95
+ with fitz.open(stream=content, filetype="pdf") as doc:
96
+ if doc.is_encrypted:
97
+ if not doc.authenticate(""):
98
+ raise ValueError("Encrypted PDF - cannot extract text")
99
+ if len(doc) > 50:
100
+ raise ValueError("PDF too large (max 50 pages)")
101
+ except Exception as e:
102
+ logger.error(f"PDF validation failed: {str(e)}")
103
+ raise HTTPException(422, detail=f"Invalid PDF file: {str(e)}")
104
+
105
+ await file.seek(0) # Reset file pointer for processing
106
+ return file_ext, content
107
+
108
+ def extract_text(content: bytes, file_ext: str) -> str:
109
+ """Extract text from various file formats with enhanced support"""
110
+ try:
111
+ if file_ext == "docx":
112
+ doc = Document(io.BytesIO(content))
113
+ return "\n".join(para.text for para in doc.paragraphs if para.text.strip())
114
+
115
+ elif file_ext in {"xlsx", "xls"}:
116
+ df = pd.read_excel(io.BytesIO(content), sheet_name=None)
117
+ all_text = []
118
+ for sheet_name, sheet_data in df.items():
119
+ sheet_text = []
120
+ for column in sheet_data.columns:
121
+ sheet_text.extend(sheet_data[column].dropna().astype(str).tolist())
122
+ all_text.append(f"Sheet: {sheet_name}\n" + "\n".join(sheet_text))
123
+ return "\n\n".join(all_text)
124
+
125
+ elif file_ext == "pptx":
126
+ ppt = Presentation(io.BytesIO(content))
127
+ text = []
128
+ for slide in ppt.slides:
129
+ for shape in slide.shapes:
130
+ if hasattr(shape, "text") and shape.text.strip():
131
+ text.append(shape.text)
132
+ return "\n".join(text)
133
+
134
+ elif file_ext == "pdf":
135
+ pdf = fitz.open(stream=content, filetype="pdf")
136
+ return "\n".join(page.get_text("text") for page in pdf)
137
+
138
+ elif file_ext in {"jpg", "jpeg", "png"}:
139
+ # First try OCR
140
+ try:
141
+ image = Image.open(io.BytesIO(content))
142
+ text = pytesseract.image_to_string(image, config='--psm 6')
143
+ if text.strip():
144
+ return text
145
+
146
+ # If OCR fails, try image captioning
147
+ captioner = get_image_captioner()
148
+ result = captioner(image)
149
+ return result[0]['generated_text']
150
+ except Exception as img_e:
151
+ logger.error(f"Image processing failed: {str(img_e)}")
152
+ raise ValueError("Could not extract text or caption from image")
153
+
154
+ except Exception as e:
155
+ logger.error(f"Text extraction failed for {file_ext}: {str(e)}")
156
+ raise HTTPException(422, f"Failed to extract text from {file_ext} file")
157
+
158
+ # Visualization Models
159
+ class VisualizationRequest(BaseModel):
160
+ chart_type: str
161
+ x_column: Optional[str] = None
162
+ y_column: Optional[str] = None
163
+ hue_column: Optional[str] = None
164
+ title: Optional[str] = None
165
+ x_label: Optional[str] = None
166
+ y_label: Optional[str] = None
167
+ style: str = "seaborn"
168
+ filters: Optional[dict] = None
169
+
170
+ class NaturalLanguageRequest(BaseModel):
171
+ prompt: str
172
+ style: str = "seaborn"
173
+
174
+ def generate_visualization_code(df: pd.DataFrame, request: VisualizationRequest) -> str:
175
+ """Generate Python code for visualization based on request parameters"""
176
+ code_lines = [
177
+ "import matplotlib.pyplot as plt",
178
+ "import seaborn as sns",
179
+ "import pandas as pd",
180
+ "",
181
+ "# Data preparation",
182
+ f"df = pd.DataFrame({df.to_dict(orient='list')})",
183
+ ]
184
+
185
+ # Apply filters if specified
186
+ if request.filters:
187
+ filter_conditions = []
188
+ for column, condition in request.filters.items():
189
+ if isinstance(condition, dict):
190
+ if 'min' in condition and 'max' in condition:
191
+ filter_conditions.append(f"(df['{column}'] >= {condition['min']}) & (df['{column}'] <= {condition['max']})")
192
+ elif 'values' in condition:
193
+ values = ', '.join([f"'{v}'" if isinstance(v, str) else str(v) for v in condition['values']])
194
+ filter_conditions.append(f"df['{column}'].isin([{values}])")
195
+ else:
196
+ filter_conditions.append(f"df['{column}'] == {repr(condition)}")
197
+
198
+ if filter_conditions:
199
+ code_lines.extend([
200
+ "",
201
+ "# Apply filters",
202
+ f"df = df[{' & '.join(filter_conditions)}]"
203
+ ])
204
+
205
+ code_lines.extend([
206
+ "",
207
+ "# Visualization",
208
+ f"plt.style.use('{request.style}')",
209
+ f"plt.figure(figsize=(10, 6))"
210
+ ])
211
+
212
+ # Chart type specific code
213
+ if request.chart_type == "line":
214
+ if request.hue_column:
215
+ code_lines.append(f"sns.lineplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
216
+ else:
217
+ code_lines.append(f"plt.plot(df['{request.x_column}'], df['{request.y_column}'])")
218
+ elif request.chart_type == "bar":
219
+ if request.hue_column:
220
+ code_lines.append(f"sns.barplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
221
+ else:
222
+ code_lines.append(f"plt.bar(df['{request.x_column}'], df['{request.y_column}'])")
223
+ elif request.chart_type == "scatter":
224
+ if request.hue_column:
225
+ code_lines.append(f"sns.scatterplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
226
+ else:
227
+ code_lines.append(f"plt.scatter(df['{request.x_column}'], df['{request.y_column}'])")
228
+ elif request.chart_type == "histogram":
229
+ code_lines.append(f"plt.hist(df['{request.x_column}'], bins=20)")
230
+ elif request.chart_type == "boxplot":
231
+ if request.hue_column:
232
+ code_lines.append(f"sns.boxplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
233
+ else:
234
+ code_lines.append(f"sns.boxplot(data=df, x='{request.x_column}', y='{request.y_column}')")
235
+ elif request.chart_type == "heatmap":
236
+ code_lines.append(f"corr = df.corr()")
237
+ code_lines.append(f"sns.heatmap(corr, annot=True, cmap='coolwarm')")
238
+ else:
239
+ raise ValueError(f"Unsupported chart type: {request.chart_type}")
240
+
241
+ # Add labels and title
242
+ if request.title:
243
+ code_lines.append(f"plt.title('{request.title}')")
244
+ if request.x_label:
245
+ code_lines.append(f"plt.xlabel('{request.x_label}')")
246
+ if request.y_label:
247
+ code_lines.append(f"plt.ylabel('{request.y_label}')")
248
+
249
+ code_lines.extend([
250
+ "plt.tight_layout()",
251
+ "plt.show()"
252
+ ])
253
+
254
+ return "\n".join(code_lines)
255
+
256
+ def interpret_natural_language(prompt: str, df_columns: list) -> VisualizationRequest:
257
+ """Convert natural language prompt to visualization parameters"""
258
+ # Simple keyword-based interpretation (could be enhanced with NLP)
259
+ prompt = prompt.lower()
260
+
261
+ # Determine chart type
262
+ chart_type = "bar"
263
+ if "line" in prompt:
264
+ chart_type = "line"
265
+ elif "scatter" in prompt:
266
+ chart_type = "scatter"
267
+ elif "histogram" in prompt:
268
+ chart_type = "histogram"
269
+ elif "box" in prompt:
270
+ chart_type = "boxplot"
271
+ elif "heatmap" in prompt or "correlation" in prompt:
272
+ chart_type = "heatmap"
273
+
274
+ # Try to detect columns
275
+ x_col = None
276
+ y_col = None
277
+ hue_col = None
278
+
279
+ for col in df_columns:
280
+ if col.lower() in prompt:
281
+ if not x_col:
282
+ x_col = col
283
+ elif not y_col:
284
+ y_col = col
285
+ else:
286
+ hue_col = col
287
+
288
+ # Default to first columns if not detected
289
+ if not x_col and len(df_columns) > 0:
290
+ x_col = df_columns[0]
291
+ if not y_col and len(df_columns) > 1:
292
+ y_col = df_columns[1]
293
+
294
+ return VisualizationRequest(
295
+ chart_type=chart_type,
296
+ x_column=x_col,
297
+ y_column=y_col,
298
+ hue_column=hue_col,
299
+ title="Generated from: " + prompt[:50] + ("..." if len(prompt) > 50 else ""),
300
+ style="seaborn"
301
  )
302
 
303
+ @app.post("/summarize")
304
+ @limiter.limit("5/minute")
305
+ async def summarize_document(request: Request, file: UploadFile = File(...)):
306
+ try:
307
+ file_ext, content = await process_uploaded_file(file)
308
+ text = extract_text(content, file_ext)
309
+
310
+ if not text.strip():
311
+ raise HTTPException(400, "No extractable text found")
312
+
313
+ # Clean and chunk text
314
+ text = re.sub(r'\s+', ' ', text).strip()
315
+ chunks = [text[i:i+1000] for i in range(0, len(text), 1000)]
316
+
317
+ # Summarize each chunk
318
+ summarizer = get_summarizer()
319
+ summaries = []
320
+ for chunk in chunks:
321
+ summary = summarizer(chunk, max_length=150, min_length=50, do_sample=False)[0]["summary_text"]
322
+ summaries.append(summary)
323
+
324
+ return {"summary": " ".join(summaries)}
325
+
326
+ except HTTPException:
327
+ raise
328
+ except Exception as e:
329
+ logger.error(f"Summarization failed: {str(e)}")
330
+ raise HTTPException(500, "Document summarization failed")
331
+
332
+ @app.post("/qa")
333
+ @limiter.limit("5/minute")
334
+ async def question_answering(
 
 
 
 
 
 
 
 
 
335
  request: Request,
336
  file: UploadFile = File(...),
337
+ question: str = Form(...),
338
+ language: str = Form("fr")
339
  ):
340
+ try:
341
+ file_ext, content = await process_uploaded_file(file)
342
+ text = extract_text(content, file_ext)
343
+
344
+ if not text.strip():
345
+ raise HTTPException(400, "No extractable text found")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
 
347
+ # Clean and truncate text
348
+ text = re.sub(r'\s+', ' ', text).strip()[:5000]
349
+
350
+ # Theme detection
351
+ theme_keywords = ["thème", "sujet principal", "quoi le sujet", "theme", "main topic"]
352
+ if any(kw in question.lower() for kw in theme_keywords):
353
+ try:
354
+ summarizer = get_summarizer()
355
+ summary_output = summarizer(
356
+ text,
357
+ max_length=min(100, len(text)//4),
358
+ min_length=30,
359
+ do_sample=False,
360
+ truncation=True
361
+ )
362
+
363
+ theme = summary_output[0].get("summary_text", text[:200] + "...")
364
+ return {
365
+ "question": question,
366
+ "answer": f"Le document traite principalement de : {theme}",
367
+ "confidence": 0.95,
368
+ "language": language
369
+ }
370
+ except Exception:
371
+ theme = text[:200] + ("..." if len(text) > 200 else "")
372
+ return {
373
+ "question": question,
374
+ "answer": f"D'après le document : {theme}",
375
+ "confidence": 0.7,
376
+ "language": language,
377
+ "warning": "theme_summary_fallback"
378
+ }
379
+
380
+ # Standard QA
381
+ qa = get_qa_model()
382
+ result = qa(question=question, context=text[:3000])
383
+
384
+ return {
385
+ "question": question,
386
+ "answer": result["answer"],
387
+ "confidence": result["score"],
388
+ "language": language
389
+ }
390
+
391
+ except HTTPException:
392
+ raise
393
+ except Exception as e:
394
+ logger.error(f"QA processing failed: {str(e)}")
395
+ raise HTTPException(500, detail=f"Analysis failed: {str(e)}")
396
+
397
+ @app.post("/visualize/code")
398
+ @limiter.limit("5/minute")
399
+ async def visualize_with_code(
400
+ request: Request,
401
+ file: UploadFile = File(...),
402
+ chart_type: str = Form(...),
403
+ x_column: Optional[str] = Form(None),
404
+ y_column: Optional[str] = Form(None),
405
+ hue_column: Optional[str] = Form(None),
406
+ title: Optional[str] = Form(None),
407
+ x_label: Optional[str] = Form(None),
408
+ y_label: Optional[str] = Form(None),
409
+ style: str = Form("seaborn"),
410
+ filters: Optional[str] = Form(None)
411
+ ):
412
+ try:
413
+ # Validate file
414
+ file_ext, content = await process_uploaded_file(file)
415
+ if file_ext not in {"xlsx", "xls"}:
416
+ raise HTTPException(400, "Only Excel files are supported for visualization")
417
+
418
+ # Read Excel file
419
+ df = pd.read_excel(io.BytesIO(content))
420
+
421
+ # Parse filters if provided
422
+ filter_dict = {}
423
+ if filters:
424
+ try:
425
+ filter_dict = ast.literal_eval(filters)
426
+ if not isinstance(filter_dict, dict):
427
+ filter_dict = {}
428
+ except:
429
+ filter_dict = {}
430
+
431
+ # Create visualization request
432
+ vis_request = VisualizationRequest(
433
+ chart_type=chart_type,
434
+ x_column=x_column,
435
+ y_column=y_column,
436
+ hue_column=hue_column,
437
+ title=title,
438
+ x_label=x_label,
439
+ y_label=y_label,
440
+ style=style,
441
+ filters=filter_dict
442
+ )
443
+
444
+ # Generate visualization code
445
+ visualization_code = generate_visualization_code(df, vis_request)
446
+
447
+ # Execute the code to generate the plot
448
+ plt.figure()
449
+ local_vars = {}
450
+ exec(visualization_code, globals(), local_vars)
451
+
452
+ # Save the plot to a temporary file
453
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
454
+ plt.savefig(tmpfile.name, format='png', dpi=300)
455
+ plt.close()
456
+
457
+ # Read the image back as bytes
458
+ with open(tmpfile.name, "rb") as f:
459
+ image_bytes = f.read()
460
+
461
+ # Encode image as base64
462
+ image_base64 = base64.b64encode(image_bytes).decode('utf-8')
463
+
464
+ return {
465
+ "status": "success",
466
+ "image": f"data:image/png;base64,{image_base64}",
467
+ "code": visualization_code,
468
+ "data_preview": df.head().to_dict(orient='records')
469
+ }
470
+
471
+ except HTTPException:
472
+ raise
473
+ except Exception as e:
474
+ logger.error(f"Visualization failed: {str(e)}\n{traceback.format_exc()}")
475
+ raise HTTPException(500, detail=f"Visualization failed: {str(e)}")
476
+
477
+ @app.post("/visualize/natural")
478
+ @limiter.limit("5/minute")
479
+ async def visualize_with_natural_language(
480
+ request: Request,
481
+ file: UploadFile = File(...),
482
+ prompt: str = Form(...),
483
+ style: str = Form("seaborn")
484
+ ):
485
+ try:
486
+ # Validate file
487
+ file_ext, content = await process_uploaded_file(file)
488
+ if file_ext not in {"xlsx", "xls"}:
489
+ raise HTTPException(400, "Only Excel files are supported for visualization")
490
+
491
+ # Read Excel file
492
+ df = pd.read_excel(io.BytesIO(content))
493
+
494
+ # Convert natural language to visualization parameters
495
+ nl_request = NaturalLanguageRequest(prompt=prompt, style=style)
496
+ vis_request = interpret_natural_language(nl_request.prompt, df.columns.tolist())
497
+
498
+ # Generate visualization code
499
+ visualization_code = generate_visualization_code(df, vis_request)
500
+
501
+ # Execute the code to generate the plot
502
+ plt.figure()
503
+ local_vars = {}
504
+ exec(visualization_code, globals(), local_vars)
505
+
506
+ # Save the plot to a temporary file
507
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
508
+ plt.savefig(tmpfile.name, format='png', dpi=300)
509
+ plt.close()
510
+
511
+ # Read the image back as bytes
512
+ with open(tmpfile.name, "rb") as f:
513
+ image_bytes = f.read()
514
+
515
+ # Encode image as base64
516
+ image_base64 = base64.b64encode(image_bytes).decode('utf-8')
517
+
518
+ return {
519
+ "status": "success",
520
+ "image": f"data:image/png;base64,{image_base64}",
521
+ "code": visualization_code,
522
+ "interpreted_parameters": vis_request.dict(),
523
+ "data_preview": df.head().to_dict(orient='records')
524
+ }
525
+
526
+ except HTTPException:
527
+ raise
528
+ except Exception as e:
529
+ logger.error(f"Natural language visualization failed: {str(e)}\n{traceback.format_exc()}")
530
+ raise HTTPException(500, detail=f"Visualization failed: {str(e)}")
531
+
532
+ @app.post("/get_columns")
533
+ @limiter.limit("10/minute")
534
+ async def get_excel_columns(
535
+ request: Request,
536
+ file: UploadFile = File(...)
537
+ ):
538
+ try:
539
+ file_ext, content = await process_uploaded_file(file)
540
+ if file_ext not in {"xlsx", "xls"}:
541
+ raise HTTPException(400, "Only Excel files are supported")
542
+
543
+ df = pd.read_excel(io.BytesIO(content))
544
+ return {
545
+ "columns": list(df.columns),
546
+ "sample_data": df.head().to_dict(orient='records'),
547
+ "statistics": df.describe().to_dict() if len(df.select_dtypes(include=['number']).columns) > 0 else None
548
+ }
549
+ except Exception as e:
550
+ logger.error(f"Column extraction failed: {str(e)}")
551
+ raise HTTPException(500, detail="Failed to extract columns from Excel file")
552
+
553
+ @app.exception_handler(RateLimitExceeded)
554
+ async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
555
+ return JSONResponse(
556
+ status_code=429,
557
+ content={"detail": "Too many requests. Please try again later."}
558
+ )
559
 
560
+ if __name__ == "__main__":
561
+ uvicorn.run(app, host="0.0.0.0", port=7860)