chenguittiMaroua commited on
Commit
118cebd
·
verified ·
1 Parent(s): 74fd655

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +462 -151
main.py CHANGED
@@ -1,31 +1,45 @@
1
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import JSONResponse
4
- from pydantic import BaseModel
5
- from typing import Optional, Dict, List
6
- import pandas as pd
7
- import matplotlib.pyplot as plt
8
- import seaborn as sns
9
- import base64
10
- import tempfile
11
  import io
12
- import logging
13
- import traceback
 
14
  import uvicorn
 
 
 
 
 
15
  from slowapi import Limiter
16
  from slowapi.util import get_remote_address
17
  from slowapi.errors import RateLimitExceeded
18
  from slowapi.middleware import SlowAPIMiddleware
 
 
 
 
 
 
 
 
19
 
20
- # Initialize FastAPI app
21
- app = FastAPI(title="Data Visualization API", version="1.0")
22
-
23
- # Rate limiting setup
24
  limiter = Limiter(key_func=get_remote_address)
 
 
 
 
 
 
 
 
25
  app.state.limiter = limiter
26
  app.add_middleware(SlowAPIMiddleware)
27
 
28
- # CORS configuration
29
  app.add_middleware(
30
  CORSMiddleware,
31
  allow_origins=["*"],
@@ -35,152 +49,386 @@ app.add_middleware(
35
 
36
  # Constants
37
  MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
38
- SUPPORTED_EXCEL_TYPES = {"xlsx": "Excel Workbook", "xls": "Excel 97-2003 Workbook"}
 
 
39
 
40
- # Configure logging
41
- logging.basicConfig(level=logging.INFO)
42
- logger = logging.getLogger(__name__)
 
43
 
44
- class VisualizationRequest(BaseModel):
45
- chart_type: str
46
- x_column: str
47
- y_column: Optional[str] = None
48
- hue_column: Optional[str] = None
49
- title: Optional[str] = None
50
- x_label: Optional[str] = None
51
- y_label: Optional[str] = None
52
- style: str = "seaborn"
53
- filters: Optional[Dict] = None
 
54
 
55
- async def validate_excel_file(file: UploadFile) -> bytes:
56
- """Validate and process uploaded Excel file with clear error messages"""
 
 
 
 
 
 
57
  if not file.filename:
58
  raise HTTPException(400, "No filename provided")
59
 
60
  file_ext = file.filename.split('.')[-1].lower()
61
- if file_ext not in SUPPORTED_EXCEL_TYPES:
62
- supported = ", ".join([f"{ext} ({desc})" for ext, desc in SUPPORTED_EXCEL_TYPES.items()])
63
- raise HTTPException(
64
- 400,
65
- f"Unsupported file type '{file.filename}'. Please upload: {supported}"
66
- )
67
 
68
  content = await file.read()
69
  if len(content) > MAX_FILE_SIZE:
70
- raise HTTPException(413, f"File too large ({len(content)/1024/1024:.1f}MB). Max size: {MAX_FILE_SIZE//1024//1024}MB")
71
 
72
- await file.seek(0)
73
- return content
74
-
75
- def read_excel_with_fallback(content: bytes) -> pd.DataFrame:
76
- """Read Excel file with engine fallback and better error handling"""
77
- try:
78
- # Try openpyxl first (for .xlsx)
79
- return pd.read_excel(io.BytesIO(content), engine='openpyxl')
80
- except Exception as e:
81
- logger.warning(f"Openpyxl failed, trying xlrd: {str(e)}")
82
  try:
83
- # Fallback to xlrd (for .xls)
84
- return pd.read_excel(io.BytesIO(content), engine='xlrd')
 
 
 
 
85
  except Exception as e:
86
- raise ValueError(f"Failed to read Excel file with either engine. Error: {str(e)}")
 
 
 
 
87
 
88
- def generate_visualization(df: pd.DataFrame, request: VisualizationRequest) -> str:
89
- """Generate and save visualization with proper resource cleanup"""
90
  try:
91
- plt.style.use(request.style)
92
- fig, ax = plt.subplots(figsize=(10, 6))
93
-
94
- # Apply filters if specified
95
- if request.filters:
96
- for col, condition in request.filters.items():
97
- if isinstance(condition, dict):
98
- if 'min' in condition and 'max' in condition:
99
- df = df[(df[col] >= condition['min']) & (df[col] <= condition['max'])]
100
- elif 'values' in condition:
101
- df = df[df[col].isin(condition['values'])]
102
- else:
103
- df = df[df[col] == condition]
104
-
105
- # Generate chart based on type
106
- if request.chart_type == "line":
107
- if request.hue_column:
108
- sns.lineplot(data=df, x=request.x_column, y=request.y_column,
109
- hue=request.hue_column, ax=ax)
110
- else:
111
- ax.plot(df[request.x_column], df[request.y_column])
112
- elif request.chart_type == "bar":
113
- if request.hue_column:
114
- sns.barplot(data=df, x=request.x_column, y=request.y_column,
115
- hue=request.hue_column, ax=ax)
116
- else:
117
- ax.bar(df[request.x_column], df[request.y_column])
118
- elif request.chart_type == "scatter":
119
- if request.hue_column:
120
- sns.scatterplot(data=df, x=request.x_column, y=request.y_column,
121
- hue=request.hue_column, ax=ax)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  else:
123
- ax.scatter(df[request.x_column], df[request.y_column])
124
- elif request.chart_type == "histogram":
125
- ax.hist(df[request.x_column], bins=20)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  else:
127
- raise ValueError(f"Unsupported chart type: {request.chart_type}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- # Add labels and title
130
- if request.title:
131
- ax.set_title(request.title)
132
- if request.x_label:
133
- ax.set_xlabel(request.x_label)
134
- if request.y_label:
135
- ax.set_ylabel(request.y_label)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
- plt.tight_layout()
 
 
 
 
 
138
 
139
- # Save to temporary file
140
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
141
- fig.savefig(tmpfile.name, format='png', dpi=300)
142
- plt.close(fig)
143
- with open(tmpfile.name, "rb") as f:
144
- return base64.b64encode(f.read()).decode('utf-8')
145
-
 
 
 
 
 
 
 
 
 
 
 
146
  except Exception as e:
147
- plt.close()
148
- raise ValueError(f"Chart generation failed: {str(e)}")
149
 
150
- @app.post("/visualize")
151
  @limiter.limit("5/minute")
152
- async def create_visualization(
153
  request: Request,
154
- file: UploadFile = File(..., description="Excel file to visualize"),
155
- chart_type: str = Form(..., description="Type of chart (line, bar, scatter, histogram)"),
156
- x_column: str = Form(..., description="Column for x-axis"),
157
- y_column: Optional[str] = Form(None, description="Column for y-axis"),
158
- hue_column: Optional[str] = Form(None, description="Column for color grouping"),
159
- title: Optional[str] = Form(None, description="Chart title"),
160
- x_label: Optional[str] = Form(None, description="X-axis label"),
161
- y_label: Optional[str] = Form(None, description="Y-axis label"),
162
- style: str = Form("seaborn", description="Plot style (seaborn, ggplot, etc.)"),
163
- filters: Optional[str] = Form(None, description="JSON string of filters to apply")
164
  ):
165
  try:
166
- # Validate and read file
167
- content = await validate_excel_file(file)
168
- df = read_excel_with_fallback(content)
169
 
170
- if df.empty:
171
- raise ValueError("Excel file contains no data")
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  # Parse filters if provided
174
  filter_dict = {}
175
  if filters:
176
  try:
177
- filter_dict = eval(filters) if filters else {}
178
  if not isinstance(filter_dict, dict):
179
  filter_dict = {}
180
  except:
181
  filter_dict = {}
182
-
183
- # Create visualization
184
  vis_request = VisualizationRequest(
185
  chart_type=chart_type,
186
  x_column=x_column,
@@ -193,54 +441,117 @@ async def create_visualization(
193
  filters=filter_dict
194
  )
195
 
196
- image_base64 = generate_visualization(df, vis_request)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  return {
199
  "status": "success",
200
  "image": f"data:image/png;base64,{image_base64}",
201
- "columns": list(df.columns),
202
- "filtered_data": df.to_dict(orient='records')
203
  }
 
 
 
 
 
 
204
 
205
- except HTTPException as he:
206
- raise he
207
- except ValueError as ve:
208
- logger.error(f"Validation error: {str(ve)}")
209
- raise HTTPException(422, detail=str(ve))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  except Exception as e:
211
- logger.error(f"Unexpected error: {str(e)}\n{traceback.format_exc()}")
212
- raise HTTPException(500, detail="Failed to generate visualization")
213
 
214
  @app.post("/get_columns")
215
  @limiter.limit("10/minute")
216
- async def get_columns(
217
  request: Request,
218
- file: UploadFile = File(..., description="Excel file to analyze")
219
  ):
220
  try:
221
- content = await validate_excel_file(file)
222
- df = read_excel_with_fallback(content)
 
223
 
224
- if df.empty:
225
- raise ValueError("Excel file contains no data")
226
-
227
  return {
228
  "columns": list(df.columns),
229
- "sample_data": df.head().replace({float('nan'): None}).to_dict(orient='records'),
230
  "statistics": df.describe().to_dict() if len(df.select_dtypes(include=['number']).columns) > 0 else None
231
  }
232
-
233
- except HTTPException as he:
234
- raise he
235
- except ValueError as ve:
236
- logger.error(f"Validation error: {str(ve)}")
237
- raise HTTPException(422, detail=str(ve))
238
  except Exception as e:
239
- logger.error(f"Unexpected error: {str(e)}\n{traceback.format_exc()}")
240
- raise HTTPException(500, detail="Failed to process Excel file")
241
 
242
  @app.exception_handler(RateLimitExceeded)
243
- async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
244
  return JSONResponse(
245
  status_code=429,
246
  content={"detail": "Too many requests. Please try again later."}
 
1
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import JSONResponse
4
+ from transformers import pipeline
5
+ from typing import Tuple, Optional
 
 
 
 
 
6
  import io
7
+ import fitz # PyMuPDF
8
+ from PIL import Image
9
+ import pandas as pd
10
  import uvicorn
11
+ from docx import Document
12
+ from pptx import Presentation
13
+ import pytesseract
14
+ import logging
15
+ import re
16
  from slowapi import Limiter
17
  from slowapi.util import get_remote_address
18
  from slowapi.errors import RateLimitExceeded
19
  from slowapi.middleware import SlowAPIMiddleware
20
+ import matplotlib.pyplot as plt
21
+ import seaborn as sns
22
+ import tempfile
23
+ import base64
24
+ from io import BytesIO
25
+ from pydantic import BaseModel
26
+ import traceback
27
+ import ast
28
 
29
+ # Initialize rate limiter
 
 
 
30
  limiter = Limiter(key_func=get_remote_address)
31
+
32
+ # Configure logging
33
+ logging.basicConfig(level=logging.INFO)
34
+ logger = logging.getLogger(__name__)
35
+
36
+ app = FastAPI()
37
+
38
+ # Apply rate limiting middleware
39
  app.state.limiter = limiter
40
  app.add_middleware(SlowAPIMiddleware)
41
 
42
+ # CORS Configuration
43
  app.add_middleware(
44
  CORSMiddleware,
45
  allow_origins=["*"],
 
49
 
50
  # Constants
51
  MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
52
+ SUPPORTED_FILE_TYPES = {
53
+ "docx", "xlsx", "pptx", "pdf", "jpg", "jpeg", "png"
54
+ }
55
 
56
+ # Model caching
57
+ summarizer = None
58
+ qa_model = None
59
+ image_captioner = None
60
 
61
+ def get_summarizer():
62
+ global summarizer
63
+ if summarizer is None:
64
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
65
+ return summarizer
66
+
67
+ def get_qa_model():
68
+ global qa_model
69
+ if qa_model is None:
70
+ qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
71
+ return qa_model
72
 
73
+ def get_image_captioner():
74
+ global image_captioner
75
+ if image_captioner is None:
76
+ image_captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
77
+ return image_captioner
78
+
79
+ async def process_uploaded_file(file: UploadFile) -> Tuple[str, bytes]:
80
+ """Validate and process uploaded file with special handling for each type"""
81
  if not file.filename:
82
  raise HTTPException(400, "No filename provided")
83
 
84
  file_ext = file.filename.split('.')[-1].lower()
85
+ if file_ext not in SUPPORTED_FILE_TYPES:
86
+ raise HTTPException(400, f"Unsupported file type. Supported: {', '.join(SUPPORTED_FILE_TYPES)}")
 
 
 
 
87
 
88
  content = await file.read()
89
  if len(content) > MAX_FILE_SIZE:
90
+ raise HTTPException(413, f"File too large. Max size: {MAX_FILE_SIZE//1024//1024}MB")
91
 
92
+ # Special validation for PDFs
93
+ if file_ext == "pdf":
 
 
 
 
 
 
 
 
94
  try:
95
+ with fitz.open(stream=content, filetype="pdf") as doc:
96
+ if doc.is_encrypted:
97
+ if not doc.authenticate(""):
98
+ raise ValueError("Encrypted PDF - cannot extract text")
99
+ if len(doc) > 50:
100
+ raise ValueError("PDF too large (max 50 pages)")
101
  except Exception as e:
102
+ logger.error(f"PDF validation failed: {str(e)}")
103
+ raise HTTPException(422, detail=f"Invalid PDF file: {str(e)}")
104
+
105
+ await file.seek(0) # Reset file pointer for processing
106
+ return file_ext, content
107
 
108
+ def extract_text(content: bytes, file_ext: str) -> str:
109
+ """Extract text from various file formats with enhanced support"""
110
  try:
111
+ if file_ext == "docx":
112
+ doc = Document(io.BytesIO(content))
113
+ return "\n".join(para.text for para in doc.paragraphs if para.text.strip())
114
+
115
+ elif file_ext in {"xlsx", "xls"}:
116
+ df = pd.read_excel(io.BytesIO(content), sheet_name=None)
117
+ all_text = []
118
+ for sheet_name, sheet_data in df.items():
119
+ sheet_text = []
120
+ for column in sheet_data.columns:
121
+ sheet_text.extend(sheet_data[column].dropna().astype(str).tolist())
122
+ all_text.append(f"Sheet: {sheet_name}\n" + "\n".join(sheet_text))
123
+ return "\n\n".join(all_text)
124
+
125
+ elif file_ext == "pptx":
126
+ ppt = Presentation(io.BytesIO(content))
127
+ text = []
128
+ for slide in ppt.slides:
129
+ for shape in slide.shapes:
130
+ if hasattr(shape, "text") and shape.text.strip():
131
+ text.append(shape.text)
132
+ return "\n".join(text)
133
+
134
+ elif file_ext == "pdf":
135
+ pdf = fitz.open(stream=content, filetype="pdf")
136
+ return "\n".join(page.get_text("text") for page in pdf)
137
+
138
+ elif file_ext in {"jpg", "jpeg", "png"}:
139
+ # First try OCR
140
+ try:
141
+ image = Image.open(io.BytesIO(content))
142
+ text = pytesseract.image_to_string(image, config='--psm 6')
143
+ if text.strip():
144
+ return text
145
+
146
+ # If OCR fails, try image captioning
147
+ captioner = get_image_captioner()
148
+ result = captioner(image)
149
+ return result[0]['generated_text']
150
+ except Exception as img_e:
151
+ logger.error(f"Image processing failed: {str(img_e)}")
152
+ raise ValueError("Could not extract text or caption from image")
153
+
154
+ except Exception as e:
155
+ logger.error(f"Text extraction failed for {file_ext}: {str(e)}")
156
+ raise HTTPException(422, f"Failed to extract text from {file_ext} file")
157
+
158
+ # Visualization Models
159
+ class VisualizationRequest(BaseModel):
160
+ chart_type: str
161
+ x_column: Optional[str] = None
162
+ y_column: Optional[str] = None
163
+ hue_column: Optional[str] = None
164
+ title: Optional[str] = None
165
+ x_label: Optional[str] = None
166
+ y_label: Optional[str] = None
167
+ style: str = "seaborn"
168
+ filters: Optional[dict] = None
169
+
170
+ class NaturalLanguageRequest(BaseModel):
171
+ prompt: str
172
+ style: str = "seaborn"
173
+
174
+ def generate_visualization_code(df: pd.DataFrame, request: VisualizationRequest) -> str:
175
+ """Generate Python code for visualization based on request parameters"""
176
+ code_lines = [
177
+ "import matplotlib.pyplot as plt",
178
+ "import seaborn as sns",
179
+ "import pandas as pd",
180
+ "",
181
+ "# Data preparation",
182
+ f"df = pd.DataFrame({df.to_dict(orient='list')})",
183
+ ]
184
+
185
+ # Apply filters if specified
186
+ if request.filters:
187
+ filter_conditions = []
188
+ for column, condition in request.filters.items():
189
+ if isinstance(condition, dict):
190
+ if 'min' in condition and 'max' in condition:
191
+ filter_conditions.append(f"(df['{column}'] >= {condition['min']}) & (df['{column}'] <= {condition['max']})")
192
+ elif 'values' in condition:
193
+ values = ', '.join([f"'{v}'" if isinstance(v, str) else str(v) for v in condition['values']])
194
+ filter_conditions.append(f"df['{column}'].isin([{values}])")
195
  else:
196
+ filter_conditions.append(f"df['{column}'] == {repr(condition)}")
197
+
198
+ if filter_conditions:
199
+ code_lines.extend([
200
+ "",
201
+ "# Apply filters",
202
+ f"df = df[{' & '.join(filter_conditions)}]"
203
+ ])
204
+
205
+ code_lines.extend([
206
+ "",
207
+ "# Visualization",
208
+ f"plt.style.use('{request.style}')",
209
+ f"plt.figure(figsize=(10, 6))"
210
+ ])
211
+
212
+ # Chart type specific code
213
+ if request.chart_type == "line":
214
+ if request.hue_column:
215
+ code_lines.append(f"sns.lineplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
216
+ else:
217
+ code_lines.append(f"plt.plot(df['{request.x_column}'], df['{request.y_column}'])")
218
+ elif request.chart_type == "bar":
219
+ if request.hue_column:
220
+ code_lines.append(f"sns.barplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
221
+ else:
222
+ code_lines.append(f"plt.bar(df['{request.x_column}'], df['{request.y_column}'])")
223
+ elif request.chart_type == "scatter":
224
+ if request.hue_column:
225
+ code_lines.append(f"sns.scatterplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
226
  else:
227
+ code_lines.append(f"plt.scatter(df['{request.x_column}'], df['{request.y_column}'])")
228
+ elif request.chart_type == "histogram":
229
+ code_lines.append(f"plt.hist(df['{request.x_column}'], bins=20)")
230
+ elif request.chart_type == "boxplot":
231
+ if request.hue_column:
232
+ code_lines.append(f"sns.boxplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
233
+ else:
234
+ code_lines.append(f"sns.boxplot(data=df, x='{request.x_column}', y='{request.y_column}')")
235
+ elif request.chart_type == "heatmap":
236
+ code_lines.append(f"corr = df.corr()")
237
+ code_lines.append(f"sns.heatmap(corr, annot=True, cmap='coolwarm')")
238
+ else:
239
+ raise ValueError(f"Unsupported chart type: {request.chart_type}")
240
+
241
+ # Add labels and title
242
+ if request.title:
243
+ code_lines.append(f"plt.title('{request.title}')")
244
+ if request.x_label:
245
+ code_lines.append(f"plt.xlabel('{request.x_label}')")
246
+ if request.y_label:
247
+ code_lines.append(f"plt.ylabel('{request.y_label}')")
248
+
249
+ code_lines.extend([
250
+ "plt.tight_layout()",
251
+ "plt.show()"
252
+ ])
253
+
254
+ return "\n".join(code_lines)
255
 
256
+ def interpret_natural_language(prompt: str, df_columns: list) -> VisualizationRequest:
257
+ """Convert natural language prompt to visualization parameters"""
258
+ # Simple keyword-based interpretation (could be enhanced with NLP)
259
+ prompt = prompt.lower()
260
+
261
+ # Determine chart type
262
+ chart_type = "bar"
263
+ if "line" in prompt:
264
+ chart_type = "line"
265
+ elif "scatter" in prompt:
266
+ chart_type = "scatter"
267
+ elif "histogram" in prompt:
268
+ chart_type = "histogram"
269
+ elif "box" in prompt:
270
+ chart_type = "boxplot"
271
+ elif "heatmap" in prompt or "correlation" in prompt:
272
+ chart_type = "heatmap"
273
+
274
+ # Try to detect columns
275
+ x_col = None
276
+ y_col = None
277
+ hue_col = None
278
+
279
+ for col in df_columns:
280
+ if col.lower() in prompt:
281
+ if not x_col:
282
+ x_col = col
283
+ elif not y_col:
284
+ y_col = col
285
+ else:
286
+ hue_col = col
287
+
288
+ # Default to first columns if not detected
289
+ if not x_col and len(df_columns) > 0:
290
+ x_col = df_columns[0]
291
+ if not y_col and len(df_columns) > 1:
292
+ y_col = df_columns[1]
293
+
294
+ return VisualizationRequest(
295
+ chart_type=chart_type,
296
+ x_column=x_col,
297
+ y_column=y_col,
298
+ hue_column=hue_col,
299
+ title="Generated from: " + prompt[:50] + ("..." if len(prompt) > 50 else ""),
300
+ style="seaborn"
301
+ )
302
 
303
+ @app.post("/summarize")
304
+ @limiter.limit("5/minute")
305
+ async def summarize_document(request: Request, file: UploadFile = File(...)):
306
+ try:
307
+ file_ext, content = await process_uploaded_file(file)
308
+ text = extract_text(content, file_ext)
309
 
310
+ if not text.strip():
311
+ raise HTTPException(400, "No extractable text found")
312
+
313
+ # Clean and chunk text
314
+ text = re.sub(r'\s+', ' ', text).strip()
315
+ chunks = [text[i:i+1000] for i in range(0, len(text), 1000)]
316
+
317
+ # Summarize each chunk
318
+ summarizer = get_summarizer()
319
+ summaries = []
320
+ for chunk in chunks:
321
+ summary = summarizer(chunk, max_length=150, min_length=50, do_sample=False)[0]["summary_text"]
322
+ summaries.append(summary)
323
+
324
+ return {"summary": " ".join(summaries)}
325
+
326
+ except HTTPException:
327
+ raise
328
  except Exception as e:
329
+ logger.error(f"Summarization failed: {str(e)}")
330
+ raise HTTPException(500, "Document summarization failed")
331
 
332
+ @app.post("/qa")
333
  @limiter.limit("5/minute")
334
+ async def question_answering(
335
  request: Request,
336
+ file: UploadFile = File(...),
337
+ question: str = Form(...),
338
+ language: str = Form("fr")
 
 
 
 
 
 
 
339
  ):
340
  try:
341
+ file_ext, content = await process_uploaded_file(file)
342
+ text = extract_text(content, file_ext)
 
343
 
344
+ if not text.strip():
345
+ raise HTTPException(400, "No extractable text found")
346
 
347
+ # Clean and truncate text
348
+ text = re.sub(r'\s+', ' ', text).strip()[:5000]
349
+
350
+ # Theme detection
351
+ theme_keywords = ["thème", "sujet principal", "quoi le sujet", "theme", "main topic"]
352
+ if any(kw in question.lower() for kw in theme_keywords):
353
+ try:
354
+ summarizer = get_summarizer()
355
+ summary_output = summarizer(
356
+ text,
357
+ max_length=min(100, len(text)//4),
358
+ min_length=30,
359
+ do_sample=False,
360
+ truncation=True
361
+ )
362
+
363
+ theme = summary_output[0].get("summary_text", text[:200] + "...")
364
+ return {
365
+ "question": question,
366
+ "answer": f"Le document traite principalement de : {theme}",
367
+ "confidence": 0.95,
368
+ "language": language
369
+ }
370
+ except Exception:
371
+ theme = text[:200] + ("..." if len(text) > 200 else "")
372
+ return {
373
+ "question": question,
374
+ "answer": f"D'après le document : {theme}",
375
+ "confidence": 0.7,
376
+ "language": language,
377
+ "warning": "theme_summary_fallback"
378
+ }
379
+
380
+ # Standard QA
381
+ qa = get_qa_model()
382
+ result = qa(question=question, context=text[:3000])
383
+
384
+ return {
385
+ "question": question,
386
+ "answer": result["answer"],
387
+ "confidence": result["score"],
388
+ "language": language
389
+ }
390
+
391
+ except HTTPException:
392
+ raise
393
+ except Exception as e:
394
+ logger.error(f"QA processing failed: {str(e)}")
395
+ raise HTTPException(500, detail=f"Analysis failed: {str(e)}")
396
+
397
+ @app.post("/visualize/code")
398
+ @limiter.limit("5/minute")
399
+ async def visualize_with_code(
400
+ request: Request,
401
+ file: UploadFile = File(...),
402
+ chart_type: str = Form(...),
403
+ x_column: Optional[str] = Form(None),
404
+ y_column: Optional[str] = Form(None),
405
+ hue_column: Optional[str] = Form(None),
406
+ title: Optional[str] = Form(None),
407
+ x_label: Optional[str] = Form(None),
408
+ y_label: Optional[str] = Form(None),
409
+ style: str = Form("seaborn"),
410
+ filters: Optional[str] = Form(None)
411
+ ):
412
+ try:
413
+ # Validate file
414
+ file_ext, content = await process_uploaded_file(file)
415
+ if file_ext not in {"xlsx", "xls"}:
416
+ raise HTTPException(400, "Only Excel files are supported for visualization")
417
+
418
+ # Read Excel file
419
+ df = pd.read_excel(io.BytesIO(content))
420
+
421
  # Parse filters if provided
422
  filter_dict = {}
423
  if filters:
424
  try:
425
+ filter_dict = ast.literal_eval(filters)
426
  if not isinstance(filter_dict, dict):
427
  filter_dict = {}
428
  except:
429
  filter_dict = {}
430
+
431
+ # Create visualization request
432
  vis_request = VisualizationRequest(
433
  chart_type=chart_type,
434
  x_column=x_column,
 
441
  filters=filter_dict
442
  )
443
 
444
+ # Generate visualization code
445
+ visualization_code = generate_visualization_code(df, vis_request)
446
+
447
+ # Execute the code to generate the plot
448
+ plt.figure()
449
+ local_vars = {}
450
+ exec(visualization_code, globals(), local_vars)
451
+
452
+ # Save the plot to a temporary file
453
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
454
+ plt.savefig(tmpfile.name, format='png', dpi=300)
455
+ plt.close()
456
+
457
+ # Read the image back as bytes
458
+ with open(tmpfile.name, "rb") as f:
459
+ image_bytes = f.read()
460
+
461
+ # Encode image as base64
462
+ image_base64 = base64.b64encode(image_bytes).decode('utf-8')
463
 
464
  return {
465
  "status": "success",
466
  "image": f"data:image/png;base64,{image_base64}",
467
+ "code": visualization_code,
468
+ "data_preview": df.head().to_dict(orient='records')
469
  }
470
+
471
+ except HTTPException:
472
+ raise
473
+ except Exception as e:
474
+ logger.error(f"Visualization failed: {str(e)}\n{traceback.format_exc()}")
475
+ raise HTTPException(500, detail=f"Visualization failed: {str(e)}")
476
 
477
+ @app.post("/visualize/natural")
478
+ @limiter.limit("5/minute")
479
+ async def visualize_with_natural_language(
480
+ request: Request,
481
+ file: UploadFile = File(...),
482
+ prompt: str = Form(...),
483
+ style: str = Form("seaborn")
484
+ ):
485
+ try:
486
+ # Validate file
487
+ file_ext, content = await process_uploaded_file(file)
488
+ if file_ext not in {"xlsx", "xls"}:
489
+ raise HTTPException(400, "Only Excel files are supported for visualization")
490
+
491
+ # Read Excel file
492
+ df = pd.read_excel(io.BytesIO(content))
493
+
494
+ # Convert natural language to visualization parameters
495
+ nl_request = NaturalLanguageRequest(prompt=prompt, style=style)
496
+ vis_request = interpret_natural_language(nl_request.prompt, df.columns.tolist())
497
+
498
+ # Generate visualization code
499
+ visualization_code = generate_visualization_code(df, vis_request)
500
+
501
+ # Execute the code to generate the plot
502
+ plt.figure()
503
+ local_vars = {}
504
+ exec(visualization_code, globals(), local_vars)
505
+
506
+ # Save the plot to a temporary file
507
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
508
+ plt.savefig(tmpfile.name, format='png', dpi=300)
509
+ plt.close()
510
+
511
+ # Read the image back as bytes
512
+ with open(tmpfile.name, "rb") as f:
513
+ image_bytes = f.read()
514
+
515
+ # Encode image as base64
516
+ image_base64 = base64.b64encode(image_bytes).decode('utf-8')
517
+
518
+ return {
519
+ "status": "success",
520
+ "image": f"data:image/png;base64,{image_base64}",
521
+ "code": visualization_code,
522
+ "interpreted_parameters": vis_request.dict(),
523
+ "data_preview": df.head().to_dict(orient='records')
524
+ }
525
+
526
+ except HTTPException:
527
+ raise
528
  except Exception as e:
529
+ logger.error(f"Natural language visualization failed: {str(e)}\n{traceback.format_exc()}")
530
+ raise HTTPException(500, detail=f"Visualization failed: {str(e)}")
531
 
532
  @app.post("/get_columns")
533
  @limiter.limit("10/minute")
534
+ async def get_excel_columns(
535
  request: Request,
536
+ file: UploadFile = File(...)
537
  ):
538
  try:
539
+ file_ext, content = await process_uploaded_file(file)
540
+ if file_ext not in {"xlsx", "xls"}:
541
+ raise HTTPException(400, "Only Excel files are supported")
542
 
543
+ df = pd.read_excel(io.BytesIO(content))
 
 
544
  return {
545
  "columns": list(df.columns),
546
+ "sample_data": df.head().to_dict(orient='records'),
547
  "statistics": df.describe().to_dict() if len(df.select_dtypes(include=['number']).columns) > 0 else None
548
  }
 
 
 
 
 
 
549
  except Exception as e:
550
+ logger.error(f"Column extraction failed: {str(e)}")
551
+ raise HTTPException(500, detail="Failed to extract columns from Excel file")
552
 
553
  @app.exception_handler(RateLimitExceeded)
554
+ async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
555
  return JSONResponse(
556
  status_code=429,
557
  content={"detail": "Too many requests. Please try again later."}