Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -74,42 +74,260 @@ def get_qa_model():
|
|
74 |
#########################################################
|
75 |
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
|
|
|
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
text = ""
|
89 |
try:
|
90 |
if file_ext == "docx":
|
91 |
-
doc = Document(io.BytesIO(
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
96 |
elif file_ext == "pptx":
|
97 |
-
ppt = Presentation(io.BytesIO(
|
98 |
-
|
|
|
|
|
99 |
elif file_ext == "pdf":
|
100 |
-
pdf = fitz.open(stream=
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
except Exception as e:
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
########################################################
|
115 |
@app.get("/", response_class=HTMLResponse)
|
@@ -160,20 +378,6 @@ async def summarize_document(file: UploadFile = File(...)):
|
|
160 |
except Exception as e:
|
161 |
raise HTTPException(500, f"Error processing document: {str(e)}")
|
162 |
#################################################################
|
163 |
-
@app.post("/qa")
|
164 |
-
async def question_answering(file: UploadFile = File(...), question: str = Form(...)):
|
165 |
-
content = await file.read()
|
166 |
-
file_ext = file.filename.split(".")[-1].lower()
|
167 |
-
extracted_text = extract_text_from_file(content, file_ext)
|
168 |
-
# Use a pipeline as a high-level helper
|
169 |
-
summarizer = get_model("google-bert/bert-large-uncased-whole-word-masking-finetuned-squad", "summarization")
|
170 |
-
if len(extracted_text) > 2000:
|
171 |
-
extracted_text = summarizer(extracted_text[:2000], max_length=500, min_length=100, do_sample=False)[0]["summary_text"]
|
172 |
-
|
173 |
-
qa_model = get_model("distilbert-base-cased-distilled-squad", "question-answering")
|
174 |
-
answer = qa_model(question=question, context=extracted_text)
|
175 |
-
|
176 |
-
return {"question": question, "answer": answer["answer"], "context_used": extracted_text}
|
177 |
|
178 |
###############################################
|
179 |
|
|
|
74 |
#########################################################
|
75 |
|
76 |
|
77 |
+
# CORS Configuration
|
78 |
+
app.add_middleware(
|
79 |
+
CORSMiddleware,
|
80 |
+
allow_origins=["*"],
|
81 |
+
allow_credentials=True,
|
82 |
+
allow_methods=["*"],
|
83 |
+
allow_headers=["*"],
|
84 |
+
)
|
85 |
|
86 |
+
# Constants
|
87 |
+
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
|
88 |
+
MAX_TEXT_LENGTH = 2000
|
89 |
+
MAX_QUESTION_LENGTH = 500
|
90 |
+
MIN_QUESTION_LENGTH = 3
|
91 |
+
SUPPORTED_LANGUAGES = {"fr", "en", "es", "de"}
|
92 |
+
DEFAULT_LANGUAGE = "fr"
|
93 |
+
|
94 |
+
SUPPORTED_FILE_TYPES = {
|
95 |
+
"docx": "Word Document",
|
96 |
+
"xlsx": "Excel Spreadsheet",
|
97 |
+
"pptx": "PowerPoint Presentation",
|
98 |
+
"pdf": "PDF Document",
|
99 |
+
"jpg": "JPEG Image",
|
100 |
+
"jpeg": "JPEG Image",
|
101 |
+
"png": "PNG Image"
|
102 |
+
}
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
MODEL_MAPPING = {
|
109 |
+
"fr": {
|
110 |
+
"qa": "illuin/camembert-base-fquad",
|
111 |
+
"summarization": "moussaKam/barthez-orangesum-abstract",
|
112 |
+
"translation": "Helsinki-NLP/opus-mt-fr-en"
|
113 |
+
},
|
114 |
+
"en": {
|
115 |
+
"qa": "deepset/roberta-base-squad2",
|
116 |
+
"summarization": "facebook/bart-large-cnn",
|
117 |
+
"translation": "Helsinki-NLP/opus-mt-en-fr"
|
118 |
+
},
|
119 |
+
"default": {
|
120 |
+
"image_captioning": "Salesforce/blip-image-captioning-large",
|
121 |
+
"multilingual_translation": "facebook/nllb-200-distilled-600M"
|
122 |
+
}
|
123 |
+
}
|
124 |
+
|
125 |
+
# Models cache
|
126 |
+
models_cache: Dict[str, Pipeline] = {}
|
127 |
+
|
128 |
+
# Pydantic Models
|
129 |
+
class TranslationRequest(BaseModel):
|
130 |
+
text: constr(min_length=1, max_length=5000)
|
131 |
+
target_lang: constr(min_length=2, max_length=5)
|
132 |
+
src_lang: Optional[constr(min_length=2, max_length=5)] = None
|
133 |
+
|
134 |
+
@validator('target_lang', 'src_lang')
|
135 |
+
def validate_language_code(cls, v):
|
136 |
+
if v and len(v) not in {2, 5}:
|
137 |
+
raise ValueError("Language code must be 2 or 5 characters")
|
138 |
+
return v
|
139 |
+
|
140 |
+
class QARequest(BaseModel):
|
141 |
+
question: constr(min_length=MIN_QUESTION_LENGTH, max_length=MAX_QUESTION_LENGTH)
|
142 |
+
language: constr(min_length=2, max_length=2) = DEFAULT_LANGUAGE
|
143 |
+
|
144 |
+
@validator('language')
|
145 |
+
def validate_language(cls, v):
|
146 |
+
if v.lower() not in SUPPORTED_LANGUAGES:
|
147 |
+
raise ValueError(f"Unsupported language. Supported: {SUPPORTED_LANGUAGES}")
|
148 |
+
return v.lower()
|
149 |
+
|
150 |
+
class ErrorResponse(BaseModel):
|
151 |
+
error: str
|
152 |
+
success: bool = False
|
153 |
+
status_code: int
|
154 |
+
timestamp: str
|
155 |
+
details: Optional[dict] = None
|
156 |
+
|
157 |
+
# Exception Handler
|
158 |
+
@app.exception_handler(HTTPException)
|
159 |
+
async def http_exception_handler(request, exc):
|
160 |
+
error_response = ErrorResponse(
|
161 |
+
error=exc.detail,
|
162 |
+
status_code=exc.status_code,
|
163 |
+
timestamp=datetime.now().isoformat(),
|
164 |
+
details=getattr(exc, 'details', None)
|
165 |
+
)
|
166 |
+
return JSONResponse(
|
167 |
+
status_code=exc.status_code,
|
168 |
+
content=jsonable_encoder(error_response)
|
169 |
+
)
|
170 |
+
|
171 |
+
# Helper Functions
|
172 |
+
def get_model(model_name: str, task: str) -> Pipeline:
|
173 |
+
"""Get or load a Hugging Face model with caching."""
|
174 |
+
cache_key = f"{model_name}_{task}"
|
175 |
+
if cache_key not in models_cache:
|
176 |
+
try:
|
177 |
+
logger.info(f"Loading model: {model_name} for task: {task}")
|
178 |
+
models_cache[cache_key] = pipeline(task, model=model_name)
|
179 |
+
except Exception as e:
|
180 |
+
logger.error(f"Model loading failed: {str(e)}")
|
181 |
+
raise HTTPException(
|
182 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
183 |
+
detail="Model service unavailable",
|
184 |
+
details={"model": model_name, "error": str(e)}
|
185 |
+
)
|
186 |
+
return models_cache[cache_key]
|
187 |
+
|
188 |
+
async def validate_and_read_file(file: UploadFile) -> Tuple[str, bytes]:
|
189 |
+
"""Validate and read uploaded file."""
|
190 |
+
# Check file extension
|
191 |
+
file_ext = Path(file.filename).suffix[1:].lower()
|
192 |
+
if file_ext not in SUPPORTED_FILE_TYPES:
|
193 |
+
raise HTTPException(
|
194 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
195 |
+
detail=f"Unsupported file type. Supported: {', '.join(SUPPORTED_FILE_TYPES.values())}"
|
196 |
+
)
|
197 |
+
|
198 |
+
# Read and check file size
|
199 |
+
content = await file.read()
|
200 |
+
if len(content) > MAX_FILE_SIZE:
|
201 |
+
raise HTTPException(
|
202 |
+
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
203 |
+
detail=f"File exceeds maximum size of {MAX_FILE_SIZE//1024//1024}MB"
|
204 |
+
)
|
205 |
|
206 |
+
await file.seek(0)
|
207 |
+
return file_ext, content
|
208 |
|
209 |
+
def extract_text(content: bytes, file_ext: str) -> str:
|
210 |
+
"""Extract text from various file formats."""
|
|
|
211 |
try:
|
212 |
if file_ext == "docx":
|
213 |
+
doc = Document(io.BytesIO(content))
|
214 |
+
return " ".join(p.text for p in doc.paragraphs if p.text.strip())
|
215 |
+
|
216 |
+
elif file_ext in {"xls", "xlsx"}:
|
217 |
+
df = pd.read_excel(io.BytesIO(content))
|
218 |
+
return " ".join(df.iloc[:, 0].dropna().astype(str).tolist())
|
219 |
+
|
220 |
elif file_ext == "pptx":
|
221 |
+
ppt = Presentation(io.BytesIO(content))
|
222 |
+
return " ".join(shape.text for slide in ppt.slides
|
223 |
+
for shape in slide.shapes if hasattr(shape, "text"))
|
224 |
+
|
225 |
elif file_ext == "pdf":
|
226 |
+
pdf = fitz.open(stream=content, filetype="pdf")
|
227 |
+
return " ".join(page.get_text("text") for page in pdf)
|
228 |
+
|
229 |
+
elif file_ext in {"jpg", "jpeg", "png"}:
|
230 |
+
image = Image.open(io.BytesIO(content))
|
231 |
+
return pytesseract.image_to_string(image, config='--psm 6')
|
232 |
+
|
233 |
except Exception as e:
|
234 |
+
logger.error(f"Text extraction failed: {str(e)}")
|
235 |
+
raise HTTPException(
|
236 |
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
237 |
+
detail="Failed to extract text from file",
|
238 |
+
details={"error": str(e), "file_type": file_ext}
|
239 |
+
)
|
240 |
+
|
241 |
+
def preprocess_text(text: str) -> str:
|
242 |
+
"""Clean and normalize extracted text."""
|
243 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
244 |
+
return text[:MAX_TEXT_LENGTH] if len(text) > MAX_TEXT_LENGTH else text
|
245 |
|
246 |
+
# API Endpoints
|
247 |
+
@app.post("/qa")
|
248 |
+
async def question_answering(
|
249 |
+
file: UploadFile = File(...),
|
250 |
+
question: str = Form(...),
|
251 |
+
language: str = Form(DEFAULT_LANGUAGE)
|
252 |
+
) -> JSONResponse:
|
253 |
+
try:
|
254 |
+
# Validation et extraction du texte
|
255 |
+
file_ext, content = await validate_and_read_file(file)
|
256 |
+
text = preprocess_text(extract_text(content, file_ext))
|
257 |
+
|
258 |
+
# Détection spéciale pour les questions sur le thème
|
259 |
+
theme_keywords = {
|
260 |
+
"fr": ["thème", "sujet principal", "quoi le sujet"],
|
261 |
+
"en": ["theme", "main topic", "what is about"]
|
262 |
+
}
|
263 |
+
|
264 |
+
is_theme_question = any(
|
265 |
+
kw in question.lower()
|
266 |
+
for kw in theme_keywords.get(language, theme_keywords["en"])
|
267 |
+
)
|
268 |
+
|
269 |
+
if is_theme_question:
|
270 |
+
# Utilisation d'un prompt spécialisé pour l'analyse thématique
|
271 |
+
theme_prompt = (
|
272 |
+
"Extrayez le thème principal de ce texte en 1-2 phrases. "
|
273 |
+
"Répondez comme si vous expliquiez à un novice. "
|
274 |
+
"Texte : {text}"
|
275 |
+
)
|
276 |
+
|
277 |
+
# Utilisation d'un LLM plus puissant pour l'analyse thématique
|
278 |
+
generator = get_model("moussaKam/barthez-orangesum-abstract", "text-generation")
|
279 |
+
response = generator(
|
280 |
+
theme_prompt.format(text=text[:2000]),
|
281 |
+
max_length=200,
|
282 |
+
num_return_sequences=1,
|
283 |
+
do_sample=False
|
284 |
+
)
|
285 |
+
|
286 |
+
# Nettoyage de la réponse
|
287 |
+
theme = response[0]["generated_text"].split(":")[-1].strip()
|
288 |
+
theme = re.sub(r"^(Le|La)\s+", "", theme) # Retire les articles en début de phrase
|
289 |
+
|
290 |
+
return JSONResponse({
|
291 |
+
"question": question,
|
292 |
+
"answer": f"Le document traite principalement de : {theme}",
|
293 |
+
"confidence": 0.95, # Haut confiance car méthode spécialisée
|
294 |
+
"language": language,
|
295 |
+
"processing_method": "theme_analysis",
|
296 |
+
"success": True
|
297 |
+
})
|
298 |
+
|
299 |
+
# ... reste du code pour les questions normales ...
|
300 |
+
|
301 |
+
# Standard QA processing
|
302 |
+
result = qa_model(question=request.question, context=clean_text)
|
303 |
+
|
304 |
+
if result["score"] < 0.1: # Low confidence threshold
|
305 |
+
return JSONResponse({
|
306 |
+
"question": request.question,
|
307 |
+
"answer": "No clear answer found in the document" if language == "en" else "Aucune réponse claire trouvée dans le document",
|
308 |
+
"confidence": result["score"],
|
309 |
+
"language": language,
|
310 |
+
"warning": "low_confidence",
|
311 |
+
"success": True
|
312 |
+
})
|
313 |
+
|
314 |
+
return JSONResponse({
|
315 |
+
"question": request.question,
|
316 |
+
"answer": result["answer"],
|
317 |
+
"confidence": result["score"],
|
318 |
+
"language": language,
|
319 |
+
"success": True
|
320 |
+
})
|
321 |
+
|
322 |
+
except HTTPException:
|
323 |
+
raise
|
324 |
+
except Exception as e:
|
325 |
+
logger.error(f"QA processing failed: {str(e)}")
|
326 |
+
raise HTTPException(
|
327 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
328 |
+
detail="Document analysis failed",
|
329 |
+
details={"error": str(e)}
|
330 |
+
)
|
331 |
|
332 |
########################################################
|
333 |
@app.get("/", response_class=HTMLResponse)
|
|
|
378 |
except Exception as e:
|
379 |
raise HTTPException(500, f"Error processing document: {str(e)}")
|
380 |
#################################################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
381 |
|
382 |
###############################################
|
383 |
|