chenguittiMaroua commited on
Commit
d69a8fd
·
verified ·
1 Parent(s): 9b8b8b8

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +98 -64
main.py CHANGED
@@ -136,56 +136,70 @@ import torch
136
 
137
  # Model options (ordered by preference)
138
  QA_MODELS = [
139
- "google/flan-t5-small", # Lightweight default
140
- "google/flan-t5-base", # Medium option
141
- "facebook/bart-large-cnn" # Fallback option
142
  ]
143
 
144
- qa_model = None
145
- current_model_name = None
 
 
 
 
146
 
147
- def get_qa_model():
148
- global qa_model, current_model_name
149
-
150
- if qa_model is not None:
151
- return qa_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
- # Try each model in order until one works
154
- for model_name in QA_MODELS:
155
  try:
156
- logger.info(f"Attempting to load model: {model_name}")
 
 
 
157
 
158
- tokenizer = AutoTokenizer.from_pretrained(model_name)
159
- model = AutoModelForSeq2SeqLM.from_pretrained(
160
- model_name,
161
- device_map="auto" if torch.cuda.is_available() else None,
162
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
163
- )
164
 
165
- qa_model = pipeline(
166
- "text2text-generation",
167
- model=model,
168
- tokenizer=tokenizer,
169
- device=0 if torch.cuda.is_available() else -1
170
  )
171
 
172
- current_model_name = model_name
173
- logger.info(f"Successfully loaded model: {model_name}")
174
- return qa_model
175
 
176
  except Exception as e:
177
- logger.warning(f"Failed to load {model_name}: {str(e)}")
178
- continue
 
 
 
 
179
 
180
- logger.error("All model loading attempts failed")
181
- raise HTTPException(
182
- status_code=500,
183
- detail={
184
- "error": "QA system initialization failed",
185
- "tried_models": QA_MODELS,
186
- "suggestion": "Check available memory or try smaller models"
187
- }
188
- )
189
 
190
 
191
 
@@ -878,43 +892,63 @@ from typing import Optional
878
 
879
  @app.post("/qa")
880
  async def question_answering(
881
- request: Request,
882
  question: str = Form(...),
883
  file: Optional[UploadFile] = File(None),
884
  language: str = Form("en")
885
  ):
 
 
 
 
 
 
 
 
 
 
 
 
886
  try:
887
- # Initialize model (with fallback)
 
 
 
 
 
 
 
 
 
 
 
888
  try:
889
- qa_pipeline = get_qa_model()
 
 
 
 
 
 
 
 
 
890
  except Exception as e:
891
- logger.critical(f"Model loading failed: {str(e)}")
892
- raise HTTPException(500, "Could not initialize any QA model")
893
-
894
- # Rest of your existing endpoint logic...
895
- # [Keep all your existing file processing and QA code]
896
-
897
- return {
898
- "question": question,
899
- "answer": result[0]["generated_text"],
900
- "model_used": current_model_name, # Add this to responses
901
- "source": "document" if file else "general knowledge",
902
- "language": language
903
- }
904
 
905
  except HTTPException:
906
  raise
907
  except Exception as e:
908
- logger.error(f"QA processing failed: {str(e)}")
909
- raise HTTPException(
910
- 500,
911
- detail={
912
- "error": "QA processing failed",
913
- "model": current_model_name,
914
- "input_question": question[:100] + "..." if question else None,
915
- "file_type": file.filename.split('.')[-1] if file else None
916
- }
917
- )
918
 
919
 
920
 
 
136
 
137
  # Model options (ordered by preference)
138
  QA_MODELS = [
139
+ {"name": "google/flan-t5-small", "max_length": 512},
140
+ {"name": "facebook/bart-large-cnn", "max_length": 1024}
 
141
  ]
142
 
143
+ class QASystem:
144
+ def __init__(self):
145
+ self.model = None
146
+ self.tokenizer = None
147
+ self.current_model = None
148
+ self.device = 0 if torch.cuda.is_available() else -1
149
 
150
+ def load_model(self):
151
+ for model_info in QA_MODELS:
152
+ try:
153
+ logger.info(f"Loading model: {model_info['name']}")
154
+
155
+ self.tokenizer = AutoTokenizer.from_pretrained(model_info["name"])
156
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(
157
+ model_info["name"],
158
+ device_map="auto",
159
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
160
+ )
161
+ self.current_model = model_info
162
+ logger.info(f"Successfully loaded {model_info['name']}")
163
+ return True
164
+
165
+ except Exception as e:
166
+ logger.warning(f"Failed to load {model_info['name']}: {str(e)}")
167
+ continue
168
+
169
+ logger.error("All model loading attempts failed")
170
+ return False
171
 
172
+ def generate_answer(self, question: str, context: Optional[str] = None):
 
173
  try:
174
+ if context:
175
+ input_text = f"question: {question} context: {context[:2000]}"
176
+ else:
177
+ input_text = f"question: {question}"
178
 
179
+ inputs = self.tokenizer(
180
+ input_text,
181
+ return_tensors="pt",
182
+ truncation=True,
183
+ max_length=self.current_model["max_length"]
184
+ ).to(self.device)
185
 
186
+ outputs = self.model.generate(
187
+ **inputs,
188
+ max_new_tokens=200,
189
+ num_beams=4,
190
+ early_stopping=True
191
  )
192
 
193
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
194
 
195
  except Exception as e:
196
+ logger.error(f"Generation failed: {str(e)}")
197
+ raise
198
+
199
+ # Initialize QA system
200
+ qa_system = QASystem()
201
+
202
 
 
 
 
 
 
 
 
 
 
203
 
204
 
205
 
 
892
 
893
  @app.post("/qa")
894
  async def question_answering(
 
895
  question: str = Form(...),
896
  file: Optional[UploadFile] = File(None),
897
  language: str = Form("en")
898
  ):
899
+ # Initialize model if not loaded
900
+ if not qa_system.model:
901
+ if not qa_system.load_model():
902
+ raise HTTPException(
903
+ 500,
904
+ detail={
905
+ "error": "System initialization failed",
906
+ "tried_models": [m["name"] for m in QA_MODELS],
907
+ "suggestion": "Check logs for loading errors"
908
+ }
909
+ )
910
+
911
  try:
912
+ # Process file if provided
913
+ context = None
914
+ if file:
915
+ try:
916
+ file_ext, content = await process_uploaded_file(file)
917
+ context = extract_text(content, file_ext)
918
+ context = re.sub(r'\s+', ' ', context).strip()[:3000]
919
+ except Exception as e:
920
+ logger.error(f"File processing failed: {str(e)}")
921
+ raise HTTPException(422, detail=f"File processing error: {str(e)}")
922
+
923
+ # Generate answer
924
  try:
925
+ answer = qa_system.generate_answer(question, context)
926
+
927
+ return {
928
+ "question": question,
929
+ "answer": answer,
930
+ "model": qa_system.current_model["name"],
931
+ "source": "document" if context else "general",
932
+ "language": language
933
+ }
934
+
935
  except Exception as e:
936
+ logger.error(f"Answer generation failed: {str(e)}")
937
+ raise HTTPException(
938
+ 500,
939
+ detail={
940
+ "error": "Answer generation failed",
941
+ "model": qa_system.current_model["name"],
942
+ "input_length": len(question) + (len(context) if context else 0),
943
+ "suggestion": "Try simplifying your question or reducing document size"
944
+ }
945
+ )
 
 
 
946
 
947
  except HTTPException:
948
  raise
949
  except Exception as e:
950
+ logger.critical(f"Unexpected error: {str(e)}")
951
+ raise HTTPException(500, "Internal server error")
 
 
 
 
 
 
 
 
952
 
953
 
954