meisaicheck-api / services /sentence_transformer_service.py
Vu Minh Chien
fix dtype
cc741b8
import os
import pickle
import pandas as pd
import warnings
# Suppress pandas warnings globally
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning)
pd.set_option("mode.chained_assignment", None)
import sys
from config import (
MODEL_NAME,
MODEL_TYPE,
DEVICE_TYPE,
SENTENCE_EMBEDDING_FILE,
STANDARD_NAME_MAP_DATA_FILE,
SUBJECT_DATA_FILE,
DATA_DIR,
HALF,
ABSTRACT_MAP_DATA_FILE,
NAME_ABSTRACT_MAP_DATA_FILE,
)
# Add the path to import modules from meisai-check-ai
# sys.path.append(os.path.join(os.path.dirname(__file__), "..", "meisai-check-ai"))
from sentence_transformer_lib.sentence_transformer_helper import SentenceTransformerHelper
from sentence_transformer_lib.cached_embedding_helper import CachedEmbeddingHelper
# Cache file paths for different types of embeddings
CACHED_EMBEDDINGS_SUBJECT_FILE = os.path.join(DATA_DIR, "cached_embeddings_subject.pkl")
CACHED_EMBEDDINGS_NAME_FILE = os.path.join(DATA_DIR, "cached_embeddings_name.pkl")
CACHED_EMBEDDINGS_ABSTRACT_FILE = os.path.join(
DATA_DIR, "cached_embeddings_abstract.pkl"
)
CACHED_EMBEDDINGS_SUB_SUBJECT_FILE = os.path.join(
DATA_DIR, "cached_embeddings_sub_subject.pkl"
)
CACHED_EMBEDDINGS_UNIT_FILE = os.path.join(DATA_DIR, "cached_embeddings_unit.pkl")
def load_cached_embeddings_by_type(cache_type):
"""Load cached embeddings from file based on type"""
cache_files = {
"subject": CACHED_EMBEDDINGS_SUBJECT_FILE,
"name": CACHED_EMBEDDINGS_NAME_FILE,
"abstract": CACHED_EMBEDDINGS_ABSTRACT_FILE,
"sub_subject": CACHED_EMBEDDINGS_SUB_SUBJECT_FILE,
"unit": CACHED_EMBEDDINGS_UNIT_FILE,
}
cache_file = cache_files.get(cache_type)
if not cache_file:
print(f"Unknown cache type: {cache_type}")
return {}, False
if os.path.exists(cache_file):
try:
with open(cache_file, "rb") as f:
cached_embeddings = pickle.load(f)
print(
f"Loaded {cache_type} embeddings with {len(cached_embeddings)} entries from {cache_file}"
)
return cached_embeddings, True
except Exception as e:
print(f"Error loading {cache_type} embeddings: {e}")
return {}, False
else:
print(
f"No {cache_type} embeddings cache file found. Starting with empty cache."
)
return {}, False
def save_cached_embeddings_by_type(cached_embedding_helper, cache_type):
"""Save cached embeddings to file based on type"""
cache_files = {
"subject": CACHED_EMBEDDINGS_SUBJECT_FILE,
"name": CACHED_EMBEDDINGS_NAME_FILE,
"abstract": CACHED_EMBEDDINGS_ABSTRACT_FILE,
"sub_subject": CACHED_EMBEDDINGS_SUB_SUBJECT_FILE,
"unit": CACHED_EMBEDDINGS_UNIT_FILE,
}
cache_file = cache_files.get(cache_type)
if not cache_file:
print(f"Unknown cache type: {cache_type}")
return
try:
# Ensure directory exists
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
cached_embeddings = cached_embedding_helper._cached_sentence_embeddings
with open(cache_file, "wb") as f:
pickle.dump(cached_embeddings, f)
print(
f"Saved {cache_type} embeddings with {len(cached_embeddings)} entries to {cache_file}"
)
except Exception as e:
print(f"Error saving {cache_type} embeddings: {e}")
def create_cached_embedding_helper_for_type(sentence_transformer, cache_type):
"""Create a CachedEmbeddingHelper for specific embedding type"""
cached_embeddings, is_loaded = load_cached_embeddings_by_type(cache_type)
return CachedEmbeddingHelper(
sentence_transformer, cached_sentence_embeddings=cached_embeddings
), is_loaded
class SentenceTransformerService:
def __init__(self):
self.sentenceTransformerHelper = None
# Different cached embedding helpers for different types
self.unit_cached_embedding_helper = None
self.unit_is_loaded = False
self.subject_cached_embedding_helper = None
self.subject_is_loaded = False
self.sub_subject_cached_embedding_helper = None
self.sub_subject_is_loaded = False
self.name_cached_embedding_helper = None
self.name_is_loaded = False
self.abstract_cached_embedding_helper = None
self.abstract_is_loaded = False
# Map data holders
self.df_unit_map_data = None
self.df_subject_map_data = None
self.df_standard_subject_map_data = None
self.df_sub_subject_map_data = None
self.df_name_map_data = None
self.df_abstract_map_data = None
self.df_name_and_subject_map_data = None
self.df_sub_subject_and_name_map_data = None
self.df_standard_name_map_data = None
def load_model_data(self):
"""Load model and data only once at startup"""
if self.sentenceTransformerHelper is not None:
print("Model already loaded. Skipping reload.")
return # Không load lại nếu đã có model
print("Loading models and data...")
# Load sentence transformer model
print(f"Loading model {MODEL_NAME} with type {MODEL_TYPE} and half={HALF}")
self.sentenceTransformerHelper = SentenceTransformerHelper(
model_name=MODEL_NAME, model_type=MODEL_TYPE, half=HALF
)
# Create different cached embedding helpers for different types
self.unit_cached_embedding_helper, self.unit_is_loaded = create_cached_embedding_helper_for_type(
self.sentenceTransformerHelper, "unit"
)
self.subject_cached_embedding_helper, self.subject_is_loaded = create_cached_embedding_helper_for_type(
self.sentenceTransformerHelper, "subject"
)
self.sub_subject_cached_embedding_helper, self.sub_subject_is_loaded = (
create_cached_embedding_helper_for_type(
self.sentenceTransformerHelper, "sub_subject"
)
)
self.name_cached_embedding_helper, self.name_is_loaded = create_cached_embedding_helper_for_type(
self.sentenceTransformerHelper, "name"
)
self.abstract_cached_embedding_helper, self.abstract_is_loaded = create_cached_embedding_helper_for_type(
self.sentenceTransformerHelper, "abstract"
)
# Load map data from CSV files (assuming they exist)
self._load_map_data()
print("Models and data loaded successfully")
def _load_map_data(self):
"""Load all mapping data from CSV files"""
try:
import pandas as pd
# Load unit map data
unit_map_file = os.path.join(DATA_DIR, "unitMapData.csv")
if os.path.exists(unit_map_file):
self.df_unit_map_data = pd.read_csv(unit_map_file)
print(f"Loaded unit map data: {len(self.df_unit_map_data)} entries")
# Load subject map data
subject_map_file = os.path.join(DATA_DIR, "subjectMapData.csv")
if os.path.exists(subject_map_file):
self.df_subject_map_data = pd.read_csv(subject_map_file)
print(
f"Loaded subject map data: {len(self.df_subject_map_data)} entries"
)
# Load standard subject map data
standard_subject_map_file = os.path.join(
DATA_DIR, "standardSubjectMapData.csv"
)
if os.path.exists(standard_subject_map_file):
self.df_standard_subject_map_data = pd.read_csv(
standard_subject_map_file
)
print(
f"Loaded standard subject map data: {len(self.df_standard_subject_map_data)} entries"
)
# Load sub subject map data
sub_subject_map_file = os.path.join(DATA_DIR, "subSubjectMapData.csv")
if os.path.exists(sub_subject_map_file):
self.df_sub_subject_map_data = pd.read_csv(sub_subject_map_file)
print(
f"Loaded sub subject map data: {len(self.df_sub_subject_map_data)} entries"
)
# Load name map data
name_map_file = os.path.join(DATA_DIR, "nameMapData.csv")
if os.path.exists(name_map_file):
self.df_name_map_data = pd.read_csv(name_map_file, dtype={'外部・内部設定タイプ': str})
print(f"Loaded name map data: {len(self.df_name_map_data)} entries")
# Load sub subject and name map data
sub_subject_and_name_map_file = os.path.join(
DATA_DIR, "subSubjectAndNameMapData.csv"
)
if os.path.exists(sub_subject_and_name_map_file):
self.df_sub_subject_and_name_map_data = pd.read_csv(
sub_subject_and_name_map_file
)
print(
f"Loaded sub subject and name map data: {len(self.df_sub_subject_and_name_map_data)} entries"
)
# Load abstract map data
abstract_map_file = os.path.join(DATA_DIR, "abstractMapData.csv")
if os.path.exists(abstract_map_file):
self.df_abstract_map_data = pd.read_csv(abstract_map_file, dtype={'摘要タイプ': str})
print(
f"Loaded abstract map data: {len(self.df_abstract_map_data)} entries"
)
# Load name and subject map data
name_and_subject_map_file = os.path.join(
DATA_DIR, "nameAndSubjectMapData.csv"
)
if os.path.exists(name_and_subject_map_file):
self.df_name_and_subject_map_data = pd.read_csv(
name_and_subject_map_file
)
print(
f"Loaded name and subject map data: {len(self.df_name_and_subject_map_data)} entries"
)
# Load standard name map data
standard_name_map_file = os.path.join(DATA_DIR, "standardNameMapData.csv")
if os.path.exists(standard_name_map_file):
self.df_standard_name_map_data = pd.read_csv(standard_name_map_file)
print(
f"Loaded standard name map data: {len(self.df_standard_name_map_data)} entries"
)
except Exception as e:
print(f"Error loading map data: {e}")
def save_all_caches(self):
"""Save all cached embeddings"""
try:
if not self.unit_is_loaded:
save_cached_embeddings_by_type(
self.unit_cached_embedding_helper, "unit"
)
if not self.subject_is_loaded:
save_cached_embeddings_by_type(
self.subject_cached_embedding_helper, "subject"
)
if not self.sub_subject_is_loaded:
save_cached_embeddings_by_type(
self.sub_subject_cached_embedding_helper, "sub_subject"
)
if not self.name_is_loaded:
save_cached_embeddings_by_type(
self.name_cached_embedding_helper, "name"
)
if not self.abstract_is_loaded:
save_cached_embeddings_by_type(
self.abstract_cached_embedding_helper, "abstract"
)
# Print cache statistics summary
print("\n" + "=" * 60)
print("EMBEDDING CACHE PERFORMANCE SUMMARY")
print("=" * 60)
total_cache_size = 0
if not self.unit_is_loaded:
unit_size = len(
self.unit_cached_embedding_helper._cached_sentence_embeddings
)
total_cache_size += unit_size
print(f"Unit cache: {unit_size} embeddings")
if not self.subject_is_loaded:
subject_size = len(
self.subject_cached_embedding_helper._cached_sentence_embeddings
)
total_cache_size += subject_size
print(f"Subject cache: {subject_size} embeddings")
if not self.sub_subject_is_loaded:
sub_subject_size = len(
self.sub_subject_cached_embedding_helper._cached_sentence_embeddings
)
total_cache_size += sub_subject_size
print(f"Sub-subject cache: {sub_subject_size} embeddings")
if not self.name_is_loaded:
name_size = len(
self.name_cached_embedding_helper._cached_sentence_embeddings
)
total_cache_size += name_size
print(f"Name cache: {name_size} embeddings")
if not self.abstract_is_loaded:
abstract_size = len(
self.abstract_cached_embedding_helper._cached_sentence_embeddings
)
total_cache_size += abstract_size
print(f"Abstract cache: {abstract_size} embeddings")
print(f"Total cached embeddings: {total_cache_size}")
print("=" * 60)
except Exception as e:
print(f"Error saving caches: {e}")
# Global instance (singleton)
sentence_transformer_service = SentenceTransformerService()