chenguittiMaroua commited on
Commit
473762c
·
verified ·
1 Parent(s): 38bf145

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +245 -41
main.py CHANGED
@@ -74,42 +74,260 @@ def get_qa_model():
74
  #########################################################
75
 
76
 
77
- models_cache: Dict[str, pipeline] = {}
78
-
79
- def get_model(model_name: str, task: str):
80
- if model_name not in models_cache:
81
- models_cache[model_name] = pipeline(task, model=model_name)
82
- return models_cache[model_name]
 
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
 
 
85
 
86
-
87
- def extract_text_from_file(file_content: bytes, file_ext: str):
88
- text = ""
89
  try:
90
  if file_ext == "docx":
91
- doc = Document(io.BytesIO(file_content))
92
- text = " ".join([p.text for p in doc.paragraphs if p.text.strip()])
93
- elif file_ext in ["xls", "xlsx"]:
94
- df = pd.read_excel(io.BytesIO(file_content))
95
- text = " ".join(df.iloc[:, 0].dropna().astype(str).tolist())
 
 
96
  elif file_ext == "pptx":
97
- ppt = Presentation(io.BytesIO(file_content))
98
- text = " ".join([shape.text for slide in ppt.slides for shape in slide.shapes if hasattr(shape, "text")])
 
 
99
  elif file_ext == "pdf":
100
- pdf = fitz.open(stream=file_content, filetype="pdf")
101
- text = " ".join([page.get_text("text") for page in pdf])
102
- elif file_ext in ["jpg", "jpeg", "png"]:
103
- image = Image.open(io.BytesIO(file_content))
104
- text = pytesseract.image_to_string(image, config='--psm 6')
105
- else:
106
- raise HTTPException(status_code=400, detail="Unsupported file format.")
107
  except Exception as e:
108
- raise HTTPException(status_code=500, detail=f"Error extracting text: {str(e)}")
 
 
 
 
 
 
 
 
 
 
109
 
110
- if not text.strip():
111
- raise HTTPException(status_code=400, detail="No extractable text found.")
112
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  ########################################################
115
  @app.get("/", response_class=HTMLResponse)
@@ -160,20 +378,6 @@ async def summarize_document(file: UploadFile = File(...)):
160
  except Exception as e:
161
  raise HTTPException(500, f"Error processing document: {str(e)}")
162
  #################################################################
163
- @app.post("/qa")
164
- async def question_answering(file: UploadFile = File(...), question: str = Form(...)):
165
- content = await file.read()
166
- file_ext = file.filename.split(".")[-1].lower()
167
- extracted_text = extract_text_from_file(content, file_ext)
168
- # Use a pipeline as a high-level helper
169
- summarizer = get_model("google-bert/bert-large-uncased-whole-word-masking-finetuned-squad", "summarization")
170
- if len(extracted_text) > 2000:
171
- extracted_text = summarizer(extracted_text[:2000], max_length=500, min_length=100, do_sample=False)[0]["summary_text"]
172
-
173
- qa_model = get_model("distilbert-base-cased-distilled-squad", "question-answering")
174
- answer = qa_model(question=question, context=extracted_text)
175
-
176
- return {"question": question, "answer": answer["answer"], "context_used": extracted_text}
177
 
178
  ###############################################
179
 
 
74
  #########################################################
75
 
76
 
77
+ # CORS Configuration
78
+ app.add_middleware(
79
+ CORSMiddleware,
80
+ allow_origins=["*"],
81
+ allow_credentials=True,
82
+ allow_methods=["*"],
83
+ allow_headers=["*"],
84
+ )
85
 
86
+ # Constants
87
+ MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
88
+ MAX_TEXT_LENGTH = 2000
89
+ MAX_QUESTION_LENGTH = 500
90
+ MIN_QUESTION_LENGTH = 3
91
+ SUPPORTED_LANGUAGES = {"fr", "en", "es", "de"}
92
+ DEFAULT_LANGUAGE = "fr"
93
+
94
+ SUPPORTED_FILE_TYPES = {
95
+ "docx": "Word Document",
96
+ "xlsx": "Excel Spreadsheet",
97
+ "pptx": "PowerPoint Presentation",
98
+ "pdf": "PDF Document",
99
+ "jpg": "JPEG Image",
100
+ "jpeg": "JPEG Image",
101
+ "png": "PNG Image"
102
+ }
103
+
104
+
105
+
106
+
107
+
108
+ MODEL_MAPPING = {
109
+ "fr": {
110
+ "qa": "illuin/camembert-base-fquad",
111
+ "summarization": "moussaKam/barthez-orangesum-abstract",
112
+ "translation": "Helsinki-NLP/opus-mt-fr-en"
113
+ },
114
+ "en": {
115
+ "qa": "deepset/roberta-base-squad2",
116
+ "summarization": "facebook/bart-large-cnn",
117
+ "translation": "Helsinki-NLP/opus-mt-en-fr"
118
+ },
119
+ "default": {
120
+ "image_captioning": "Salesforce/blip-image-captioning-large",
121
+ "multilingual_translation": "facebook/nllb-200-distilled-600M"
122
+ }
123
+ }
124
+
125
+ # Models cache
126
+ models_cache: Dict[str, Pipeline] = {}
127
+
128
+ # Pydantic Models
129
+ class TranslationRequest(BaseModel):
130
+ text: constr(min_length=1, max_length=5000)
131
+ target_lang: constr(min_length=2, max_length=5)
132
+ src_lang: Optional[constr(min_length=2, max_length=5)] = None
133
+
134
+ @validator('target_lang', 'src_lang')
135
+ def validate_language_code(cls, v):
136
+ if v and len(v) not in {2, 5}:
137
+ raise ValueError("Language code must be 2 or 5 characters")
138
+ return v
139
+
140
+ class QARequest(BaseModel):
141
+ question: constr(min_length=MIN_QUESTION_LENGTH, max_length=MAX_QUESTION_LENGTH)
142
+ language: constr(min_length=2, max_length=2) = DEFAULT_LANGUAGE
143
+
144
+ @validator('language')
145
+ def validate_language(cls, v):
146
+ if v.lower() not in SUPPORTED_LANGUAGES:
147
+ raise ValueError(f"Unsupported language. Supported: {SUPPORTED_LANGUAGES}")
148
+ return v.lower()
149
+
150
+ class ErrorResponse(BaseModel):
151
+ error: str
152
+ success: bool = False
153
+ status_code: int
154
+ timestamp: str
155
+ details: Optional[dict] = None
156
+
157
+ # Exception Handler
158
+ @app.exception_handler(HTTPException)
159
+ async def http_exception_handler(request, exc):
160
+ error_response = ErrorResponse(
161
+ error=exc.detail,
162
+ status_code=exc.status_code,
163
+ timestamp=datetime.now().isoformat(),
164
+ details=getattr(exc, 'details', None)
165
+ )
166
+ return JSONResponse(
167
+ status_code=exc.status_code,
168
+ content=jsonable_encoder(error_response)
169
+ )
170
+
171
+ # Helper Functions
172
+ def get_model(model_name: str, task: str) -> Pipeline:
173
+ """Get or load a Hugging Face model with caching."""
174
+ cache_key = f"{model_name}_{task}"
175
+ if cache_key not in models_cache:
176
+ try:
177
+ logger.info(f"Loading model: {model_name} for task: {task}")
178
+ models_cache[cache_key] = pipeline(task, model=model_name)
179
+ except Exception as e:
180
+ logger.error(f"Model loading failed: {str(e)}")
181
+ raise HTTPException(
182
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
183
+ detail="Model service unavailable",
184
+ details={"model": model_name, "error": str(e)}
185
+ )
186
+ return models_cache[cache_key]
187
+
188
+ async def validate_and_read_file(file: UploadFile) -> Tuple[str, bytes]:
189
+ """Validate and read uploaded file."""
190
+ # Check file extension
191
+ file_ext = Path(file.filename).suffix[1:].lower()
192
+ if file_ext not in SUPPORTED_FILE_TYPES:
193
+ raise HTTPException(
194
+ status_code=status.HTTP_400_BAD_REQUEST,
195
+ detail=f"Unsupported file type. Supported: {', '.join(SUPPORTED_FILE_TYPES.values())}"
196
+ )
197
+
198
+ # Read and check file size
199
+ content = await file.read()
200
+ if len(content) > MAX_FILE_SIZE:
201
+ raise HTTPException(
202
+ status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
203
+ detail=f"File exceeds maximum size of {MAX_FILE_SIZE//1024//1024}MB"
204
+ )
205
 
206
+ await file.seek(0)
207
+ return file_ext, content
208
 
209
+ def extract_text(content: bytes, file_ext: str) -> str:
210
+ """Extract text from various file formats."""
 
211
  try:
212
  if file_ext == "docx":
213
+ doc = Document(io.BytesIO(content))
214
+ return " ".join(p.text for p in doc.paragraphs if p.text.strip())
215
+
216
+ elif file_ext in {"xls", "xlsx"}:
217
+ df = pd.read_excel(io.BytesIO(content))
218
+ return " ".join(df.iloc[:, 0].dropna().astype(str).tolist())
219
+
220
  elif file_ext == "pptx":
221
+ ppt = Presentation(io.BytesIO(content))
222
+ return " ".join(shape.text for slide in ppt.slides
223
+ for shape in slide.shapes if hasattr(shape, "text"))
224
+
225
  elif file_ext == "pdf":
226
+ pdf = fitz.open(stream=content, filetype="pdf")
227
+ return " ".join(page.get_text("text") for page in pdf)
228
+
229
+ elif file_ext in {"jpg", "jpeg", "png"}:
230
+ image = Image.open(io.BytesIO(content))
231
+ return pytesseract.image_to_string(image, config='--psm 6')
232
+
233
  except Exception as e:
234
+ logger.error(f"Text extraction failed: {str(e)}")
235
+ raise HTTPException(
236
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
237
+ detail="Failed to extract text from file",
238
+ details={"error": str(e), "file_type": file_ext}
239
+ )
240
+
241
+ def preprocess_text(text: str) -> str:
242
+ """Clean and normalize extracted text."""
243
+ text = re.sub(r'\s+', ' ', text).strip()
244
+ return text[:MAX_TEXT_LENGTH] if len(text) > MAX_TEXT_LENGTH else text
245
 
246
+ # API Endpoints
247
+ @app.post("/qa")
248
+ async def question_answering(
249
+ file: UploadFile = File(...),
250
+ question: str = Form(...),
251
+ language: str = Form(DEFAULT_LANGUAGE)
252
+ ) -> JSONResponse:
253
+ try:
254
+ # Validation et extraction du texte
255
+ file_ext, content = await validate_and_read_file(file)
256
+ text = preprocess_text(extract_text(content, file_ext))
257
+
258
+ # Détection spéciale pour les questions sur le thème
259
+ theme_keywords = {
260
+ "fr": ["thème", "sujet principal", "quoi le sujet"],
261
+ "en": ["theme", "main topic", "what is about"]
262
+ }
263
+
264
+ is_theme_question = any(
265
+ kw in question.lower()
266
+ for kw in theme_keywords.get(language, theme_keywords["en"])
267
+ )
268
+
269
+ if is_theme_question:
270
+ # Utilisation d'un prompt spécialisé pour l'analyse thématique
271
+ theme_prompt = (
272
+ "Extrayez le thème principal de ce texte en 1-2 phrases. "
273
+ "Répondez comme si vous expliquiez à un novice. "
274
+ "Texte : {text}"
275
+ )
276
+
277
+ # Utilisation d'un LLM plus puissant pour l'analyse thématique
278
+ generator = get_model("moussaKam/barthez-orangesum-abstract", "text-generation")
279
+ response = generator(
280
+ theme_prompt.format(text=text[:2000]),
281
+ max_length=200,
282
+ num_return_sequences=1,
283
+ do_sample=False
284
+ )
285
+
286
+ # Nettoyage de la réponse
287
+ theme = response[0]["generated_text"].split(":")[-1].strip()
288
+ theme = re.sub(r"^(Le|La)\s+", "", theme) # Retire les articles en début de phrase
289
+
290
+ return JSONResponse({
291
+ "question": question,
292
+ "answer": f"Le document traite principalement de : {theme}",
293
+ "confidence": 0.95, # Haut confiance car méthode spécialisée
294
+ "language": language,
295
+ "processing_method": "theme_analysis",
296
+ "success": True
297
+ })
298
+
299
+ # ... reste du code pour les questions normales ...
300
+
301
+ # Standard QA processing
302
+ result = qa_model(question=request.question, context=clean_text)
303
+
304
+ if result["score"] < 0.1: # Low confidence threshold
305
+ return JSONResponse({
306
+ "question": request.question,
307
+ "answer": "No clear answer found in the document" if language == "en" else "Aucune réponse claire trouvée dans le document",
308
+ "confidence": result["score"],
309
+ "language": language,
310
+ "warning": "low_confidence",
311
+ "success": True
312
+ })
313
+
314
+ return JSONResponse({
315
+ "question": request.question,
316
+ "answer": result["answer"],
317
+ "confidence": result["score"],
318
+ "language": language,
319
+ "success": True
320
+ })
321
+
322
+ except HTTPException:
323
+ raise
324
+ except Exception as e:
325
+ logger.error(f"QA processing failed: {str(e)}")
326
+ raise HTTPException(
327
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
328
+ detail="Document analysis failed",
329
+ details={"error": str(e)}
330
+ )
331
 
332
  ########################################################
333
  @app.get("/", response_class=HTMLResponse)
 
378
  except Exception as e:
379
  raise HTTPException(500, f"Error processing document: {str(e)}")
380
  #################################################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
  ###############################################
383