chenguittiMaroua commited on
Commit
d728ee4
·
verified ·
1 Parent(s): 2001581

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +78 -300
main.py CHANGED
@@ -1,12 +1,8 @@
1
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException, status, Request
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from fastapi.responses import JSONResponse, HTMLResponse
4
- from fastapi.encoders import jsonable_encoder
5
- from fastapi.staticfiles import StaticFiles
6
- from fastapi.templating import Jinja2Templates
7
- from transformers import pipeline, Pipeline
8
- from typing import Dict, Optional, Tuple, List
9
- from pydantic import BaseModel, constr, validator
10
  import io
11
  import fitz # PyMuPDF
12
  from PIL import Image
@@ -16,161 +12,66 @@ from docx import Document
16
  from pptx import Presentation
17
  import pytesseract
18
  import logging
19
- import os
20
- from datetime import datetime
21
- from pathlib import Path
22
  import re
23
- import torch
24
 
25
  # Configure logging
26
  logging.basicConfig(level=logging.INFO)
27
  logger = logging.getLogger(__name__)
28
 
29
- # Initialize FastAPI app
30
- app = FastAPI(
31
- title="AI Document Analysis API",
32
- description="Advanced document processing with multilingual support",
33
- version="2.0.0",
34
- docs_url="/docs",
35
- redoc_url="/redoc"
36
- )
37
 
38
- # Configure CORS
39
  app.add_middleware(
40
  CORSMiddleware,
41
  allow_origins=["*"],
42
- allow_credentials=True,
43
  allow_methods=["*"],
44
  allow_headers=["*"],
45
  )
46
 
47
- # Set up templates
48
- templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates"))
49
-
50
- # Serve static files
51
- app.mount("/static", StaticFiles(directory="static"), name="static")
52
-
53
  # Constants
54
  MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
55
- MAX_TEXT_LENGTH = 2000
56
- MAX_QUESTION_LENGTH = 500
57
- MIN_QUESTION_LENGTH = 3
58
- SUPPORTED_LANGUAGES = {"fr", "en", "es", "de"}
59
- DEFAULT_LANGUAGE = "fr"
60
-
61
  SUPPORTED_FILE_TYPES = {
62
- "docx": "Word Document",
63
- "xlsx": "Excel Spreadsheet",
64
- "pptx": "PowerPoint Presentation",
65
- "pdf": "PDF Document",
66
- "jpg": "JPEG Image",
67
- "jpeg": "JPEG Image",
68
- "png": "PNG Image"
69
- }
70
-
71
- MODEL_MAPPING = {
72
- "fr": {
73
- "qa": "illuin/camembert-base-fquad",
74
- "summarization": "moussaKam/barthez-orangesum-abstract",
75
- "translation": "Helsinki-NLP/opus-mt-fr-en"
76
- },
77
- "en": {
78
- "qa": "deepset/roberta-base-squad2",
79
- "summarization": "facebook/bart-large-cnn",
80
- "translation": "Helsinki-NLP/opus-mt-en-fr"
81
- },
82
- "default": {
83
- "image_captioning": "Salesforce/blip-image-captioning-large",
84
- "multilingual_translation": "facebook/nllb-200-distilled-600M"
85
- }
86
  }
87
 
88
- # Models cache
89
- models_cache: Dict[str, Pipeline] = {}
90
-
91
- # Pydantic Models
92
- class TranslationRequest(BaseModel):
93
- text: constr(min_length=1, max_length=5000)
94
- target_lang: constr(min_length=2, max_length=5)
95
- src_lang: Optional[constr(min_length=2, max_length=5)] = None
96
-
97
- @validator('target_lang', 'src_lang')
98
- def validate_language_code(cls, v):
99
- if v and len(v) not in {2, 5}:
100
- raise ValueError("Language code must be 2 or 5 characters")
101
- return v
102
-
103
- class QARequest(BaseModel):
104
- question: constr(min_length=MIN_QUESTION_LENGTH, max_length=MAX_QUESTION_LENGTH)
105
- language: constr(min_length=2, max_length=2) = DEFAULT_LANGUAGE
106
-
107
- @validator('language')
108
- def validate_language(cls, v):
109
- if v.lower() not in SUPPORTED_LANGUAGES:
110
- raise ValueError(f"Unsupported language. Supported: {SUPPORTED_LANGUAGES}")
111
- return v.lower()
112
-
113
- class ErrorResponse(BaseModel):
114
- error: str
115
- success: bool = False
116
- status_code: int
117
- timestamp: str
118
- details: Optional[dict] = None
119
-
120
- # Exception Handler
121
- @app.exception_handler(HTTPException)
122
- async def http_exception_handler(request: Request, exc: HTTPException):
123
- error_response = ErrorResponse(
124
- error=exc.detail,
125
- status_code=exc.status_code,
126
- timestamp=datetime.now().isoformat(),
127
- details=getattr(exc, 'details', None)
128
- )
129
- return JSONResponse(
130
- status_code=exc.status_code,
131
- content=jsonable_encoder(error_response)
132
- )
133
-
134
- # Helper Functions
135
- def get_model(model_name: str, task: str) -> Pipeline:
136
- """Get or load a Hugging Face model with caching."""
137
- cache_key = f"{model_name}_{task}"
138
- if cache_key not in models_cache:
139
- try:
140
- logger.info(f"Loading model: {model_name} for task: {task}")
141
- models_cache[cache_key] = pipeline(task, model=model_name)
142
- except Exception as e:
143
- logger.error(f"Model loading failed: {str(e)}")
144
- raise HTTPException(
145
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
146
- detail="Model service unavailable",
147
- details={"model": model_name, "error": str(e)}
148
- )
149
- return models_cache[cache_key]
150
-
151
- async def validate_and_read_file(file: UploadFile) -> Tuple[str, bytes]:
152
- """Validate and read uploaded file."""
153
- # Check file extension
154
- file_ext = Path(file.filename).suffix[1:].lower()
155
  if file_ext not in SUPPORTED_FILE_TYPES:
156
- raise HTTPException(
157
- status_code=status.HTTP_400_BAD_REQUEST,
158
- detail=f"Unsupported file type. Supported: {', '.join(SUPPORTED_FILE_TYPES.values())}"
159
- )
160
-
161
- # Read and check file size
162
  content = await file.read()
163
  if len(content) > MAX_FILE_SIZE:
164
- raise HTTPException(
165
- status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
166
- detail=f"File exceeds maximum size of {MAX_FILE_SIZE//1024//1024}MB"
167
- )
168
-
169
- await file.seek(0)
170
  return file_ext, content
171
 
172
  def extract_text(content: bytes, file_ext: str) -> str:
173
- """Extract text from various file formats."""
174
  try:
175
  if file_ext == "docx":
176
  doc = Document(io.BytesIO(content))
@@ -187,7 +88,12 @@ def extract_text(content: bytes, file_ext: str) -> str:
187
 
188
  elif file_ext == "pdf":
189
  pdf = fitz.open(stream=content, filetype="pdf")
190
- return " ".join(page.get_text("text") for page in pdf)
 
 
 
 
 
191
 
192
  elif file_ext in {"jpg", "jpeg", "png"}:
193
  image = Image.open(io.BytesIO(content))
@@ -195,209 +101,81 @@ def extract_text(content: bytes, file_ext: str) -> str:
195
 
196
  except Exception as e:
197
  logger.error(f"Text extraction failed: {str(e)}")
198
- raise HTTPException(
199
- status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
200
- detail="Failed to extract text from file",
201
- details={"error": str(e), "file_type": file_ext}
202
- )
203
-
204
- def preprocess_text(text: str) -> str:
205
- """Clean and normalize extracted text."""
206
- text = re.sub(r'\s+', ' ', text).strip()
207
- return text[:MAX_TEXT_LENGTH] if len(text) > MAX_TEXT_LENGTH else text
208
-
209
- def chunk_text(text: str, chunk_size: int = 1000) -> List[str]:
210
- """Split text into chunks for processing."""
211
- return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
212
-
213
- # API Endpoints
214
- @app.get("/", response_class=HTMLResponse)
215
- async def home(request: Request):
216
- return templates.TemplateResponse("index.html", {"request": request})
217
-
218
- @app.get("/health")
219
- async def health_check():
220
- return {"status": "healthy", "timestamp": datetime.now().isoformat()}
221
 
222
  @app.post("/summarize")
223
  async def summarize_document(file: UploadFile = File(...)):
224
  try:
225
- file_ext, content = await validate_and_read_file(file)
226
- text = preprocess_text(extract_text(content, file_ext))
227
 
228
  if not text.strip():
229
- raise HTTPException(
230
- status_code=status.HTTP_400_BAD_REQUEST,
231
- detail="No extractable text found in document"
232
- )
233
-
234
- model_name = MODEL_MAPPING.get("en", {}).get("summarization", "facebook/bart-large-cnn")
235
- summarizer = get_model(model_name, "summarization")
236
 
237
- chunks = chunk_text(text)
 
 
 
 
 
238
  summaries = []
239
  for chunk in chunks:
240
  summary = summarizer(chunk, max_length=150, min_length=50, do_sample=False)[0]["summary_text"]
241
  summaries.append(summary)
242
 
243
- return {
244
- "success": True,
245
- "summary": " ".join(summaries),
246
- "language": "en",
247
- "processed_chunks": len(chunks)
248
- }
249
  except HTTPException:
250
  raise
251
  except Exception as e:
252
  logger.error(f"Summarization failed: {str(e)}")
253
- raise HTTPException(
254
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
255
- detail="Document summarization failed",
256
- details={"error": str(e)}
257
- )
258
 
259
  @app.post("/qa")
260
  async def question_answering(
261
  file: UploadFile = File(...),
262
  question: str = Form(...),
263
- language: str = Form(DEFAULT_LANGUAGE)
264
  ):
265
  try:
266
- file_ext, content = await validate_and_read_file(file)
267
- text = preprocess_text(extract_text(content, file_ext))
268
 
269
- # Theme detection
270
- theme_keywords = {
271
- "fr": ["thème", "sujet principal", "quoi le sujet"],
272
- "en": ["theme", "main topic", "what is about"]
273
- }
274
 
275
- is_theme_question = any(
276
- kw in question.lower()
277
- for kw in theme_keywords.get(language, theme_keywords["en"])
278
- )
279
 
280
- if is_theme_question:
281
- model_name = MODEL_MAPPING.get(language, {}).get("summarization")
282
- if not model_name:
283
- model_name = MODEL_MAPPING["default"].get("summarization")
284
-
285
- generator = get_model(model_name, "text-generation")
286
- theme_prompt = (
287
- "Extract the main theme of this text in 1-2 sentences. "
288
- "Respond as if explaining to a beginner. "
289
- "Text: {text}"
290
- )
291
-
292
- response = generator(
293
- theme_prompt.format(text=text[:2000]),
294
- max_length=200,
295
- num_return_sequences=1,
296
- do_sample=False
297
- )
298
-
299
- theme = response[0]["generated_text"].split(":")[-1].strip()
300
- theme = re.sub(r"^(Le|La)\s+", "", theme)
301
-
302
  return {
303
  "question": question,
304
- "answer": f"The document mainly discusses: {theme}",
305
  "confidence": 0.95,
306
- "language": language,
307
- "processing_method": "theme_analysis",
308
- "success": True
309
  }
310
 
311
  # Standard QA processing
312
- model_name = MODEL_MAPPING.get(language, {}).get("qa")
313
- if not model_name:
314
- model_name = MODEL_MAPPING["default"].get("qa")
315
-
316
- qa_model = get_model(model_name, "question-answering")
317
- result = qa_model(question=question, context=text)
318
-
319
- if result["score"] < 0.1:
320
- return {
321
- "question": question,
322
- "answer": "No clear answer found in the document" if language == "en" else "Aucune réponse claire trouvée dans le document",
323
- "confidence": result["score"],
324
- "language": language,
325
- "warning": "low_confidence",
326
- "success": True
327
- }
328
 
329
  return {
330
  "question": question,
331
  "answer": result["answer"],
332
  "confidence": result["score"],
333
- "language": language,
334
- "success": True
335
  }
336
 
337
  except HTTPException:
338
  raise
339
  except Exception as e:
340
  logger.error(f"QA processing failed: {str(e)}")
341
- raise HTTPException(
342
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
343
- detail="Document analysis failed",
344
- details={"error": str(e)}
345
- )
346
-
347
- @app.post("/api/caption")
348
- async def caption_image(file: UploadFile = File(...)):
349
- try:
350
- file_ext, content = await validate_and_read_file(file)
351
- if file_ext not in {"jpg", "jpeg", "png"}:
352
- raise HTTPException(
353
- status_code=status.HTTP_400_BAD_REQUEST,
354
- detail="Only image files are supported for captioning"
355
- )
356
-
357
- image = Image.open(io.BytesIO(content)).convert("RGB")
358
- captioner = get_model(MODEL_MAPPING["default"]["image_captioning"], "image-to-text")
359
- caption = captioner(image)[0]['generated_text']
360
-
361
- return {
362
- "success": True,
363
- "caption": caption,
364
- "file_type": file_ext
365
- }
366
- except HTTPException:
367
- raise
368
- except Exception as e:
369
- logger.error(f"Image captioning failed: {str(e)}")
370
- raise HTTPException(
371
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
372
- detail="Image captioning failed",
373
- details={"error": str(e)}
374
- )
375
-
376
- @app.post("/translate")
377
- async def translate_text(
378
- text: str = Form(...),
379
- target_lang: str = Form(...),
380
- src_lang: str = Form("eng_Latn")
381
- ):
382
- try:
383
- translator = get_model(MODEL_MAPPING["default"]["multilingual_translation"], "translation")
384
- translated = translator(text, src_lang=src_lang, tgt_lang=target_lang)
385
-
386
- return {
387
- "success": True,
388
- "translated_text": translated[0]["translation_text"],
389
- "source_language": src_lang,
390
- "target_language": target_lang
391
- }
392
- except Exception as e:
393
- logger.error(f"Translation failed: {str(e)}")
394
- raise HTTPException(
395
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
396
- detail="Text translation failed",
397
- details={"error": str(e)}
398
- )
399
 
400
- # Run the application
401
  if __name__ == "__main__":
402
- port = int(os.environ.get("PORT", 7860))
403
- uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)
 
1
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import JSONResponse
4
+ from transformers import pipeline
5
+ from typing import Optional
 
 
 
 
6
  import io
7
  import fitz # PyMuPDF
8
  from PIL import Image
 
12
  from pptx import Presentation
13
  import pytesseract
14
  import logging
 
 
 
15
  import re
 
16
 
17
  # Configure logging
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
20
 
21
+ app = FastAPI()
 
 
 
 
 
 
 
22
 
23
+ # CORS Configuration
24
  app.add_middleware(
25
  CORSMiddleware,
26
  allow_origins=["*"],
 
27
  allow_methods=["*"],
28
  allow_headers=["*"],
29
  )
30
 
 
 
 
 
 
 
31
  # Constants
32
  MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
 
 
 
 
 
 
33
  SUPPORTED_FILE_TYPES = {
34
+ "docx", "xlsx", "pptx", "pdf", "jpg", "jpeg", "png"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  }
36
 
37
+ # Model caching
38
+ summarizer = None
39
+ qa_model = None
40
+ image_captioner = None
41
+
42
+ def get_summarizer():
43
+ global summarizer
44
+ if summarizer is None:
45
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
46
+ return summarizer
47
+
48
+ def get_qa_model():
49
+ global qa_model
50
+ if qa_model is None:
51
+ qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
52
+ return qa_model
53
+
54
+ def get_image_captioner():
55
+ global image_captioner
56
+ if image_captioner is None:
57
+ image_captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
58
+ return image_captioner
59
+
60
+ async def process_uploaded_file(file: UploadFile):
61
+ if not file.filename:
62
+ raise HTTPException(400, "No file provided")
63
+
64
+ file_ext = file.filename.split('.')[-1].lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  if file_ext not in SUPPORTED_FILE_TYPES:
66
+ raise HTTPException(400, f"Unsupported file type. Supported: {', '.join(SUPPORTED_FILE_TYPES)}")
67
+
 
 
 
 
68
  content = await file.read()
69
  if len(content) > MAX_FILE_SIZE:
70
+ raise HTTPException(413, f"File too large. Max size: {MAX_FILE_SIZE//1024//1024}MB")
71
+
 
 
 
 
72
  return file_ext, content
73
 
74
  def extract_text(content: bytes, file_ext: str) -> str:
 
75
  try:
76
  if file_ext == "docx":
77
  doc = Document(io.BytesIO(content))
 
88
 
89
  elif file_ext == "pdf":
90
  pdf = fitz.open(stream=content, filetype="pdf")
91
+ text = []
92
+ for page in pdf:
93
+ page_text = page.get_text("text")
94
+ if page_text.strip():
95
+ text.append(page_text)
96
+ return " ".join(text)
97
 
98
  elif file_ext in {"jpg", "jpeg", "png"}:
99
  image = Image.open(io.BytesIO(content))
 
101
 
102
  except Exception as e:
103
  logger.error(f"Text extraction failed: {str(e)}")
104
+ raise HTTPException(422, f"Failed to extract text from {file_ext} file")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  @app.post("/summarize")
107
  async def summarize_document(file: UploadFile = File(...)):
108
  try:
109
+ file_ext, content = await process_uploaded_file(file)
110
+ text = extract_text(content, file_ext)
111
 
112
  if not text.strip():
113
+ raise HTTPException(400, "No extractable text found")
 
 
 
 
 
 
114
 
115
+ # Clean and chunk text
116
+ text = re.sub(r'\s+', ' ', text).strip()
117
+ chunks = [text[i:i+1000] for i in range(0, len(text), 1000)]
118
+
119
+ # Summarize each chunk
120
+ summarizer = get_summarizer()
121
  summaries = []
122
  for chunk in chunks:
123
  summary = summarizer(chunk, max_length=150, min_length=50, do_sample=False)[0]["summary_text"]
124
  summaries.append(summary)
125
 
126
+ return {"summary": " ".join(summaries)}
127
+
 
 
 
 
128
  except HTTPException:
129
  raise
130
  except Exception as e:
131
  logger.error(f"Summarization failed: {str(e)}")
132
+ raise HTTPException(500, "Document summarization failed")
 
 
 
 
133
 
134
  @app.post("/qa")
135
  async def question_answering(
136
  file: UploadFile = File(...),
137
  question: str = Form(...),
138
+ language: str = Form("fr")
139
  ):
140
  try:
141
+ file_ext, content = await process_uploaded_file(file)
142
+ text = extract_text(content, file_ext)
143
 
144
+ if not text.strip():
145
+ raise HTTPException(400, "No extractable text found")
 
 
 
146
 
147
+ # Clean text
148
+ text = re.sub(r'\s+', ' ', text).strip()
 
 
149
 
150
+ # Handle theme questions
151
+ theme_keywords = ["thème", "sujet principal", "quoi le sujet", "theme", "main topic"]
152
+ if any(kw in question.lower() for kw in theme_keywords):
153
+ # Use summarization for theme detection
154
+ summarizer = get_summarizer()
155
+ theme = summarizer(text, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  return {
157
  "question": question,
158
+ "answer": f"Le document traite principalement de : {theme}",
159
  "confidence": 0.95,
160
+ "language": language
 
 
161
  }
162
 
163
  # Standard QA processing
164
+ qa = get_qa_model()
165
+ result = qa(question=question, context=text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  return {
168
  "question": question,
169
  "answer": result["answer"],
170
  "confidence": result["score"],
171
+ "language": language
 
172
  }
173
 
174
  except HTTPException:
175
  raise
176
  except Exception as e:
177
  logger.error(f"QA processing failed: {str(e)}")
178
+ raise HTTPException(500, "Document analysis failed")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
 
180
  if __name__ == "__main__":
181
+ uvicorn.run(app, host="0.0.0.0", port=7860)