chenguittiMaroua commited on
Commit
043cd21
·
verified ·
1 Parent(s): 4e8d5a1

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +326 -130
main.py CHANGED
@@ -1,31 +1,44 @@
1
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import JSONResponse
4
- from pydantic import BaseModel
5
- from typing import Optional, Tuple, List, Dict
6
  import io
 
 
7
  import pandas as pd
8
- import matplotlib.pyplot as plt
9
- import seaborn as sns
10
- import base64
11
- import tempfile
12
- import logging
13
- import traceback
14
  import uvicorn
 
 
 
 
 
15
  from slowapi import Limiter
16
  from slowapi.util import get_remote_address
17
  from slowapi.errors import RateLimitExceeded
18
  from slowapi.middleware import SlowAPIMiddleware
 
 
 
 
 
 
 
19
 
20
- # Initialize FastAPI app
21
- app = FastAPI(title="Data Visualization API", version="1.0")
22
-
23
- # Rate limiting setup
24
  limiter = Limiter(key_func=get_remote_address)
 
 
 
 
 
 
 
 
25
  app.state.limiter = limiter
26
  app.add_middleware(SlowAPIMiddleware)
27
 
28
- # CORS configuration
29
  app.add_middleware(
30
  CORSMiddleware,
31
  allow_origins=["*"],
@@ -35,123 +48,306 @@ app.add_middleware(
35
 
36
  # Constants
37
  MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
38
- SUPPORTED_EXCEL_TYPES = {"xlsx", "xls"}
 
 
39
 
40
- # Configure logging
41
- logging.basicConfig(level=logging.INFO)
42
- logger = logging.getLogger(__name__)
 
43
 
44
- class VisualizationRequest(BaseModel):
45
- chart_type: str
46
- x_column: str
47
- y_column: Optional[str] = None
48
- hue_column: Optional[str] = None
49
- title: Optional[str] = None
50
- x_label: Optional[str] = None
51
- y_label: Optional[str] = None
52
- style: str = "seaborn"
53
- width: int = 10
54
- height: int = 6
55
 
56
- async def validate_excel_file(file: UploadFile) -> Tuple[str, bytes]:
57
- """Validate and process uploaded Excel file"""
 
 
 
 
 
 
 
 
 
 
 
 
58
  if not file.filename:
59
  raise HTTPException(400, "No filename provided")
60
 
61
  file_ext = file.filename.split('.')[-1].lower()
62
- if file_ext not in SUPPORTED_EXCEL_TYPES:
63
- raise HTTPException(400, f"Unsupported file type. Supported: {', '.join(SUPPORTED_EXCEL_TYPES)}")
64
 
65
  content = await file.read()
66
  if len(content) > MAX_FILE_SIZE:
67
  raise HTTPException(413, f"File too large. Max size: {MAX_FILE_SIZE//1024//1024}MB")
68
 
69
- await file.seek(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  return file_ext, content
71
 
72
- def read_excel_with_fallback(content: bytes) -> pd.DataFrame:
73
- """Read Excel file with engine fallback"""
74
  try:
75
- return pd.read_excel(io.BytesIO(content), engine='openpyxl')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  except Exception as e:
77
- logger.warning(f"Openpyxl failed, trying xlrd: {str(e)}")
78
- try:
79
- return pd.read_excel(io.BytesIO(content), engine='xlrd')
80
- except Exception as e:
81
- raise ValueError(f"Failed to read Excel file: {str(e)}")
82
 
83
- def generate_chart(df: pd.DataFrame, request: VisualizationRequest) -> str:
84
- """Generate matplotlib/seaborn chart based on request"""
85
- plt.style.use(request.style)
86
- plt.figure(figsize=(request.width, request.height))
87
-
88
- chart_type = request.chart_type.lower()
89
  try:
90
- if chart_type == "line":
91
- if request.hue_column:
92
- sns.lineplot(data=df, x=request.x_column, y=request.y_column, hue=request.hue_column)
93
- else:
94
- plt.plot(df[request.x_column], df[request.y_column])
95
- elif chart_type == "bar":
96
- if request.hue_column:
97
- sns.barplot(data=df, x=request.x_column, y=request.y_column, hue=request.hue_column)
98
- else:
99
- plt.bar(df[request.x_column], df[request.y_column])
100
- elif chart_type == "scatter":
101
- if request.hue_column:
102
- sns.scatterplot(data=df, x=request.x_column, y=request.y_column, hue=request.hue_column)
103
- else:
104
- plt.scatter(df[request.x_column], df[request.y_column])
105
- elif chart_type == "histogram":
106
- plt.hist(df[request.x_column], bins=20)
107
- else:
108
- raise ValueError(f"Unsupported chart type: {chart_type}")
109
-
110
- if request.title:
111
- plt.title(request.title)
112
- if request.x_label:
113
- plt.xlabel(request.x_label)
114
- if request.y_label:
115
- plt.ylabel(request.y_label)
116
 
117
- plt.tight_layout()
 
 
 
 
 
 
 
 
 
 
118
 
119
- # Save to temporary file
120
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
121
- plt.savefig(tmpfile.name, format='png', dpi=300)
122
- plt.close()
123
- with open(tmpfile.name, "rb") as f:
124
- return base64.b64encode(f.read()).decode('utf-8')
 
 
 
 
 
 
 
 
 
 
 
 
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  except Exception as e:
127
- plt.close()
128
- raise ValueError(f"Chart generation failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
- @app.post("/visualize", response_model=Dict[str, str])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  @limiter.limit("5/minute")
132
- async def create_visualization(
133
  request: Request,
134
  file: UploadFile = File(...),
135
  chart_type: str = Form(...),
136
- x_column: str = Form(...),
137
  y_column: Optional[str] = Form(None),
138
  hue_column: Optional[str] = Form(None),
139
  title: Optional[str] = Form(None),
140
  x_label: Optional[str] = Form(None),
141
  y_label: Optional[str] = Form(None),
142
- style: str = Form("seaborn"),
143
- width: int = Form(10),
144
- height: int = Form(6)
145
  ):
146
  try:
147
- # Validate and read file
148
- file_ext, content = await validate_excel_file(file)
149
- df = read_excel_with_fallback(content)
 
150
 
151
- if df.empty:
152
- raise ValueError("Excel file contains no data")
153
-
154
- # Create visualization
155
  vis_request = VisualizationRequest(
156
  chart_type=chart_type,
157
  x_column=x_column,
@@ -160,60 +356,60 @@ async def create_visualization(
160
  title=title,
161
  x_label=x_label,
162
  y_label=y_label,
163
- style=style,
164
- width=width,
165
- height=height
166
  )
167
 
168
- image_base64 = generate_chart(df, vis_request)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  return {
171
  "status": "success",
172
- "image": f"data:image/png;base64,{image_base64}"
 
173
  }
174
-
175
- except HTTPException as he:
176
- raise he
177
- except ValueError as ve:
178
- logger.error(f"Validation error: {str(ve)}")
179
- raise HTTPException(422, detail=str(ve))
180
  except Exception as e:
181
- logger.error(f"Unexpected error: {str(e)}\n{traceback.format_exc()}")
182
- raise HTTPException(500, detail="Failed to generate visualization")
183
 
184
- @app.post("/get_columns", response_model=Dict[str, List])
 
185
  @limiter.limit("10/minute")
186
- async def get_columns(
187
  request: Request,
188
  file: UploadFile = File(...)
189
  ):
190
  try:
191
- file_ext, content = await validate_excel_file(file)
192
- df = read_excel_with_fallback(content)
 
193
 
194
- if df.empty:
195
- raise ValueError("Excel file contains no data")
196
-
197
  return {
198
  "columns": list(df.columns),
199
- "sample_data": df.head().replace({float('nan'): None}).to_dict(orient='records')
200
  }
201
-
202
- except HTTPException as he:
203
- raise he
204
- except ValueError as ve:
205
- logger.error(f"Validation error: {str(ve)}")
206
- raise HTTPException(422, detail=str(ve))
207
  except Exception as e:
208
- logger.error(f"Unexpected error: {str(e)}\n{traceback.format_exc()}")
209
- raise HTTPException(500, detail="Failed to process Excel file")
210
 
211
- @app.exception_handler(RateLimitExceeded)
212
- async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
213
- return JSONResponse(
214
- status_code=429,
215
- content={"detail": "Too many requests. Please try again later."}
216
- )
217
 
218
  if __name__ == "__main__":
219
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import JSONResponse
4
+ from transformers import pipeline
5
+ from typing import Tuple
6
  import io
7
+ import fitz # PyMuPDF
8
+ from PIL import Image
9
  import pandas as pd
 
 
 
 
 
 
10
  import uvicorn
11
+ from docx import Document
12
+ from pptx import Presentation
13
+ import pytesseract
14
+ import logging
15
+ import re
16
  from slowapi import Limiter
17
  from slowapi.util import get_remote_address
18
  from slowapi.errors import RateLimitExceeded
19
  from slowapi.middleware import SlowAPIMiddleware
20
+ import matplotlib.pyplot as plt
21
+ import seaborn as sns
22
+ import tempfile
23
+ import base64
24
+ from io import BytesIO
25
+ from typing import Optional
26
+ from pydantic import BaseModel
27
 
28
+ # Initialize rate limiter
 
 
 
29
  limiter = Limiter(key_func=get_remote_address)
30
+
31
+ # Configure logging
32
+ logging.basicConfig(level=logging.INFO)
33
+ logger = logging.getLogger(__name__)
34
+
35
+ app = FastAPI()
36
+
37
+ # Apply rate limiting middleware
38
  app.state.limiter = limiter
39
  app.add_middleware(SlowAPIMiddleware)
40
 
41
+ # CORS Configuration
42
  app.add_middleware(
43
  CORSMiddleware,
44
  allow_origins=["*"],
 
48
 
49
  # Constants
50
  MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
51
+ SUPPORTED_FILE_TYPES = {
52
+ "docx", "xlsx", "pptx", "pdf", "jpg", "jpeg", "png"
53
+ }
54
 
55
+ # Model caching
56
+ summarizer = None
57
+ qa_model = None
58
+ image_captioner = None
59
 
60
+ def get_summarizer():
61
+ global summarizer
62
+ if summarizer is None:
63
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
64
+ return summarizer
 
 
 
 
 
 
65
 
66
+ def get_qa_model():
67
+ global qa_model
68
+ if qa_model is None:
69
+ qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
70
+ return qa_model
71
+
72
+ def get_image_captioner():
73
+ global image_captioner
74
+ if image_captioner is None:
75
+ image_captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
76
+ return image_captioner
77
+
78
+ async def process_uploaded_file(file: UploadFile) -> Tuple[str, bytes]:
79
+ """Validate and process uploaded file with special handling for each type"""
80
  if not file.filename:
81
  raise HTTPException(400, "No filename provided")
82
 
83
  file_ext = file.filename.split('.')[-1].lower()
84
+ if file_ext not in SUPPORTED_FILE_TYPES:
85
+ raise HTTPException(400, f"Unsupported file type. Supported: {', '.join(SUPPORTED_FILE_TYPES)}")
86
 
87
  content = await file.read()
88
  if len(content) > MAX_FILE_SIZE:
89
  raise HTTPException(413, f"File too large. Max size: {MAX_FILE_SIZE//1024//1024}MB")
90
 
91
+ # Special validation for PDFs
92
+ if file_ext == "pdf":
93
+ try:
94
+ with fitz.open(stream=content, filetype="pdf") as doc:
95
+ if doc.is_encrypted:
96
+ if not doc.authenticate(""):
97
+ raise ValueError("Encrypted PDF - cannot extract text")
98
+ if len(doc) > 50:
99
+ raise ValueError("PDF too large (max 50 pages)")
100
+ except Exception as e:
101
+ logger.error(f"PDF validation failed: {str(e)}")
102
+ raise HTTPException(422, detail=f"Invalid PDF file: {str(e)}")
103
+
104
+ await file.seek(0) # Reset file pointer for processing
105
  return file_ext, content
106
 
107
+ def extract_text(content: bytes, file_ext: str) -> str:
108
+ """Extract text from various file formats with enhanced support"""
109
  try:
110
+ if file_ext == "docx":
111
+ doc = Document(io.BytesIO(content))
112
+ return "\n".join(para.text for para in doc.paragraphs if para.text.strip())
113
+
114
+ elif file_ext in {"xlsx", "xls"}:
115
+ df = pd.read_excel(io.BytesIO(content), sheet_name=None)
116
+ all_text = []
117
+ for sheet_name, sheet_data in df.items():
118
+ sheet_text = []
119
+ for column in sheet_data.columns:
120
+ sheet_text.extend(sheet_data[column].dropna().astype(str).tolist())
121
+ all_text.append(f"Sheet: {sheet_name}\n" + "\n".join(sheet_text))
122
+ return "\n\n".join(all_text)
123
+
124
+ elif file_ext == "pptx":
125
+ ppt = Presentation(io.BytesIO(content))
126
+ text = []
127
+ for slide in ppt.slides:
128
+ for shape in slide.shapes:
129
+ if hasattr(shape, "text") and shape.text.strip():
130
+ text.append(shape.text)
131
+ return "\n".join(text)
132
+
133
+ elif file_ext == "pdf":
134
+ pdf = fitz.open(stream=content, filetype="pdf")
135
+ return "\n".join(page.get_text("text") for page in pdf)
136
+
137
+ elif file_ext in {"jpg", "jpeg", "png"}:
138
+ # First try OCR
139
+ try:
140
+ image = Image.open(io.BytesIO(content))
141
+ text = pytesseract.image_to_string(image, config='--psm 6')
142
+ if text.strip():
143
+ return text
144
+
145
+ # If OCR fails, try image captioning
146
+ captioner = get_image_captioner()
147
+ result = captioner(image)
148
+ return result[0]['generated_text']
149
+ except Exception as img_e:
150
+ logger.error(f"Image processing failed: {str(img_e)}")
151
+ raise ValueError("Could not extract text or caption from image")
152
+
153
  except Exception as e:
154
+ logger.error(f"Text extraction failed for {file_ext}: {str(e)}")
155
+ raise HTTPException(422, f"Failed to extract text from {file_ext} file")
 
 
 
156
 
157
+ @app.post("/summarize")
158
+ @limiter.limit("5/minute")
159
+ async def summarize_document(request: Request, file: UploadFile = File(...)):
 
 
 
160
  try:
161
+ file_ext, content = await process_uploaded_file(file)
162
+ text = extract_text(content, file_ext)
163
+
164
+ if not text.strip():
165
+ raise HTTPException(400, "No extractable text found")
166
+
167
+ # Clean and chunk text
168
+ text = re.sub(r'\s+', ' ', text).strip()
169
+ chunks = [text[i:i+1000] for i in range(0, len(text), 1000)]
170
+
171
+ # Summarize each chunk
172
+ summarizer = get_summarizer()
173
+ summaries = []
174
+ for chunk in chunks:
175
+ summary = summarizer(chunk, max_length=150, min_length=50, do_sample=False)[0]["summary_text"]
176
+ summaries.append(summary)
177
+
178
+ return {"summary": " ".join(summaries)}
179
+
180
+ except HTTPException:
181
+ raise
182
+ except Exception as e:
183
+ logger.error(f"Summarization failed: {str(e)}")
184
+ raise HTTPException(500, "Document summarization failed")
 
 
185
 
186
+ @app.post("/qa")
187
+ @limiter.limit("5/minute")
188
+ async def question_answering(
189
+ request: Request,
190
+ file: UploadFile = File(...),
191
+ question: str = Form(...),
192
+ language: str = Form("fr")
193
+ ):
194
+ try:
195
+ file_ext, content = await process_uploaded_file(file)
196
+ text = extract_text(content, file_ext)
197
 
198
+ if not text.strip():
199
+ raise HTTPException(400, "No extractable text found")
200
+
201
+ # Clean and truncate text
202
+ text = re.sub(r'\s+', ' ', text).strip()[:5000]
203
+
204
+ # Theme detection
205
+ theme_keywords = ["thème", "sujet principal", "quoi le sujet", "theme", "main topic"]
206
+ if any(kw in question.lower() for kw in theme_keywords):
207
+ try:
208
+ summarizer = get_summarizer()
209
+ summary_output = summarizer(
210
+ text,
211
+ max_length=min(100, len(text)//4),
212
+ min_length=30,
213
+ do_sample=False,
214
+ truncation=True
215
+ )
216
 
217
+ theme = summary_output[0].get("summary_text", text[:200] + "...")
218
+ return {
219
+ "question": question,
220
+ "answer": f"Le document traite principalement de : {theme}",
221
+ "confidence": 0.95,
222
+ "language": language
223
+ }
224
+ except Exception:
225
+ theme = text[:200] + ("..." if len(text) > 200 else "")
226
+ return {
227
+ "question": question,
228
+ "answer": f"D'après le document : {theme}",
229
+ "confidence": 0.7,
230
+ "language": language,
231
+ "warning": "theme_summary_fallback"
232
+ }
233
+
234
+ # Standard QA
235
+ qa = get_qa_model()
236
+ result = qa(question=question, context=text[:3000])
237
+
238
+ return {
239
+ "question": question,
240
+ "answer": result["answer"],
241
+ "confidence": result["score"],
242
+ "language": language
243
+ }
244
+
245
+ except HTTPException:
246
+ raise
247
  except Exception as e:
248
+ logger.error(f"QA processing failed: {str(e)}")
249
+ raise HTTPException(500, detail=f"Analysis failed: {str(e)}")
250
+
251
+ @app.exception_handler(RateLimitExceeded)
252
+ async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
253
+ return JSONResponse(
254
+ status_code=429,
255
+ content={"detail": "Too many requests. Please try again later."}
256
+ )
257
+
258
+
259
+
260
+
261
+
262
+ # Add this new Pydantic model for visualization requests
263
+ class VisualizationRequest(BaseModel):
264
+ chart_type: str
265
+ x_column: Optional[str] = None
266
+ y_column: Optional[str] = None
267
+ hue_column: Optional[str] = None
268
+ title: Optional[str] = None
269
+ x_label: Optional[str] = None
270
+ y_label: Optional[str] = None
271
+ style: str = "seaborn" # seaborn or matplotlib
272
 
273
+ # Add this new function for visualization code generation
274
+ def generate_visualization(df: pd.DataFrame, request: VisualizationRequest) -> str:
275
+ """Generate and execute visualization code based on request"""
276
+ plt.style.use(request.style)
277
+
278
+ code_lines = [
279
+ "import matplotlib.pyplot as plt",
280
+ "import seaborn as sns",
281
+ "import pandas as pd",
282
+ "",
283
+ "# Data preparation",
284
+ f"df = pd.DataFrame({df.head().to_dict()})", # Simplified for demo
285
+ "",
286
+ "# Visualization code"
287
+ ]
288
+
289
+ if request.chart_type == "line":
290
+ code_lines.append(f"plt.figure(figsize=(10, 6))")
291
+ if request.hue_column:
292
+ code_lines.append(f"sns.lineplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
293
+ else:
294
+ code_lines.append(f"plt.plot(df['{request.x_column}'], df['{request.y_column}'])")
295
+ elif request.chart_type == "bar":
296
+ code_lines.append(f"plt.figure(figsize=(10, 6))")
297
+ if request.hue_column:
298
+ code_lines.append(f"sns.barplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
299
+ else:
300
+ code_lines.append(f"plt.bar(df['{request.x_column}'], df['{request.y_column}'])")
301
+ elif request.chart_type == "scatter":
302
+ code_lines.append(f"plt.figure(figsize=(10, 6))")
303
+ if request.hue_column:
304
+ code_lines.append(f"sns.scatterplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
305
+ else:
306
+ code_lines.append(f"plt.scatter(df['{request.x_column}'], df['{request.y_column}'])")
307
+ elif request.chart_type == "histogram":
308
+ code_lines.append(f"plt.figure(figsize=(10, 6))")
309
+ code_lines.append(f"plt.hist(df['{request.x_column}'], bins=20)")
310
+ else:
311
+ raise ValueError("Unsupported chart type")
312
+
313
+ # Add labels and title
314
+ if request.title:
315
+ code_lines.append(f"plt.title('{request.title}')")
316
+ if request.x_label:
317
+ code_lines.append(f"plt.xlabel('{request.x_label}')")
318
+ if request.y_label:
319
+ code_lines.append(f"plt.ylabel('{request.y_label}')")
320
+
321
+ code_lines.append("plt.tight_layout()")
322
+ code_lines.append("plt.show()")
323
+
324
+ return "\n".join(code_lines)
325
+
326
+ # Add this new endpoint for visualization
327
+ @app.post("/visualize")
328
  @limiter.limit("5/minute")
329
+ async def generate_visualization_from_excel(
330
  request: Request,
331
  file: UploadFile = File(...),
332
  chart_type: str = Form(...),
333
+ x_column: Optional[str] = Form(None),
334
  y_column: Optional[str] = Form(None),
335
  hue_column: Optional[str] = Form(None),
336
  title: Optional[str] = Form(None),
337
  x_label: Optional[str] = Form(None),
338
  y_label: Optional[str] = Form(None),
339
+ style: str = Form("seaborn")
 
 
340
  ):
341
  try:
342
+ # Validate file
343
+ file_ext, content = await validate_file(file)
344
+ if file_ext not in {"xlsx", "xls"}:
345
+ raise HTTPException(400, "Only Excel files are supported for visualization")
346
 
347
+ # Read Excel file
348
+ df = pd.read_excel(io.BytesIO(content))
349
+
350
+ # Generate visualization request
351
  vis_request = VisualizationRequest(
352
  chart_type=chart_type,
353
  x_column=x_column,
 
356
  title=title,
357
  x_label=x_label,
358
  y_label=y_label,
359
+ style=style
 
 
360
  )
361
 
362
+ # Generate and execute the visualization code
363
+ plt.figure()
364
+ exec(generate_visualization(df, vis_request), globals(), locals())
365
+
366
+ # Save the plot to a temporary file
367
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
368
+ plt.savefig(tmpfile.name, format='png', dpi=300)
369
+ plt.close()
370
+
371
+ # Read the image back as bytes
372
+ with open(tmpfile.name, "rb") as f:
373
+ image_bytes = f.read()
374
+
375
+ # Encode image as base64
376
+ image_base64 = base64.b64encode(image_bytes).decode('utf-8')
377
 
378
  return {
379
  "status": "success",
380
+ "image": f"data:image/png;base64,{image_base64}",
381
+ "code": generate_visualization(df, vis_request)
382
  }
383
+
384
+ except HTTPException:
385
+ raise
 
 
 
386
  except Exception as e:
387
+ logger.error(f"Visualization failed: {str(e)}\n{traceback.format_exc()}")
388
+ raise HTTPException(500, detail=f"Visualization failed: {str(e)}")
389
 
390
+ # Add this new endpoint for getting column names
391
+ @app.post("/get_columns")
392
  @limiter.limit("10/minute")
393
+ async def get_excel_columns(
394
  request: Request,
395
  file: UploadFile = File(...)
396
  ):
397
  try:
398
+ file_ext, content = await validate_file(file)
399
+ if file_ext not in {"xlsx", "xls"}:
400
+ raise HTTPException(400, "Only Excel files are supported")
401
 
402
+ df = pd.read_excel(io.BytesIO(content))
 
 
403
  return {
404
  "columns": list(df.columns),
405
+ "sample_data": df.head().to_dict(orient='records')
406
  }
 
 
 
 
 
 
407
  except Exception as e:
408
+ logger.error(f"Column extraction failed: {str(e)}")
409
+ raise HTTPException(500, detail="Failed to extract columns from Excel file")
410
 
411
+
412
+
 
 
 
 
413
 
414
  if __name__ == "__main__":
415
  uvicorn.run(app, host="0.0.0.0", port=7860)