Spaces:
Running
on
Zero
Running
on
Zero
Update semantic_breed_recommender.py
Browse files- semantic_breed_recommender.py +45 -14
semantic_breed_recommender.py
CHANGED
@@ -37,6 +37,7 @@ class SemanticBreedRecommender:
|
|
37 |
"""Initialize the semantic recommender"""
|
38 |
self.model_name = 'all-MiniLM-L6-v2' # Efficient SBERT model
|
39 |
self.sbert_model = None
|
|
|
40 |
self.breed_vectors = {}
|
41 |
self.breed_list = self._get_breed_list()
|
42 |
self.comparative_keywords = {
|
@@ -44,13 +45,9 @@ class SemanticBreedRecommender:
|
|
44 |
'then': 0.7, 'second': 0.7, 'followed': 0.6,
|
45 |
'third': 0.5, 'least': 0.3, 'dislike': 0.2
|
46 |
}
|
47 |
-
#
|
48 |
-
#
|
49 |
-
|
50 |
-
# self.score_calibrator = ScoreCalibrator()
|
51 |
-
# self.config_manager = get_config_manager()
|
52 |
-
self._initialize_model()
|
53 |
-
self._build_breed_vectors()
|
54 |
|
55 |
# Initialize multi-head scorer with SBERT model if enhanced mode is enabled
|
56 |
# if self.sbert_model:
|
@@ -74,18 +71,24 @@ class SemanticBreedRecommender:
|
|
74 |
'Bulldog', 'Poodle', 'Beagle', 'Rottweiler', 'Yorkshire_Terrier']
|
75 |
|
76 |
def _initialize_model(self):
|
77 |
-
"""Initialize SBERT model with fallback"""
|
|
|
|
|
|
|
78 |
try:
|
79 |
-
print("Loading SBERT model...")
|
80 |
# Try different model names if the primary one fails
|
81 |
model_options = ['all-MiniLM-L6-v2', 'all-mpnet-base-v2', 'all-MiniLM-L12-v2']
|
82 |
|
83 |
for model_name in model_options:
|
84 |
try:
|
85 |
-
|
|
|
|
|
|
|
86 |
self.model_name = model_name
|
87 |
-
print(f"SBERT model {model_name} loaded successfully")
|
88 |
-
return
|
89 |
except Exception as model_e:
|
90 |
print(f"Failed to load {model_name}: {str(model_e)}")
|
91 |
continue
|
@@ -93,12 +96,16 @@ class SemanticBreedRecommender:
|
|
93 |
# If all models fail
|
94 |
print("All SBERT models failed to load. Using basic text matching fallback.")
|
95 |
self.sbert_model = None
|
|
|
96 |
|
97 |
except Exception as e:
|
98 |
print(f"Failed to initialize any SBERT model: {str(e)}")
|
99 |
print(traceback.format_exc())
|
100 |
print("Will provide basic text-based recommendations without embeddings")
|
101 |
self.sbert_model = None
|
|
|
|
|
|
|
102 |
|
103 |
def _create_breed_description(self, breed: str) -> str:
|
104 |
"""Create comprehensive natural language description for breed with all key characteristics"""
|
@@ -321,10 +328,14 @@ class SemanticBreedRecommender:
|
|
321 |
return f"{breed.replace('_', ' ')} is a dog breed with unique characteristics."
|
322 |
|
323 |
def _build_breed_vectors(self):
|
324 |
-
"""Build vector representations for all breeds"""
|
325 |
try:
|
326 |
print("Building breed vector database...")
|
327 |
|
|
|
|
|
|
|
|
|
328 |
# Skip if model is not available
|
329 |
if self.sbert_model is None:
|
330 |
print("SBERT model not available, skipping vector building")
|
@@ -959,12 +970,20 @@ class SemanticBreedRecommender:
|
|
959 |
try:
|
960 |
print(f"Processing user input: {user_input}")
|
961 |
|
|
|
|
|
|
|
|
|
962 |
# Check if model is available - if not, raise error
|
963 |
if self.sbert_model is None:
|
964 |
error_msg = "SBERT model not available. This could be due to:\n• Model download failed\n• Insufficient memory\n• Network connectivity issues\n\nPlease check your environment and try again."
|
965 |
print(f"ERROR: {error_msg}")
|
966 |
raise RuntimeError(error_msg)
|
967 |
|
|
|
|
|
|
|
|
|
968 |
# Generate user input embedding
|
969 |
user_embedding = self.sbert_model.encode(user_input, convert_to_tensor=False)
|
970 |
|
@@ -1584,6 +1603,10 @@ def get_breed_recommendations_by_description(user_description: str,
|
|
1584 |
try:
|
1585 |
print("Initializing Enhanced SemanticBreedRecommender...")
|
1586 |
recommender = SemanticBreedRecommender()
|
|
|
|
|
|
|
|
|
1587 |
|
1588 |
# 優先使用整合統一評分系統的增強推薦
|
1589 |
print("Using enhanced recommendation system with unified scoring")
|
@@ -1628,11 +1651,19 @@ def get_enhanced_recommendations_with_unified_scoring(user_description: str, top
|
|
1628 |
# 創建基本推薦器實例
|
1629 |
recommender = SemanticBreedRecommender()
|
1630 |
|
|
|
|
|
|
|
|
|
1631 |
if not recommender.sbert_model:
|
1632 |
print("SBERT model not available, using basic text matching...")
|
1633 |
# 使用基本文字匹配邏輯
|
1634 |
return _get_basic_text_matching_recommendations(user_description, top_k)
|
1635 |
|
|
|
|
|
|
|
|
|
1636 |
# 使用語意相似度推薦
|
1637 |
recommendations = []
|
1638 |
user_embedding = recommender.sbert_model.encode(user_description)
|
@@ -2212,4 +2243,4 @@ def _get_basic_text_matching_recommendations(user_description: str, top_k: int =
|
|
2212 |
except Exception as e:
|
2213 |
error_msg = f"Error in basic text matching: {str(e)}"
|
2214 |
print(f"ERROR: {error_msg}")
|
2215 |
-
raise RuntimeError(error_msg) from e
|
|
|
37 |
"""Initialize the semantic recommender"""
|
38 |
self.model_name = 'all-MiniLM-L6-v2' # Efficient SBERT model
|
39 |
self.sbert_model = None
|
40 |
+
self._sbert_loading_attempted = False
|
41 |
self.breed_vectors = {}
|
42 |
self.breed_list = self._get_breed_list()
|
43 |
self.comparative_keywords = {
|
|
|
45 |
'then': 0.7, 'second': 0.7, 'followed': 0.6,
|
46 |
'third': 0.5, 'least': 0.3, 'dislike': 0.2
|
47 |
}
|
48 |
+
# Defer SBERT model loading until needed in GPU context
|
49 |
+
# This prevents CUDA initialization issues in ZeroGPU environment
|
50 |
+
print("SemanticBreedRecommender initialized (SBERT loading deferred)")
|
|
|
|
|
|
|
|
|
51 |
|
52 |
# Initialize multi-head scorer with SBERT model if enhanced mode is enabled
|
53 |
# if self.sbert_model:
|
|
|
71 |
'Bulldog', 'Poodle', 'Beagle', 'Rottweiler', 'Yorkshire_Terrier']
|
72 |
|
73 |
def _initialize_model(self):
|
74 |
+
"""Initialize SBERT model with fallback - designed for ZeroGPU compatibility"""
|
75 |
+
if self.sbert_model is not None or self._sbert_loading_attempted:
|
76 |
+
return self.sbert_model
|
77 |
+
|
78 |
try:
|
79 |
+
print("Loading SBERT model in GPU context...")
|
80 |
# Try different model names if the primary one fails
|
81 |
model_options = ['all-MiniLM-L6-v2', 'all-mpnet-base-v2', 'all-MiniLM-L12-v2']
|
82 |
|
83 |
for model_name in model_options:
|
84 |
try:
|
85 |
+
# Specify device explicitly to handle ZeroGPU environment
|
86 |
+
import torch
|
87 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
88 |
+
self.sbert_model = SentenceTransformer(model_name, device=device)
|
89 |
self.model_name = model_name
|
90 |
+
print(f"SBERT model {model_name} loaded successfully on {device}")
|
91 |
+
return self.sbert_model
|
92 |
except Exception as model_e:
|
93 |
print(f"Failed to load {model_name}: {str(model_e)}")
|
94 |
continue
|
|
|
96 |
# If all models fail
|
97 |
print("All SBERT models failed to load. Using basic text matching fallback.")
|
98 |
self.sbert_model = None
|
99 |
+
return None
|
100 |
|
101 |
except Exception as e:
|
102 |
print(f"Failed to initialize any SBERT model: {str(e)}")
|
103 |
print(traceback.format_exc())
|
104 |
print("Will provide basic text-based recommendations without embeddings")
|
105 |
self.sbert_model = None
|
106 |
+
return None
|
107 |
+
finally:
|
108 |
+
self._sbert_loading_attempted = True
|
109 |
|
110 |
def _create_breed_description(self, breed: str) -> str:
|
111 |
"""Create comprehensive natural language description for breed with all key characteristics"""
|
|
|
328 |
return f"{breed.replace('_', ' ')} is a dog breed with unique characteristics."
|
329 |
|
330 |
def _build_breed_vectors(self):
|
331 |
+
"""Build vector representations for all breeds - called lazily when needed"""
|
332 |
try:
|
333 |
print("Building breed vector database...")
|
334 |
|
335 |
+
# Initialize model if not already done
|
336 |
+
if self.sbert_model is None:
|
337 |
+
self._initialize_model()
|
338 |
+
|
339 |
# Skip if model is not available
|
340 |
if self.sbert_model is None:
|
341 |
print("SBERT model not available, skipping vector building")
|
|
|
970 |
try:
|
971 |
print(f"Processing user input: {user_input}")
|
972 |
|
973 |
+
# 嘗試載入SBERT模型(如果尚未載入)
|
974 |
+
if self.sbert_model is None:
|
975 |
+
self._initialize_model()
|
976 |
+
|
977 |
# Check if model is available - if not, raise error
|
978 |
if self.sbert_model is None:
|
979 |
error_msg = "SBERT model not available. This could be due to:\n• Model download failed\n• Insufficient memory\n• Network connectivity issues\n\nPlease check your environment and try again."
|
980 |
print(f"ERROR: {error_msg}")
|
981 |
raise RuntimeError(error_msg)
|
982 |
|
983 |
+
# 確保breed vectors已建構
|
984 |
+
if not self.breed_vectors:
|
985 |
+
self._build_breed_vectors()
|
986 |
+
|
987 |
# Generate user input embedding
|
988 |
user_embedding = self.sbert_model.encode(user_input, convert_to_tensor=False)
|
989 |
|
|
|
1603 |
try:
|
1604 |
print("Initializing Enhanced SemanticBreedRecommender...")
|
1605 |
recommender = SemanticBreedRecommender()
|
1606 |
+
|
1607 |
+
# 嘗試載入SBERT模型(如果尚未載入)
|
1608 |
+
if not recommender.sbert_model:
|
1609 |
+
recommender._initialize_model()
|
1610 |
|
1611 |
# 優先使用整合統一評分系統的增強推薦
|
1612 |
print("Using enhanced recommendation system with unified scoring")
|
|
|
1651 |
# 創建基本推薦器實例
|
1652 |
recommender = SemanticBreedRecommender()
|
1653 |
|
1654 |
+
# 嘗試載入SBERT模型(如果尚未載入)
|
1655 |
+
if not recommender.sbert_model:
|
1656 |
+
recommender._initialize_model()
|
1657 |
+
|
1658 |
if not recommender.sbert_model:
|
1659 |
print("SBERT model not available, using basic text matching...")
|
1660 |
# 使用基本文字匹配邏輯
|
1661 |
return _get_basic_text_matching_recommendations(user_description, top_k)
|
1662 |
|
1663 |
+
# 確保breed vectors已建構
|
1664 |
+
if not recommender.breed_vectors:
|
1665 |
+
recommender._build_breed_vectors()
|
1666 |
+
|
1667 |
# 使用語意相似度推薦
|
1668 |
recommendations = []
|
1669 |
user_embedding = recommender.sbert_model.encode(user_description)
|
|
|
2243 |
except Exception as e:
|
2244 |
error_msg = f"Error in basic text matching: {str(e)}"
|
2245 |
print(f"ERROR: {error_msg}")
|
2246 |
+
raise RuntimeError(error_msg) from e
|