chenguittiMaroua commited on
Commit
2001581
·
verified ·
1 Parent(s): b649976

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +132 -177
main.py CHANGED
@@ -1,32 +1,11 @@
1
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from fastapi.staticfiles import StaticFiles
4
- from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering
5
- from sentence_transformers import SentenceTransformer, util
6
- from typing import Optional
7
- import io
8
- import fitz # PyMuPDF
9
- from PIL import Image
10
- import pandas as pd
11
- import uvicorn
12
- from functools import lru_cache
13
- from docx import Document
14
- from pptx import Presentation
15
- import pytesseract
16
- import torch
17
- from typing import Dict
18
- from transformers import Pipeline
19
-
20
-
21
-
22
-
23
-
24
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException, status
25
- from fastapi.middleware.cors import CORSMiddleware
26
- from fastapi.responses import JSONResponse
27
  from fastapi.encoders import jsonable_encoder
 
 
28
  from transformers import pipeline, Pipeline
29
- from typing import Dict, Optional, Tuple
30
  from pydantic import BaseModel, constr, validator
31
  import io
32
  import fitz # PyMuPDF
@@ -41,74 +20,13 @@ import os
41
  from datetime import datetime
42
  from pathlib import Path
43
  import re
44
-
45
-
46
-
47
-
48
-
49
-
50
-
51
- from fastapi.responses import HTMLResponse
52
- from fastapi.templating import Jinja2Templates
53
- from fastapi import Request
54
- from pathlib import Path
55
-
56
-
57
-
58
- import os
59
- print(os.getcwd()) # This prints the current working directory
60
-
61
-
62
-
63
-
64
-
65
-
66
-
67
-
68
- # Initialize FastAPI app
69
- app = FastAPI()
70
- print(os.getcwd())
71
- templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates"))
72
- # Configure CORS
73
- app.add_middleware(
74
- CORSMiddleware,
75
- allow_origins=[
76
- "https://*.hf.space",
77
- "http://localhost",
78
- "http://localhost:8000"
79
- ],
80
- allow_credentials=True,
81
- allow_methods=["*"],
82
- allow_headers=["*"],
83
- )
84
-
85
- # Serve static files (frontend)
86
- app.mount("/static", StaticFiles(directory="static"), name="static")
87
-
88
- # Model loading with caching
89
- @lru_cache()
90
- def get_summarizer():
91
- return pipeline("summarization", model="facebook/bart-large-cnn")
92
-
93
- @lru_cache()
94
- def get_image_captioning():
95
- return pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
96
-
97
- @lru_cache()
98
- def get_translator():
99
- return pipeline("translation", model="facebook/nllb-200-distilled-600M")
100
- @lru_cache()
101
- def get_qa_model():
102
- return pipeline("question-answering", model="deepset/roberta-base-squad2")
103
-
104
-
105
- #########################################################
106
-
107
 
108
  # Configure logging
109
  logging.basicConfig(level=logging.INFO)
110
  logger = logging.getLogger(__name__)
111
 
 
112
  app = FastAPI(
113
  title="AI Document Analysis API",
114
  description="Advanced document processing with multilingual support",
@@ -117,7 +35,7 @@ app = FastAPI(
117
  redoc_url="/redoc"
118
  )
119
 
120
- # CORS Configuration
121
  app.add_middleware(
122
  CORSMiddleware,
123
  allow_origins=["*"],
@@ -126,6 +44,12 @@ app.add_middleware(
126
  allow_headers=["*"],
127
  )
128
 
 
 
 
 
 
 
129
  # Constants
130
  MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
131
  MAX_TEXT_LENGTH = 2000
@@ -195,7 +119,7 @@ class ErrorResponse(BaseModel):
195
 
196
  # Exception Handler
197
  @app.exception_handler(HTTPException)
198
- async def http_exception_handler(request, exc):
199
  error_response = ErrorResponse(
200
  error=exc.detail,
201
  status_code=exc.status_code,
@@ -282,18 +206,67 @@ def preprocess_text(text: str) -> str:
282
  text = re.sub(r'\s+', ' ', text).strip()
283
  return text[:MAX_TEXT_LENGTH] if len(text) > MAX_TEXT_LENGTH else text
284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  @app.post("/qa")
286
  async def question_answering(
287
  file: UploadFile = File(...),
288
  question: str = Form(...),
289
  language: str = Form(DEFAULT_LANGUAGE)
290
- ) -> JSONResponse:
291
  try:
292
- # Validation et extraction du texte
293
  file_ext, content = await validate_and_read_file(file)
294
  text = preprocess_text(extract_text(content, file_ext))
295
 
296
- # Détection spéciale pour les questions sur le thème
297
  theme_keywords = {
298
  "fr": ["thème", "sujet principal", "quoi le sujet"],
299
  "en": ["theme", "main topic", "what is about"]
@@ -305,15 +278,17 @@ async def question_answering(
305
  )
306
 
307
  if is_theme_question:
308
- # Utilisation d'un prompt spécialisé pour l'analyse thématique
 
 
 
 
309
  theme_prompt = (
310
- "Extrayez le thème principal de ce texte en 1-2 phrases. "
311
- "Répondez comme si vous expliquiez à un novice. "
312
- "Texte : {text}"
313
  )
314
 
315
- # Utilisation d'un LLM plus puissant pour l'analyse thématique
316
- generator = get_model("moussaKam/barthez-orangesum-abstract", "text-generation")
317
  response = generator(
318
  theme_prompt.format(text=text[:2000]),
319
  max_length=200,
@@ -321,43 +296,43 @@ async def question_answering(
321
  do_sample=False
322
  )
323
 
324
- # Nettoyage de la réponse
325
  theme = response[0]["generated_text"].split(":")[-1].strip()
326
- theme = re.sub(r"^(Le|La)\s+", "", theme) # Retire les articles en début de phrase
327
 
328
- return JSONResponse({
329
  "question": question,
330
- "answer": f"Le document traite principalement de : {theme}",
331
- "confidence": 0.95, # Haut confiance car méthode spécialisée
332
  "language": language,
333
  "processing_method": "theme_analysis",
334
  "success": True
335
- })
336
-
337
- # ... reste du code pour les questions normales ...
338
-
339
- # ... reste du code pour les questions normales ...
340
 
341
  # Standard QA processing
342
- result = qa_model(question=request.question, context=clean_text)
 
 
343
 
344
- if result["score"] < 0.1: # Low confidence threshold
345
- return JSONResponse({
346
- "question": request.question,
 
 
 
347
  "answer": "No clear answer found in the document" if language == "en" else "Aucune réponse claire trouvée dans le document",
348
  "confidence": result["score"],
349
  "language": language,
350
  "warning": "low_confidence",
351
  "success": True
352
- })
353
 
354
- return JSONResponse({
355
- "question": request.question,
356
  "answer": result["answer"],
357
  "confidence": result["score"],
358
  "language": language,
359
  "success": True
360
- })
361
 
362
  except HTTPException:
363
  raise
@@ -369,66 +344,34 @@ async def question_answering(
369
  details={"error": str(e)}
370
  )
371
 
372
- ########################################################
373
- @app.get("/", response_class=HTMLResponse)
374
- def home ():
375
- with open("static/indexAI.html","r") as file :
376
- return file.read()
377
- # API Endpoints
378
- @app.get("/health")
379
- async def health_check():
380
- return {"status": "healthy"}
381
-
382
- @app.post("/summarize")
383
- async def summarize_document(file: UploadFile = File(...)):
384
- try:
385
- content = await file.read()
386
- file_ext = file.filename.split(".")[-1].lower()
387
- text = ""
388
-
389
- if file_ext == "docx":
390
- doc = Document(io.BytesIO(content))
391
- text = " ".join([p.text for p in doc.paragraphs if p.text.strip()])
392
- elif file_ext in ["xls", "xlsx"]:
393
- df = pd.read_excel(io.BytesIO(content))
394
- text = " ".join(df.iloc[:, 0].dropna().astype(str).tolist())
395
- elif file_ext == "pptx":
396
- ppt = Presentation(io.BytesIO(content))
397
- text = " ".join([shape.text for slide in ppt.slides for shape in slide.shapes if hasattr(shape, "text")])
398
- elif file_ext == "pdf":
399
- pdf = fitz.open(stream=content, filetype="pdf")
400
- text = " ".join([page.get_text("text") for page in pdf])
401
- elif file_ext in ["jpg", "jpeg", "png"]:
402
- image = Image.open(io.BytesIO(content))
403
- text = get_image_captioning()(image)[0]['generated_text']
404
- else:
405
- raise HTTPException(400, "Unsupported file format")
406
-
407
- if not text.strip():
408
- raise HTTPException(400, "No extractable text found")
409
-
410
- summarizer = get_summarizer()
411
- chunks = [text[i:i+1000] for i in range(0, len(text), 1000)]
412
- summary = " ".join([
413
- summarizer(chunk, max_length=150, min_length=50, do_sample=False)[0]["summary_text"]
414
- for chunk in chunks
415
- ])
416
-
417
- return {"summary": summary}
418
- except Exception as e:
419
- raise HTTPException(500, f"Error processing document: {str(e)}")
420
- #################################################################
421
-
422
- ###############################################
423
-
424
  @app.post("/api/caption")
425
  async def caption_image(file: UploadFile = File(...)):
426
  try:
427
- image = Image.open(io.BytesIO(await file.read())).convert("RGB")
428
- caption = get_image_captioning()(image)[0]['generated_text']
429
- return {"caption": caption}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  except Exception as e:
431
- raise HTTPException(500, f"Error processing image: {str(e)}")
 
 
 
 
 
432
 
433
  @app.post("/translate")
434
  async def translate_text(
@@ -437,10 +380,22 @@ async def translate_text(
437
  src_lang: str = Form("eng_Latn")
438
  ):
439
  try:
440
- translated = get_translator()(text, src_lang=src_lang, tgt_lang=target_lang)
441
- return {"translated_text": translated[0]["translation_text"]}
 
 
 
 
 
 
 
442
  except Exception as e:
443
- raise HTTPException(500, f"Error translating text: {str(e)}")
 
 
 
 
 
444
 
445
  # Run the application
446
  if __name__ == "__main__":
 
1
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, status, Request
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import JSONResponse, HTMLResponse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from fastapi.encoders import jsonable_encoder
5
+ from fastapi.staticfiles import StaticFiles
6
+ from fastapi.templating import Jinja2Templates
7
  from transformers import pipeline, Pipeline
8
+ from typing import Dict, Optional, Tuple, List
9
  from pydantic import BaseModel, constr, validator
10
  import io
11
  import fitz # PyMuPDF
 
20
  from datetime import datetime
21
  from pathlib import Path
22
  import re
23
+ import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # Configure logging
26
  logging.basicConfig(level=logging.INFO)
27
  logger = logging.getLogger(__name__)
28
 
29
+ # Initialize FastAPI app
30
  app = FastAPI(
31
  title="AI Document Analysis API",
32
  description="Advanced document processing with multilingual support",
 
35
  redoc_url="/redoc"
36
  )
37
 
38
+ # Configure CORS
39
  app.add_middleware(
40
  CORSMiddleware,
41
  allow_origins=["*"],
 
44
  allow_headers=["*"],
45
  )
46
 
47
+ # Set up templates
48
+ templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates"))
49
+
50
+ # Serve static files
51
+ app.mount("/static", StaticFiles(directory="static"), name="static")
52
+
53
  # Constants
54
  MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
55
  MAX_TEXT_LENGTH = 2000
 
119
 
120
  # Exception Handler
121
  @app.exception_handler(HTTPException)
122
+ async def http_exception_handler(request: Request, exc: HTTPException):
123
  error_response = ErrorResponse(
124
  error=exc.detail,
125
  status_code=exc.status_code,
 
206
  text = re.sub(r'\s+', ' ', text).strip()
207
  return text[:MAX_TEXT_LENGTH] if len(text) > MAX_TEXT_LENGTH else text
208
 
209
+ def chunk_text(text: str, chunk_size: int = 1000) -> List[str]:
210
+ """Split text into chunks for processing."""
211
+ return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
212
+
213
+ # API Endpoints
214
+ @app.get("/", response_class=HTMLResponse)
215
+ async def home(request: Request):
216
+ return templates.TemplateResponse("index.html", {"request": request})
217
+
218
+ @app.get("/health")
219
+ async def health_check():
220
+ return {"status": "healthy", "timestamp": datetime.now().isoformat()}
221
+
222
+ @app.post("/summarize")
223
+ async def summarize_document(file: UploadFile = File(...)):
224
+ try:
225
+ file_ext, content = await validate_and_read_file(file)
226
+ text = preprocess_text(extract_text(content, file_ext))
227
+
228
+ if not text.strip():
229
+ raise HTTPException(
230
+ status_code=status.HTTP_400_BAD_REQUEST,
231
+ detail="No extractable text found in document"
232
+ )
233
+
234
+ model_name = MODEL_MAPPING.get("en", {}).get("summarization", "facebook/bart-large-cnn")
235
+ summarizer = get_model(model_name, "summarization")
236
+
237
+ chunks = chunk_text(text)
238
+ summaries = []
239
+ for chunk in chunks:
240
+ summary = summarizer(chunk, max_length=150, min_length=50, do_sample=False)[0]["summary_text"]
241
+ summaries.append(summary)
242
+
243
+ return {
244
+ "success": True,
245
+ "summary": " ".join(summaries),
246
+ "language": "en",
247
+ "processed_chunks": len(chunks)
248
+ }
249
+ except HTTPException:
250
+ raise
251
+ except Exception as e:
252
+ logger.error(f"Summarization failed: {str(e)}")
253
+ raise HTTPException(
254
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
255
+ detail="Document summarization failed",
256
+ details={"error": str(e)}
257
+ )
258
+
259
  @app.post("/qa")
260
  async def question_answering(
261
  file: UploadFile = File(...),
262
  question: str = Form(...),
263
  language: str = Form(DEFAULT_LANGUAGE)
264
+ ):
265
  try:
 
266
  file_ext, content = await validate_and_read_file(file)
267
  text = preprocess_text(extract_text(content, file_ext))
268
 
269
+ # Theme detection
270
  theme_keywords = {
271
  "fr": ["thème", "sujet principal", "quoi le sujet"],
272
  "en": ["theme", "main topic", "what is about"]
 
278
  )
279
 
280
  if is_theme_question:
281
+ model_name = MODEL_MAPPING.get(language, {}).get("summarization")
282
+ if not model_name:
283
+ model_name = MODEL_MAPPING["default"].get("summarization")
284
+
285
+ generator = get_model(model_name, "text-generation")
286
  theme_prompt = (
287
+ "Extract the main theme of this text in 1-2 sentences. "
288
+ "Respond as if explaining to a beginner. "
289
+ "Text: {text}"
290
  )
291
 
 
 
292
  response = generator(
293
  theme_prompt.format(text=text[:2000]),
294
  max_length=200,
 
296
  do_sample=False
297
  )
298
 
 
299
  theme = response[0]["generated_text"].split(":")[-1].strip()
300
+ theme = re.sub(r"^(Le|La)\s+", "", theme)
301
 
302
+ return {
303
  "question": question,
304
+ "answer": f"The document mainly discusses: {theme}",
305
+ "confidence": 0.95,
306
  "language": language,
307
  "processing_method": "theme_analysis",
308
  "success": True
309
+ }
 
 
 
 
310
 
311
  # Standard QA processing
312
+ model_name = MODEL_MAPPING.get(language, {}).get("qa")
313
+ if not model_name:
314
+ model_name = MODEL_MAPPING["default"].get("qa")
315
 
316
+ qa_model = get_model(model_name, "question-answering")
317
+ result = qa_model(question=question, context=text)
318
+
319
+ if result["score"] < 0.1:
320
+ return {
321
+ "question": question,
322
  "answer": "No clear answer found in the document" if language == "en" else "Aucune réponse claire trouvée dans le document",
323
  "confidence": result["score"],
324
  "language": language,
325
  "warning": "low_confidence",
326
  "success": True
327
+ }
328
 
329
+ return {
330
+ "question": question,
331
  "answer": result["answer"],
332
  "confidence": result["score"],
333
  "language": language,
334
  "success": True
335
+ }
336
 
337
  except HTTPException:
338
  raise
 
344
  details={"error": str(e)}
345
  )
346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  @app.post("/api/caption")
348
  async def caption_image(file: UploadFile = File(...)):
349
  try:
350
+ file_ext, content = await validate_and_read_file(file)
351
+ if file_ext not in {"jpg", "jpeg", "png"}:
352
+ raise HTTPException(
353
+ status_code=status.HTTP_400_BAD_REQUEST,
354
+ detail="Only image files are supported for captioning"
355
+ )
356
+
357
+ image = Image.open(io.BytesIO(content)).convert("RGB")
358
+ captioner = get_model(MODEL_MAPPING["default"]["image_captioning"], "image-to-text")
359
+ caption = captioner(image)[0]['generated_text']
360
+
361
+ return {
362
+ "success": True,
363
+ "caption": caption,
364
+ "file_type": file_ext
365
+ }
366
+ except HTTPException:
367
+ raise
368
  except Exception as e:
369
+ logger.error(f"Image captioning failed: {str(e)}")
370
+ raise HTTPException(
371
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
372
+ detail="Image captioning failed",
373
+ details={"error": str(e)}
374
+ )
375
 
376
  @app.post("/translate")
377
  async def translate_text(
 
380
  src_lang: str = Form("eng_Latn")
381
  ):
382
  try:
383
+ translator = get_model(MODEL_MAPPING["default"]["multilingual_translation"], "translation")
384
+ translated = translator(text, src_lang=src_lang, tgt_lang=target_lang)
385
+
386
+ return {
387
+ "success": True,
388
+ "translated_text": translated[0]["translation_text"],
389
+ "source_language": src_lang,
390
+ "target_language": target_lang
391
+ }
392
  except Exception as e:
393
+ logger.error(f"Translation failed: {str(e)}")
394
+ raise HTTPException(
395
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
396
+ detail="Text translation failed",
397
+ details={"error": str(e)}
398
+ )
399
 
400
  # Run the application
401
  if __name__ == "__main__":