DawnC commited on
Commit
c4c78dc
·
verified ·
1 Parent(s): c3fc925

Update semantic_breed_recommender.py

Browse files
Files changed (1) hide show
  1. semantic_breed_recommender.py +45 -14
semantic_breed_recommender.py CHANGED
@@ -37,6 +37,7 @@ class SemanticBreedRecommender:
37
  """Initialize the semantic recommender"""
38
  self.model_name = 'all-MiniLM-L6-v2' # Efficient SBERT model
39
  self.sbert_model = None
 
40
  self.breed_vectors = {}
41
  self.breed_list = self._get_breed_list()
42
  self.comparative_keywords = {
@@ -44,13 +45,9 @@ class SemanticBreedRecommender:
44
  'then': 0.7, 'second': 0.7, 'followed': 0.6,
45
  'third': 0.5, 'least': 0.3, 'dislike': 0.2
46
  }
47
- # self.query_engine = QueryUnderstandingEngine()
48
- # self.constraint_manager = ConstraintManager()
49
- # self.multi_head_scorer = None # Will be initialized with SBERT model
50
- # self.score_calibrator = ScoreCalibrator()
51
- # self.config_manager = get_config_manager()
52
- self._initialize_model()
53
- self._build_breed_vectors()
54
 
55
  # Initialize multi-head scorer with SBERT model if enhanced mode is enabled
56
  # if self.sbert_model:
@@ -74,18 +71,24 @@ class SemanticBreedRecommender:
74
  'Bulldog', 'Poodle', 'Beagle', 'Rottweiler', 'Yorkshire_Terrier']
75
 
76
  def _initialize_model(self):
77
- """Initialize SBERT model with fallback"""
 
 
 
78
  try:
79
- print("Loading SBERT model...")
80
  # Try different model names if the primary one fails
81
  model_options = ['all-MiniLM-L6-v2', 'all-mpnet-base-v2', 'all-MiniLM-L12-v2']
82
 
83
  for model_name in model_options:
84
  try:
85
- self.sbert_model = SentenceTransformer(model_name)
 
 
 
86
  self.model_name = model_name
87
- print(f"SBERT model {model_name} loaded successfully")
88
- return
89
  except Exception as model_e:
90
  print(f"Failed to load {model_name}: {str(model_e)}")
91
  continue
@@ -93,12 +96,16 @@ class SemanticBreedRecommender:
93
  # If all models fail
94
  print("All SBERT models failed to load. Using basic text matching fallback.")
95
  self.sbert_model = None
 
96
 
97
  except Exception as e:
98
  print(f"Failed to initialize any SBERT model: {str(e)}")
99
  print(traceback.format_exc())
100
  print("Will provide basic text-based recommendations without embeddings")
101
  self.sbert_model = None
 
 
 
102
 
103
  def _create_breed_description(self, breed: str) -> str:
104
  """Create comprehensive natural language description for breed with all key characteristics"""
@@ -321,10 +328,14 @@ class SemanticBreedRecommender:
321
  return f"{breed.replace('_', ' ')} is a dog breed with unique characteristics."
322
 
323
  def _build_breed_vectors(self):
324
- """Build vector representations for all breeds"""
325
  try:
326
  print("Building breed vector database...")
327
 
 
 
 
 
328
  # Skip if model is not available
329
  if self.sbert_model is None:
330
  print("SBERT model not available, skipping vector building")
@@ -959,12 +970,20 @@ class SemanticBreedRecommender:
959
  try:
960
  print(f"Processing user input: {user_input}")
961
 
 
 
 
 
962
  # Check if model is available - if not, raise error
963
  if self.sbert_model is None:
964
  error_msg = "SBERT model not available. This could be due to:\n• Model download failed\n• Insufficient memory\n• Network connectivity issues\n\nPlease check your environment and try again."
965
  print(f"ERROR: {error_msg}")
966
  raise RuntimeError(error_msg)
967
 
 
 
 
 
968
  # Generate user input embedding
969
  user_embedding = self.sbert_model.encode(user_input, convert_to_tensor=False)
970
 
@@ -1584,6 +1603,10 @@ def get_breed_recommendations_by_description(user_description: str,
1584
  try:
1585
  print("Initializing Enhanced SemanticBreedRecommender...")
1586
  recommender = SemanticBreedRecommender()
 
 
 
 
1587
 
1588
  # 優先使用整合統一評分系統的增強推薦
1589
  print("Using enhanced recommendation system with unified scoring")
@@ -1628,11 +1651,19 @@ def get_enhanced_recommendations_with_unified_scoring(user_description: str, top
1628
  # 創建基本推薦器實例
1629
  recommender = SemanticBreedRecommender()
1630
 
 
 
 
 
1631
  if not recommender.sbert_model:
1632
  print("SBERT model not available, using basic text matching...")
1633
  # 使用基本文字匹配邏輯
1634
  return _get_basic_text_matching_recommendations(user_description, top_k)
1635
 
 
 
 
 
1636
  # 使用語意相似度推薦
1637
  recommendations = []
1638
  user_embedding = recommender.sbert_model.encode(user_description)
@@ -2212,4 +2243,4 @@ def _get_basic_text_matching_recommendations(user_description: str, top_k: int =
2212
  except Exception as e:
2213
  error_msg = f"Error in basic text matching: {str(e)}"
2214
  print(f"ERROR: {error_msg}")
2215
- raise RuntimeError(error_msg) from e
 
37
  """Initialize the semantic recommender"""
38
  self.model_name = 'all-MiniLM-L6-v2' # Efficient SBERT model
39
  self.sbert_model = None
40
+ self._sbert_loading_attempted = False
41
  self.breed_vectors = {}
42
  self.breed_list = self._get_breed_list()
43
  self.comparative_keywords = {
 
45
  'then': 0.7, 'second': 0.7, 'followed': 0.6,
46
  'third': 0.5, 'least': 0.3, 'dislike': 0.2
47
  }
48
+ # Defer SBERT model loading until needed in GPU context
49
+ # This prevents CUDA initialization issues in ZeroGPU environment
50
+ print("SemanticBreedRecommender initialized (SBERT loading deferred)")
 
 
 
 
51
 
52
  # Initialize multi-head scorer with SBERT model if enhanced mode is enabled
53
  # if self.sbert_model:
 
71
  'Bulldog', 'Poodle', 'Beagle', 'Rottweiler', 'Yorkshire_Terrier']
72
 
73
  def _initialize_model(self):
74
+ """Initialize SBERT model with fallback - designed for ZeroGPU compatibility"""
75
+ if self.sbert_model is not None or self._sbert_loading_attempted:
76
+ return self.sbert_model
77
+
78
  try:
79
+ print("Loading SBERT model in GPU context...")
80
  # Try different model names if the primary one fails
81
  model_options = ['all-MiniLM-L6-v2', 'all-mpnet-base-v2', 'all-MiniLM-L12-v2']
82
 
83
  for model_name in model_options:
84
  try:
85
+ # Specify device explicitly to handle ZeroGPU environment
86
+ import torch
87
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
88
+ self.sbert_model = SentenceTransformer(model_name, device=device)
89
  self.model_name = model_name
90
+ print(f"SBERT model {model_name} loaded successfully on {device}")
91
+ return self.sbert_model
92
  except Exception as model_e:
93
  print(f"Failed to load {model_name}: {str(model_e)}")
94
  continue
 
96
  # If all models fail
97
  print("All SBERT models failed to load. Using basic text matching fallback.")
98
  self.sbert_model = None
99
+ return None
100
 
101
  except Exception as e:
102
  print(f"Failed to initialize any SBERT model: {str(e)}")
103
  print(traceback.format_exc())
104
  print("Will provide basic text-based recommendations without embeddings")
105
  self.sbert_model = None
106
+ return None
107
+ finally:
108
+ self._sbert_loading_attempted = True
109
 
110
  def _create_breed_description(self, breed: str) -> str:
111
  """Create comprehensive natural language description for breed with all key characteristics"""
 
328
  return f"{breed.replace('_', ' ')} is a dog breed with unique characteristics."
329
 
330
  def _build_breed_vectors(self):
331
+ """Build vector representations for all breeds - called lazily when needed"""
332
  try:
333
  print("Building breed vector database...")
334
 
335
+ # Initialize model if not already done
336
+ if self.sbert_model is None:
337
+ self._initialize_model()
338
+
339
  # Skip if model is not available
340
  if self.sbert_model is None:
341
  print("SBERT model not available, skipping vector building")
 
970
  try:
971
  print(f"Processing user input: {user_input}")
972
 
973
+ # 嘗試載入SBERT模型(如果尚未載入)
974
+ if self.sbert_model is None:
975
+ self._initialize_model()
976
+
977
  # Check if model is available - if not, raise error
978
  if self.sbert_model is None:
979
  error_msg = "SBERT model not available. This could be due to:\n• Model download failed\n• Insufficient memory\n• Network connectivity issues\n\nPlease check your environment and try again."
980
  print(f"ERROR: {error_msg}")
981
  raise RuntimeError(error_msg)
982
 
983
+ # 確保breed vectors已建構
984
+ if not self.breed_vectors:
985
+ self._build_breed_vectors()
986
+
987
  # Generate user input embedding
988
  user_embedding = self.sbert_model.encode(user_input, convert_to_tensor=False)
989
 
 
1603
  try:
1604
  print("Initializing Enhanced SemanticBreedRecommender...")
1605
  recommender = SemanticBreedRecommender()
1606
+
1607
+ # 嘗試載入SBERT模型(如果尚未載入)
1608
+ if not recommender.sbert_model:
1609
+ recommender._initialize_model()
1610
 
1611
  # 優先使用整合統一評分系統的增強推薦
1612
  print("Using enhanced recommendation system with unified scoring")
 
1651
  # 創建基本推薦器實例
1652
  recommender = SemanticBreedRecommender()
1653
 
1654
+ # 嘗試載入SBERT模型(如果尚未載入)
1655
+ if not recommender.sbert_model:
1656
+ recommender._initialize_model()
1657
+
1658
  if not recommender.sbert_model:
1659
  print("SBERT model not available, using basic text matching...")
1660
  # 使用基本文字匹配邏輯
1661
  return _get_basic_text_matching_recommendations(user_description, top_k)
1662
 
1663
+ # 確保breed vectors已建構
1664
+ if not recommender.breed_vectors:
1665
+ recommender._build_breed_vectors()
1666
+
1667
  # 使用語意相似度推薦
1668
  recommendations = []
1669
  user_embedding = recommender.sbert_model.encode(user_description)
 
2243
  except Exception as e:
2244
  error_msg = f"Error in basic text matching: {str(e)}"
2245
  print(f"ERROR: {error_msg}")
2246
+ raise RuntimeError(error_msg) from e