Spaces:
Running
Running
import os | |
import pickle | |
import pandas as pd | |
import warnings | |
# Suppress pandas warnings globally | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning) | |
pd.set_option("mode.chained_assignment", None) | |
import sys | |
from config import ( | |
MODEL_NAME, | |
MODEL_TYPE, | |
DEVICE_TYPE, | |
SENTENCE_EMBEDDING_FILE, | |
STANDARD_NAME_MAP_DATA_FILE, | |
SUBJECT_DATA_FILE, | |
DATA_DIR, | |
HALF, | |
ABSTRACT_MAP_DATA_FILE, | |
NAME_ABSTRACT_MAP_DATA_FILE, | |
) | |
# Add the path to import modules from meisai-check-ai | |
# sys.path.append(os.path.join(os.path.dirname(__file__), "..", "meisai-check-ai")) | |
from sentence_transformer_lib.sentence_transformer_helper import SentenceTransformerHelper | |
from sentence_transformer_lib.cached_embedding_helper import CachedEmbeddingHelper | |
# Cache file paths for different types of embeddings | |
CACHED_EMBEDDINGS_SUBJECT_FILE = os.path.join(DATA_DIR, "cached_embeddings_subject.pkl") | |
CACHED_EMBEDDINGS_NAME_FILE = os.path.join(DATA_DIR, "cached_embeddings_name.pkl") | |
CACHED_EMBEDDINGS_ABSTRACT_FILE = os.path.join( | |
DATA_DIR, "cached_embeddings_abstract.pkl" | |
) | |
CACHED_EMBEDDINGS_SUB_SUBJECT_FILE = os.path.join( | |
DATA_DIR, "cached_embeddings_sub_subject.pkl" | |
) | |
CACHED_EMBEDDINGS_UNIT_FILE = os.path.join(DATA_DIR, "cached_embeddings_unit.pkl") | |
def load_cached_embeddings_by_type(cache_type): | |
"""Load cached embeddings from file based on type""" | |
cache_files = { | |
"subject": CACHED_EMBEDDINGS_SUBJECT_FILE, | |
"name": CACHED_EMBEDDINGS_NAME_FILE, | |
"abstract": CACHED_EMBEDDINGS_ABSTRACT_FILE, | |
"sub_subject": CACHED_EMBEDDINGS_SUB_SUBJECT_FILE, | |
"unit": CACHED_EMBEDDINGS_UNIT_FILE, | |
} | |
cache_file = cache_files.get(cache_type) | |
if not cache_file: | |
print(f"Unknown cache type: {cache_type}") | |
return {} | |
if os.path.exists(cache_file): | |
try: | |
with open(cache_file, "rb") as f: | |
cached_embeddings = pickle.load(f) | |
print( | |
f"Loaded {cache_type} embeddings with {len(cached_embeddings)} entries from {cache_file}" | |
) | |
return cached_embeddings | |
except Exception as e: | |
print(f"Error loading {cache_type} embeddings: {e}") | |
return {} | |
else: | |
print( | |
f"No {cache_type} embeddings cache file found. Starting with empty cache." | |
) | |
return {} | |
def save_cached_embeddings_by_type(cached_embedding_helper, cache_type): | |
"""Save cached embeddings to file based on type""" | |
cache_files = { | |
"subject": CACHED_EMBEDDINGS_SUBJECT_FILE, | |
"name": CACHED_EMBEDDINGS_NAME_FILE, | |
"abstract": CACHED_EMBEDDINGS_ABSTRACT_FILE, | |
"sub_subject": CACHED_EMBEDDINGS_SUB_SUBJECT_FILE, | |
"unit": CACHED_EMBEDDINGS_UNIT_FILE, | |
} | |
cache_file = cache_files.get(cache_type) | |
if not cache_file: | |
print(f"Unknown cache type: {cache_type}") | |
return | |
try: | |
# Ensure directory exists | |
os.makedirs(os.path.dirname(cache_file), exist_ok=True) | |
cached_embeddings = cached_embedding_helper._cached_sentence_embeddings | |
with open(cache_file, "wb") as f: | |
pickle.dump(cached_embeddings, f) | |
print( | |
f"Saved {cache_type} embeddings with {len(cached_embeddings)} entries to {cache_file}" | |
) | |
except Exception as e: | |
print(f"Error saving {cache_type} embeddings: {e}") | |
def create_cached_embedding_helper_for_type(sentence_transformer, cache_type): | |
"""Create a CachedEmbeddingHelper for specific embedding type""" | |
cached_embeddings = load_cached_embeddings_by_type(cache_type) | |
return CachedEmbeddingHelper( | |
sentence_transformer, cached_sentence_embeddings=cached_embeddings | |
) | |
class SentenceTransformerService: | |
def __init__(self): | |
self.sentenceTransformerHelper = None | |
# Different cached embedding helpers for different types | |
self.unit_cached_embedding_helper = None | |
self.subject_cached_embedding_helper = None | |
self.sub_subject_cached_embedding_helper = None | |
self.name_cached_embedding_helper = None | |
self.abstract_cached_embedding_helper = None | |
# Map data holders | |
self.df_unit_map_data = None | |
self.df_subject_map_data = None | |
self.df_standard_subject_map_data = None | |
self.df_sub_subject_map_data = None | |
self.df_name_map_data = None | |
self.df_abstract_map_data = None | |
self.df_name_and_subject_map_data = None | |
self.df_sub_subject_and_name_map_data = None | |
def load_model_data(self): | |
"""Load model and data only once at startup""" | |
if self.sentenceTransformerHelper is not None: | |
print("Model already loaded. Skipping reload.") | |
return # Không load lại nếu đã có model | |
print("Loading models and data...") | |
# Load sentence transformer model | |
print(f"Loading model {MODEL_NAME} with type {MODEL_TYPE} and half={HALF}") | |
self.sentenceTransformerHelper = SentenceTransformerHelper( | |
model_name=MODEL_NAME, model_type=MODEL_TYPE, half=HALF | |
) | |
# Create different cached embedding helpers for different types | |
self.unit_cached_embedding_helper = create_cached_embedding_helper_for_type( | |
self.sentenceTransformerHelper, "unit" | |
) | |
self.subject_cached_embedding_helper = create_cached_embedding_helper_for_type( | |
self.sentenceTransformerHelper, "subject" | |
) | |
self.sub_subject_cached_embedding_helper = ( | |
create_cached_embedding_helper_for_type( | |
self.sentenceTransformerHelper, "sub_subject" | |
) | |
) | |
self.name_cached_embedding_helper = create_cached_embedding_helper_for_type( | |
self.sentenceTransformerHelper, "name" | |
) | |
self.abstract_cached_embedding_helper = create_cached_embedding_helper_for_type( | |
self.sentenceTransformerHelper, "abstract" | |
) | |
# Load map data from CSV files (assuming they exist) | |
self._load_map_data() | |
print("Models and data loaded successfully") | |
def _load_map_data(self): | |
"""Load all mapping data from CSV files""" | |
try: | |
import pandas as pd | |
# Load unit map data | |
unit_map_file = os.path.join(DATA_DIR, "unitMapData.csv") | |
if os.path.exists(unit_map_file): | |
self.df_unit_map_data = pd.read_csv(unit_map_file) | |
print(f"Loaded unit map data: {len(self.df_unit_map_data)} entries") | |
# Load subject map data | |
subject_map_file = os.path.join(DATA_DIR, "subjectMapData.csv") | |
if os.path.exists(subject_map_file): | |
self.df_subject_map_data = pd.read_csv(subject_map_file) | |
print( | |
f"Loaded subject map data: {len(self.df_subject_map_data)} entries" | |
) | |
# Load standard subject map data | |
standard_subject_map_file = os.path.join( | |
DATA_DIR, "standardSubjectMapData.csv" | |
) | |
if os.path.exists(standard_subject_map_file): | |
self.df_standard_subject_map_data = pd.read_csv( | |
standard_subject_map_file | |
) | |
print( | |
f"Loaded standard subject map data: {len(self.df_standard_subject_map_data)} entries" | |
) | |
# Load sub subject map data | |
sub_subject_map_file = os.path.join(DATA_DIR, "subSubjectMapData.csv") | |
if os.path.exists(sub_subject_map_file): | |
self.df_sub_subject_map_data = pd.read_csv(sub_subject_map_file) | |
print( | |
f"Loaded sub subject map data: {len(self.df_sub_subject_map_data)} entries" | |
) | |
# Load name map data | |
name_map_file = os.path.join(DATA_DIR, "nameMapData.csv") | |
if os.path.exists(name_map_file): | |
self.df_name_map_data = pd.read_csv(name_map_file) | |
print(f"Loaded name map data: {len(self.df_name_map_data)} entries") | |
# Load abstract map data | |
abstract_map_file = os.path.join(DATA_DIR, "abstractMapData.csv") | |
if os.path.exists(abstract_map_file): | |
self.df_abstract_map_data = pd.read_csv(abstract_map_file) | |
print( | |
f"Loaded abstract map data: {len(self.df_abstract_map_data)} entries" | |
) | |
print( | |
f"DEBUG: Abstract map data columns: {list(self.df_abstract_map_data.columns)}" | |
) | |
print(f"DEBUG: Abstract map data sample:") | |
print(self.df_abstract_map_data.head(3).to_string()) | |
else: | |
print(f"DEBUG: Abstract map file not found: {abstract_map_file}") | |
# Load name and subject map data | |
name_and_subject_map_file = os.path.join( | |
DATA_DIR, "nameAndSubjectMapData.csv" | |
) | |
if os.path.exists(name_and_subject_map_file): | |
self.df_name_and_subject_map_data = pd.read_csv( | |
name_and_subject_map_file | |
) | |
print( | |
f"Loaded name and subject map data: {len(self.df_name_and_subject_map_data)} entries" | |
) | |
# Load sub subject and name map data | |
sub_subject_and_name_map_file = os.path.join( | |
DATA_DIR, "subSubjectAndNameMapData.csv" | |
) | |
if os.path.exists(sub_subject_and_name_map_file): | |
self.df_sub_subject_and_name_map_data = pd.read_csv( | |
sub_subject_and_name_map_file | |
) | |
print( | |
f"Loaded sub subject and name map data: {len(self.df_sub_subject_and_name_map_data)} entries" | |
) | |
except Exception as e: | |
print(f"Error loading map data: {e}") | |
def save_all_caches(self): | |
"""Save all cached embeddings""" | |
try: | |
if self.unit_cached_embedding_helper: | |
save_cached_embeddings_by_type( | |
self.unit_cached_embedding_helper, "unit" | |
) | |
if self.subject_cached_embedding_helper: | |
save_cached_embeddings_by_type( | |
self.subject_cached_embedding_helper, "subject" | |
) | |
if self.sub_subject_cached_embedding_helper: | |
save_cached_embeddings_by_type( | |
self.sub_subject_cached_embedding_helper, "sub_subject" | |
) | |
if self.name_cached_embedding_helper: | |
save_cached_embeddings_by_type( | |
self.name_cached_embedding_helper, "name" | |
) | |
if self.abstract_cached_embedding_helper: | |
save_cached_embeddings_by_type( | |
self.abstract_cached_embedding_helper, "abstract" | |
) | |
# Print cache statistics summary | |
print("\n" + "=" * 60) | |
print("EMBEDDING CACHE PERFORMANCE SUMMARY") | |
print("=" * 60) | |
total_cache_size = 0 | |
if self.unit_cached_embedding_helper: | |
unit_size = len( | |
self.unit_cached_embedding_helper._cached_sentence_embeddings | |
) | |
total_cache_size += unit_size | |
print(f"Unit cache: {unit_size} embeddings") | |
if self.subject_cached_embedding_helper: | |
subject_size = len( | |
self.subject_cached_embedding_helper._cached_sentence_embeddings | |
) | |
total_cache_size += subject_size | |
print(f"Subject cache: {subject_size} embeddings") | |
if self.sub_subject_cached_embedding_helper: | |
sub_subject_size = len( | |
self.sub_subject_cached_embedding_helper._cached_sentence_embeddings | |
) | |
total_cache_size += sub_subject_size | |
print(f"Sub-subject cache: {sub_subject_size} embeddings") | |
if self.name_cached_embedding_helper: | |
name_size = len( | |
self.name_cached_embedding_helper._cached_sentence_embeddings | |
) | |
total_cache_size += name_size | |
print(f"Name cache: {name_size} embeddings") | |
if self.abstract_cached_embedding_helper: | |
abstract_size = len( | |
self.abstract_cached_embedding_helper._cached_sentence_embeddings | |
) | |
total_cache_size += abstract_size | |
print(f"Abstract cache: {abstract_size} embeddings") | |
print(f"Total cached embeddings: {total_cache_size}") | |
print("=" * 60) | |
except Exception as e: | |
print(f"Error saving caches: {e}") | |
# Global instance (singleton) | |
sentence_transformer_service = SentenceTransformerService() | |