chenguittiMaroua commited on
Commit
051e65c
·
verified ·
1 Parent(s): fc8ee72

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +129 -537
main.py CHANGED
@@ -1,45 +1,28 @@
1
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
2
- from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import JSONResponse
4
- from transformers import pipeline
5
- from typing import Tuple, Optional
6
- import io
7
- import fitz # PyMuPDF
8
- from PIL import Image
9
- import pandas as pd
10
- import uvicorn
11
- from docx import Document
12
- from pptx import Presentation
13
- import pytesseract
14
- import logging
15
- import re
16
  from slowapi import Limiter
17
  from slowapi.util import get_remote_address
18
  from slowapi.errors import RateLimitExceeded
19
- from slowapi.middleware import SlowAPIMiddleware
20
- import matplotlib.pyplot as plt
21
- import seaborn as sns
22
- import tempfile
23
- import base64
24
- from io import BytesIO
25
- from pydantic import BaseModel
26
- import traceback
27
- import ast
28
 
29
- # Initialize rate limiter
30
- limiter = Limiter(key_func=get_remote_address)
 
 
 
 
 
31
 
32
- # Configure logging
33
- logging.basicConfig(level=logging.INFO)
34
- logger = logging.getLogger(__name__)
 
 
35
 
36
  app = FastAPI()
37
 
38
- # Apply rate limiting middleware
39
- app.state.limiter = limiter
40
- app.add_middleware(SlowAPIMiddleware)
41
-
42
- # CORS Configuration
43
  app.add_middleware(
44
  CORSMiddleware,
45
  allow_origins=["*"],
@@ -47,515 +30,124 @@ app.add_middleware(
47
  allow_headers=["*"],
48
  )
49
 
50
- # Constants
51
- MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
52
- SUPPORTED_FILE_TYPES = {
53
- "docx", "xlsx", "pptx", "pdf", "jpg", "jpeg", "png"
54
- }
55
-
56
- # Model caching
57
- summarizer = None
58
- qa_model = None
59
- image_captioner = None
60
-
61
- def get_summarizer():
62
- global summarizer
63
- if summarizer is None:
64
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
65
- return summarizer
66
-
67
- def get_qa_model():
68
- global qa_model
69
- if qa_model is None:
70
- qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
71
- return qa_model
72
-
73
- def get_image_captioner():
74
- global image_captioner
75
- if image_captioner is None:
76
- image_captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
77
- return image_captioner
78
-
79
- async def process_uploaded_file(file: UploadFile) -> Tuple[str, bytes]:
80
- """Validate and process uploaded file with special handling for each type"""
81
- if not file.filename:
82
- raise HTTPException(400, "No filename provided")
83
-
84
- file_ext = file.filename.split('.')[-1].lower()
85
- if file_ext not in SUPPORTED_FILE_TYPES:
86
- raise HTTPException(400, f"Unsupported file type. Supported: {', '.join(SUPPORTED_FILE_TYPES)}")
87
-
88
- content = await file.read()
89
- if len(content) > MAX_FILE_SIZE:
90
- raise HTTPException(413, f"File too large. Max size: {MAX_FILE_SIZE//1024//1024}MB")
91
-
92
- # Special validation for PDFs
93
- if file_ext == "pdf":
94
- try:
95
- with fitz.open(stream=content, filetype="pdf") as doc:
96
- if doc.is_encrypted:
97
- if not doc.authenticate(""):
98
- raise ValueError("Encrypted PDF - cannot extract text")
99
- if len(doc) > 50:
100
- raise ValueError("PDF too large (max 50 pages)")
101
- except Exception as e:
102
- logger.error(f"PDF validation failed: {str(e)}")
103
- raise HTTPException(422, detail=f"Invalid PDF file: {str(e)}")
104
-
105
- await file.seek(0) # Reset file pointer for processing
106
- return file_ext, content
107
-
108
- def extract_text(content: bytes, file_ext: str) -> str:
109
- """Extract text from various file formats with enhanced support"""
110
- try:
111
- if file_ext == "docx":
112
- doc = Document(io.BytesIO(content))
113
- return "\n".join(para.text for para in doc.paragraphs if para.text.strip())
114
-
115
- elif file_ext in {"xlsx", "xls"}:
116
- df = pd.read_excel(io.BytesIO(content), sheet_name=None)
117
- all_text = []
118
- for sheet_name, sheet_data in df.items():
119
- sheet_text = []
120
- for column in sheet_data.columns:
121
- sheet_text.extend(sheet_data[column].dropna().astype(str).tolist())
122
- all_text.append(f"Sheet: {sheet_name}\n" + "\n".join(sheet_text))
123
- return "\n\n".join(all_text)
124
-
125
- elif file_ext == "pptx":
126
- ppt = Presentation(io.BytesIO(content))
127
- text = []
128
- for slide in ppt.slides:
129
- for shape in slide.shapes:
130
- if hasattr(shape, "text") and shape.text.strip():
131
- text.append(shape.text)
132
- return "\n".join(text)
133
-
134
- elif file_ext == "pdf":
135
- pdf = fitz.open(stream=content, filetype="pdf")
136
- return "\n".join(page.get_text("text") for page in pdf)
137
-
138
- elif file_ext in {"jpg", "jpeg", "png"}:
139
- # First try OCR
140
- try:
141
- image = Image.open(io.BytesIO(content))
142
- text = pytesseract.image_to_string(image, config='--psm 6')
143
- if text.strip():
144
- return text
145
-
146
- # If OCR fails, try image captioning
147
- captioner = get_image_captioner()
148
- result = captioner(image)
149
- return result[0]['generated_text']
150
- except Exception as img_e:
151
- logger.error(f"Image processing failed: {str(img_e)}")
152
- raise ValueError("Could not extract text or caption from image")
153
-
154
- except Exception as e:
155
- logger.error(f"Text extraction failed for {file_ext}: {str(e)}")
156
- raise HTTPException(422, f"Failed to extract text from {file_ext} file")
157
-
158
- # Visualization Models
159
- class VisualizationRequest(BaseModel):
160
- chart_type: str
161
- x_column: Optional[str] = None
162
- y_column: Optional[str] = None
163
- hue_column: Optional[str] = None
164
- title: Optional[str] = None
165
- x_label: Optional[str] = None
166
- y_label: Optional[str] = None
167
- style: str = "seaborn"
168
- filters: Optional[dict] = None
169
-
170
- class NaturalLanguageRequest(BaseModel):
171
- prompt: str
172
- style: str = "seaborn"
173
-
174
- def generate_visualization_code(df: pd.DataFrame, request: VisualizationRequest) -> str:
175
- """Generate Python code for visualization based on request parameters"""
176
- code_lines = [
177
- "import matplotlib.pyplot as plt",
178
- "import seaborn as sns",
179
- "import pandas as pd",
180
- "",
181
- "# Data preparation",
182
- f"df = pd.DataFrame({df.to_dict(orient='list')})",
183
- ]
184
-
185
- # Apply filters if specified
186
- if request.filters:
187
- filter_conditions = []
188
- for column, condition in request.filters.items():
189
- if isinstance(condition, dict):
190
- if 'min' in condition and 'max' in condition:
191
- filter_conditions.append(f"(df['{column}'] >= {condition['min']}) & (df['{column}'] <= {condition['max']})")
192
- elif 'values' in condition:
193
- values = ', '.join([f"'{v}'" if isinstance(v, str) else str(v) for v in condition['values']])
194
- filter_conditions.append(f"df['{column}'].isin([{values}])")
195
- else:
196
- filter_conditions.append(f"df['{column}'] == {repr(condition)}")
197
-
198
- if filter_conditions:
199
- code_lines.extend([
200
- "",
201
- "# Apply filters",
202
- f"df = df[{' & '.join(filter_conditions)}]"
203
- ])
204
-
205
- code_lines.extend([
206
- "",
207
- "# Visualization",
208
- f"plt.style.use('{request.style}')",
209
- f"plt.figure(figsize=(10, 6))"
210
- ])
211
-
212
- # Chart type specific code
213
- if request.chart_type == "line":
214
- if request.hue_column:
215
- code_lines.append(f"sns.lineplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
216
- else:
217
- code_lines.append(f"plt.plot(df['{request.x_column}'], df['{request.y_column}'])")
218
- elif request.chart_type == "bar":
219
- if request.hue_column:
220
- code_lines.append(f"sns.barplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
221
- else:
222
- code_lines.append(f"plt.bar(df['{request.x_column}'], df['{request.y_column}'])")
223
- elif request.chart_type == "scatter":
224
- if request.hue_column:
225
- code_lines.append(f"sns.scatterplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
226
- else:
227
- code_lines.append(f"plt.scatter(df['{request.x_column}'], df['{request.y_column}'])")
228
- elif request.chart_type == "histogram":
229
- code_lines.append(f"plt.hist(df['{request.x_column}'], bins=20)")
230
- elif request.chart_type == "boxplot":
231
- if request.hue_column:
232
- code_lines.append(f"sns.boxplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
233
- else:
234
- code_lines.append(f"sns.boxplot(data=df, x='{request.x_column}', y='{request.y_column}')")
235
- elif request.chart_type == "heatmap":
236
- code_lines.append(f"corr = df.corr()")
237
- code_lines.append(f"sns.heatmap(corr, annot=True, cmap='coolwarm')")
238
- else:
239
- raise ValueError(f"Unsupported chart type: {request.chart_type}")
240
-
241
- # Add labels and title
242
- if request.title:
243
- code_lines.append(f"plt.title('{request.title}')")
244
- if request.x_label:
245
- code_lines.append(f"plt.xlabel('{request.x_label}')")
246
- if request.y_label:
247
- code_lines.append(f"plt.ylabel('{request.y_label}')")
248
-
249
- code_lines.extend([
250
- "plt.tight_layout()",
251
- "plt.show()"
252
- ])
253
-
254
- return "\n".join(code_lines)
255
 
256
- def interpret_natural_language(prompt: str, df_columns: list) -> VisualizationRequest:
257
- """Convert natural language prompt to visualization parameters"""
258
- # Simple keyword-based interpretation (could be enhanced with NLP)
259
- prompt = prompt.lower()
260
-
261
- # Determine chart type
262
- chart_type = "bar"
263
- if "line" in prompt:
264
- chart_type = "line"
265
- elif "scatter" in prompt:
266
- chart_type = "scatter"
267
- elif "histogram" in prompt:
268
- chart_type = "histogram"
269
- elif "box" in prompt:
270
- chart_type = "boxplot"
271
- elif "heatmap" in prompt or "correlation" in prompt:
272
- chart_type = "heatmap"
273
-
274
- # Try to detect columns
275
- x_col = None
276
- y_col = None
277
- hue_col = None
278
-
279
- for col in df_columns:
280
- if col.lower() in prompt:
281
- if not x_col:
282
- x_col = col
283
- elif not y_col:
284
- y_col = col
285
- else:
286
- hue_col = col
287
-
288
- # Default to first columns if not detected
289
- if not x_col and len(df_columns) > 0:
290
- x_col = df_columns[0]
291
- if not y_col and len(df_columns) > 1:
292
- y_col = df_columns[1]
293
-
294
- return VisualizationRequest(
295
- chart_type=chart_type,
296
- x_column=x_col,
297
- y_column=y_col,
298
- hue_column=hue_col,
299
- title="Generated from: " + prompt[:50] + ("..." if len(prompt) > 50 else ""),
300
- style="seaborn"
301
  )
302
 
303
- @app.post("/summarize")
304
- @limiter.limit("5/minute")
305
- async def summarize_document(request: Request, file: UploadFile = File(...)):
306
- try:
307
- file_ext, content = await process_uploaded_file(file)
308
- text = extract_text(content, file_ext)
309
-
310
- if not text.strip():
311
- raise HTTPException(400, "No extractable text found")
312
-
313
- # Clean and chunk text
314
- text = re.sub(r'\s+', ' ', text).strip()
315
- chunks = [text[i:i+1000] for i in range(0, len(text), 1000)]
316
-
317
- # Summarize each chunk
318
- summarizer = get_summarizer()
319
- summaries = []
320
- for chunk in chunks:
321
- summary = summarizer(chunk, max_length=150, min_length=50, do_sample=False)[0]["summary_text"]
322
- summaries.append(summary)
323
-
324
- return {"summary": " ".join(summaries)}
325
-
326
- except HTTPException:
327
- raise
328
- except Exception as e:
329
- logger.error(f"Summarization failed: {str(e)}")
330
- raise HTTPException(500, "Document summarization failed")
331
-
332
- @app.post("/qa")
333
- @limiter.limit("5/minute")
334
- async def question_answering(
335
- request: Request,
336
- file: UploadFile = File(...),
337
- question: str = Form(...),
338
- language: str = Form("fr")
339
- ):
340
- try:
341
- file_ext, content = await process_uploaded_file(file)
342
- text = extract_text(content, file_ext)
343
-
344
- if not text.strip():
345
- raise HTTPException(400, "No extractable text found")
346
-
347
- # Clean and truncate text
348
- text = re.sub(r'\s+', ' ', text).strip()[:5000]
349
-
350
- # Theme detection
351
- theme_keywords = ["thème", "sujet principal", "quoi le sujet", "theme", "main topic"]
352
- if any(kw in question.lower() for kw in theme_keywords):
353
- try:
354
- summarizer = get_summarizer()
355
- summary_output = summarizer(
356
- text,
357
- max_length=min(100, len(text)//4),
358
- min_length=30,
359
- do_sample=False,
360
- truncation=True
361
- )
362
-
363
- theme = summary_output[0].get("summary_text", text[:200] + "...")
364
- return {
365
- "question": question,
366
- "answer": f"Le document traite principalement de : {theme}",
367
- "confidence": 0.95,
368
- "language": language
369
- }
370
- except Exception:
371
- theme = text[:200] + ("..." if len(text) > 200 else "")
372
- return {
373
- "question": question,
374
- "answer": f"D'après le document : {theme}",
375
- "confidence": 0.7,
376
- "language": language,
377
- "warning": "theme_summary_fallback"
378
- }
379
-
380
- # Standard QA
381
- qa = get_qa_model()
382
- result = qa(question=question, context=text[:3000])
383
-
384
- return {
385
- "question": question,
386
- "answer": result["answer"],
387
- "confidence": result["score"],
388
- "language": language
389
- }
390
-
391
- except HTTPException:
392
- raise
393
- except Exception as e:
394
- logger.error(f"QA processing failed: {str(e)}")
395
- raise HTTPException(500, detail=f"Analysis failed: {str(e)}")
396
-
397
- @app.post("/visualize/code")
398
- @limiter.limit("5/minute")
399
- async def visualize_with_code(
400
- request: Request,
401
- file: UploadFile = File(...),
402
- chart_type: str = Form(...),
403
- x_column: Optional[str] = Form(None),
404
- y_column: Optional[str] = Form(None),
405
- hue_column: Optional[str] = Form(None),
406
- title: Optional[str] = Form(None),
407
- x_label: Optional[str] = Form(None),
408
- y_label: Optional[str] = Form(None),
409
- style: str = Form("seaborn"),
410
- filters: Optional[str] = Form(None)
411
- ):
412
- try:
413
- # Validate file
414
- file_ext, content = await process_uploaded_file(file)
415
- if file_ext not in {"xlsx", "xls"}:
416
- raise HTTPException(400, "Only Excel files are supported for visualization")
417
-
418
- # Read Excel file
419
- df = pd.read_excel(io.BytesIO(content))
420
-
421
- # Parse filters if provided
422
- filter_dict = {}
423
- if filters:
424
- try:
425
- filter_dict = ast.literal_eval(filters)
426
- if not isinstance(filter_dict, dict):
427
- filter_dict = {}
428
- except:
429
- filter_dict = {}
430
-
431
- # Create visualization request
432
- vis_request = VisualizationRequest(
433
- chart_type=chart_type,
434
- x_column=x_column,
435
- y_column=y_column,
436
- hue_column=hue_column,
437
- title=title,
438
- x_label=x_label,
439
- y_label=y_label,
440
- style=style,
441
- filters=filter_dict
442
- )
443
-
444
- # Generate visualization code
445
- visualization_code = generate_visualization_code(df, vis_request)
446
-
447
- # Execute the code to generate the plot
448
- plt.figure()
449
- local_vars = {}
450
- exec(visualization_code, globals(), local_vars)
451
-
452
- # Save the plot to a temporary file
453
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
454
- plt.savefig(tmpfile.name, format='png', dpi=300)
455
- plt.close()
456
-
457
- # Read the image back as bytes
458
- with open(tmpfile.name, "rb") as f:
459
- image_bytes = f.read()
460
-
461
- # Encode image as base64
462
- image_base64 = base64.b64encode(image_bytes).decode('utf-8')
463
-
464
- return {
465
- "status": "success",
466
- "image": f"data:image/png;base64,{image_base64}",
467
- "code": visualization_code,
468
- "data_preview": df.head().to_dict(orient='records')
469
- }
470
-
471
- except HTTPException:
472
- raise
473
- except Exception as e:
474
- logger.error(f"Visualization failed: {str(e)}\n{traceback.format_exc()}")
475
- raise HTTPException(500, detail=f"Visualization failed: {str(e)}")
476
-
477
- @app.post("/visualize/natural")
478
- @limiter.limit("5/minute")
479
- async def visualize_with_natural_language(
480
- request: Request,
481
- file: UploadFile = File(...),
482
- prompt: str = Form(...),
483
- style: str = Form("seaborn")
484
- ):
485
- try:
486
- # Validate file
487
- file_ext, content = await process_uploaded_file(file)
488
- if file_ext not in {"xlsx", "xls"}:
489
- raise HTTPException(400, "Only Excel files are supported for visualization")
490
-
491
- # Read Excel file
492
- df = pd.read_excel(io.BytesIO(content))
493
-
494
- # Convert natural language to visualization parameters
495
- nl_request = NaturalLanguageRequest(prompt=prompt, style=style)
496
- vis_request = interpret_natural_language(nl_request.prompt, df.columns.tolist())
497
-
498
- # Generate visualization code
499
- visualization_code = generate_visualization_code(df, vis_request)
500
-
501
- # Execute the code to generate the plot
502
- plt.figure()
503
- local_vars = {}
504
- exec(visualization_code, globals(), local_vars)
505
-
506
- # Save the plot to a temporary file
507
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
508
- plt.savefig(tmpfile.name, format='png', dpi=300)
509
- plt.close()
510
-
511
- # Read the image back as bytes
512
- with open(tmpfile.name, "rb") as f:
513
- image_bytes = f.read()
514
-
515
- # Encode image as base64
516
- image_base64 = base64.b64encode(image_bytes).decode('utf-8')
517
-
518
- return {
519
- "status": "success",
520
- "image": f"data:image/png;base64,{image_base64}",
521
- "code": visualization_code,
522
- "interpreted_parameters": vis_request.dict(),
523
- "data_preview": df.head().to_dict(orient='records')
524
- }
525
-
526
- except HTTPException:
527
- raise
528
- except Exception as e:
529
- logger.error(f"Natural language visualization failed: {str(e)}\n{traceback.format_exc()}")
530
- raise HTTPException(500, detail=f"Visualization failed: {str(e)}")
531
-
532
- @app.post("/get_columns")
533
  @limiter.limit("10/minute")
534
- async def get_excel_columns(
535
  request: Request,
536
- file: UploadFile = File(...)
 
 
537
  ):
538
- try:
539
- file_ext, content = await process_uploaded_file(file)
540
- if file_ext not in {"xlsx", "xls"}:
541
- raise HTTPException(400, "Only Excel files are supported")
542
-
543
- df = pd.read_excel(io.BytesIO(content))
544
- return {
545
- "columns": list(df.columns),
546
- "sample_data": df.head().to_dict(orient='records'),
547
- "statistics": df.describe().to_dict() if len(df.select_dtypes(include=['number']).columns) > 0 else None
548
- }
549
- except Exception as e:
550
- logger.error(f"Column extraction failed: {str(e)}")
551
- raise HTTPException(500, detail="Failed to extract columns from Excel file")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552
 
553
- @app.exception_handler(RateLimitExceeded)
554
- async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
555
- return JSONResponse(
556
- status_code=429,
557
- content={"detail": "Too many requests. Please try again later."}
558
- )
559
 
560
- if __name__ == "__main__":
561
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
 
2
  from fastapi.responses import JSONResponse
 
 
 
 
 
 
 
 
 
 
 
 
3
  from slowapi import Limiter
4
  from slowapi.util import get_remote_address
5
  from slowapi.errors import RateLimitExceeded
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from starlette.requests import Request
 
 
 
 
 
 
 
8
 
9
+ import pytesseract
10
+ from PIL import Image
11
+ import fitz # PyMuPDF
12
+ import docx
13
+ import pptx
14
+ import pandas as pd
15
+ import io
16
 
17
+ from transformers import pipeline
18
+ import matplotlib.pyplot as plt
19
+ import seaborn as sns
20
+ import uuid
21
+ import os
22
 
23
  app = FastAPI()
24
 
25
+ # CORS (optional, for frontend access)
 
 
 
 
26
  app.add_middleware(
27
  CORSMiddleware,
28
  allow_origins=["*"],
 
30
  allow_headers=["*"],
31
  )
32
 
33
+ # Rate Limiting
34
+ limiter = Limiter(key_func=get_remote_address)
35
+ app.state.limiter = limiter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ @app.exception_handler(RateLimitExceeded)
38
+ async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
39
+ return JSONResponse(
40
+ status_code=429,
41
+ content={"error": "Rate limit exceeded. Please try again later."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  )
43
 
44
+ # Hugging Face Pipelines
45
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
46
+ qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2")
47
+ image_captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
48
+
49
+ # Utility: Save image and return path
50
+ def save_temp_image(upload: UploadFile):
51
+ image_path = f"temp/{uuid.uuid4().hex}_{upload.filename}"
52
+ with open(image_path, "wb") as f:
53
+ f.write(upload.file.read())
54
+ return image_path
55
+
56
+ # --- File Parsing Utilities ---
57
+ def extract_text_from_pdf(file_bytes: bytes) -> str:
58
+ doc = fitz.open(stream=file_bytes, filetype="pdf")
59
+ return "\n".join(page.get_text() for page in doc)
60
+
61
+ def extract_text_from_docx(file_bytes: bytes) -> str:
62
+ doc = docx.Document(io.BytesIO(file_bytes))
63
+ return "\n".join(p.text for p in doc.paragraphs)
64
+
65
+ def extract_text_from_pptx(file_bytes: bytes) -> str:
66
+ prs = pptx.Presentation(io.BytesIO(file_bytes))
67
+ text = ""
68
+ for slide in prs.slides:
69
+ for shape in slide.shapes:
70
+ if hasattr(shape, "text"):
71
+ text += shape.text + "\n"
72
+ return text
73
+
74
+ def extract_text_from_image(file_bytes: bytes) -> str:
75
+ img = Image.open(io.BytesIO(file_bytes))
76
+ return pytesseract.image_to_string(img)
77
+
78
+ def extract_data_from_excel(file_bytes: bytes) -> pd.DataFrame:
79
+ return pd.read_excel(io.BytesIO(file_bytes))
80
+
81
+ # --- API Endpoints ---
82
+ @app.post("/process/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  @limiter.limit("10/minute")
84
+ async def process_file(
85
  request: Request,
86
+ file: UploadFile = File(...),
87
+ task: str = Form(...),
88
+ question: str = Form(None)
89
  ):
90
+ content_type = file.content_type
91
+ file_bytes = await file.read()
92
+
93
+ # --- Task: Summarization or QA ---
94
+ if task in ["summarization", "question_answering"]:
95
+ if content_type == "application/pdf":
96
+ text = extract_text_from_pdf(file_bytes)
97
+ elif content_type in ["application/vnd.openxmlformats-officedocument.wordprocessingml.document"]:
98
+ text = extract_text_from_docx(file_bytes)
99
+ elif content_type in ["application/vnd.openxmlformats-officedocument.presentationml.presentation"]:
100
+ text = extract_text_from_pptx(file_bytes)
101
+ elif content_type in ["image/png", "image/jpeg"]:
102
+ text = extract_text_from_image(file_bytes)
103
+ else:
104
+ raise HTTPException(status_code=400, detail="Unsupported file format for this task.")
105
+
106
+ if task == "summarization":
107
+ summary = summarizer(text[:3000])[0]["summary_text"] # truncate long text
108
+ return {"summary": summary}
109
+
110
+ if task == "question_answering":
111
+ if not question:
112
+ raise HTTPException(status_code=400, detail="Question is required for QA.")
113
+ answer = qa_pipeline(question=question, context=text)
114
+ return {"answer": answer["answer"]}
115
+
116
+ # --- Task: Image Captioning ---
117
+ elif task == "captioning":
118
+ if content_type not in ["image/png", "image/jpeg"]:
119
+ raise HTTPException(status_code=400, detail="Only image files supported for captioning.")
120
+ image_path = save_temp_image(file)
121
+ caption = image_captioner(image_path)[0]["generated_text"]
122
+ os.remove(image_path)
123
+ return {"caption": caption}
124
+
125
+ # --- Task: Data Visualization ---
126
+ elif task == "visualization":
127
+ if content_type != "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
128
+ raise HTTPException(status_code=400, detail="Only Excel files supported for visualization.")
129
+ df = extract_data_from_excel(file_bytes)
130
+
131
+ if df.empty:
132
+ raise HTTPException(status_code=400, detail="No data found in Excel file.")
133
+
134
+ # Example visualization: correlation heatmap
135
+ numeric_df = df.select_dtypes(include="number")
136
+ if numeric_df.empty:
137
+ raise HTTPException(status_code=400, detail="No numeric data available for visualization.")
138
+
139
+ plt.figure(figsize=(10, 6))
140
+ sns.heatmap(numeric_df.corr(), annot=True, cmap="coolwarm")
141
+ viz_path = f"temp/viz_{uuid.uuid4().hex}.png"
142
+ plt.savefig(viz_path)
143
+ plt.close()
144
+
145
+ with open(viz_path, "rb") as img_file:
146
+ img_bytes = img_file.read()
147
+ os.remove(viz_path)
148
+
149
+ return JSONResponse(content={"image_bytes": list(img_bytes)})
150
 
151
+ else:
152
+ raise HTTPException(status_code=400, detail="Unsupported task.")
 
 
 
 
153