Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
# Copyright 2025 Google LLC | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import logging | |
import os | |
import sys | |
import config | |
import nltk | |
import stanza | |
import torch | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from .siglip_embedder import CustomSigLipEmbeddings | |
logger = logging.getLogger(__name__) | |
EMBEDDING_MODEL_ID = os.environ.get("EMBEDDING_MODEL_ID", None) | |
class ModelManager: | |
"""Handles the expensive, one-time setup of downloading and loading all AI models required for RAG.""" | |
def __init__(self): | |
# Configuration for model identifiers | |
self.embedding_model_id = EMBEDDING_MODEL_ID | |
self.stanza_ner_package = "mimic" | |
self.stanza_ner_processor = "i2b2" | |
def load_models(self) -> dict: | |
""" | |
Initializes and returns a dictionary of model components. | |
Note: The main LLM is accessed via API and is NOT loaded here. | |
""" | |
logger.info("--- Initializing RAG-specific Models (Embedder, NER) ---") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {device} for RAG models") | |
models = {} | |
# 1. Load Embedder | |
try: | |
logger.info(f"Loading embedding model: {self.embedding_model_id}") | |
if "siglip" in self.embedding_model_id: | |
models["embedder"] = CustomSigLipEmbeddings( | |
siglip_model_name=self.embedding_model_id, | |
device=device, | |
normalize_embeddings=True, | |
) | |
else: | |
models['embedder'] = HuggingFaceEmbeddings( | |
model_name=self.embedding_model_id, | |
model_kwargs={"device": device}, | |
encode_kwargs={"normalize_embeddings": True}, | |
) | |
logger.info("✅ Embedding model loaded successfully.") | |
except Exception as e: | |
logger.error(f"⚠️ Failed to load embedding model: {e}", exc_info=True) | |
sys.exit(1) | |
models['embedder'] = None | |
# 2. Load Stanza for NER | |
try: | |
logger.info("Downloading NLTK and Stanza models...") | |
stanza.download( | |
"en", | |
package=self.stanza_ner_package, | |
processors={"ner": self.stanza_ner_processor}, | |
verbose=False, | |
) | |
logger.info("✅ Stanza models downloaded.") | |
logger.info("Loading Stanza NER Pipeline...") | |
models['ner_pipeline'] = stanza.Pipeline( | |
lang="en", | |
package=self.stanza_ner_package, | |
processors={"ner": "i2b2"}, | |
use_gpu=torch.cuda.is_available(), | |
verbose=False, | |
tokenize_no_ssplit=True, | |
) | |
logger.info("✅ Stanza NER Pipeline loaded successfully.") | |
except Exception as e: | |
logger.error(f"⚠️ Failed to set up Stanza NER pipeline: {e}", exc_info=True) | |
models['ner_pipeline'] = None | |
if all(models.values()): | |
logger.info("\n✅ All RAG-specific models initialized successfully.") | |
else: | |
logger.error("\n⚠️ One or more RAG models failed to initialize. Check errors above.") | |
return models | |