Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, ModernBertConfig | |
from typing import Dict, Any | |
import yaml | |
import os | |
from models import ModernBertForSentiment | |
class SentimentInference: | |
def __init__(self, config_path: str = "config.yaml"): | |
"""Load configuration and initialize model and tokenizer from local checkpoint or Hugging Face Hub.""" | |
print(f"--- Debug: SentimentInference __init__ received config_path: {config_path} ---") # Add this | |
with open(config_path, 'r') as f: | |
config_data = yaml.safe_load(f) | |
print(f"--- Debug: SentimentInference loaded config_data: {config_data} ---") # Add this | |
model_yaml_cfg = config_data.get('model', {}) | |
inference_yaml_cfg = config_data.get('inference', {}) | |
model_hf_repo_id = model_yaml_cfg.get('name_or_path') | |
tokenizer_hf_repo_id = model_yaml_cfg.get('tokenizer_name_or_path', model_hf_repo_id) | |
local_model_weights_path = inference_yaml_cfg.get('model_path') # Path for local .pt file | |
print(f"--- Debug: model_hf_repo_id: {model_hf_repo_id} ---") # Add this | |
print(f"--- Debug: local_model_weights_path: {local_model_weights_path} ---") # Add this | |
self.max_length = inference_yaml_cfg.get('max_length', model_yaml_cfg.get('max_length', 512)) | |
# --- Tokenizer Loading (always from Hub for now, or could be made conditional) --- | |
if not tokenizer_hf_repo_id and not model_hf_repo_id: | |
raise ValueError("Either model.tokenizer_name_or_path or model.name_or_path (as fallback for tokenizer) must be specified in config.yaml") | |
effective_tokenizer_repo_id = tokenizer_hf_repo_id or model_hf_repo_id | |
print(f"[INFERENCE_LOG] Loading tokenizer from: {effective_tokenizer_repo_id}") # Logging | |
self.tokenizer = AutoTokenizer.from_pretrained(effective_tokenizer_repo_id) | |
# --- Model Loading --- # | |
# Determine if we are loading from a local .pt file or from Hugging Face Hub | |
load_from_local_pt = False | |
if local_model_weights_path and os.path.isfile(local_model_weights_path): | |
print(f"[INFERENCE_LOG] Found local model weights path: {local_model_weights_path}") # Logging | |
print(f"--- Debug: Found local model weights path: {local_model_weights_path} ---") # Add this | |
load_from_local_pt = True | |
elif not model_hf_repo_id: | |
raise ValueError("No local model_path found and model.name_or_path (for Hub) is not specified in config.yaml") | |
print(f"[INFERENCE_LOG] load_from_local_pt: {load_from_local_pt}") # Logging | |
print(f"--- Debug: load_from_local_pt is: {load_from_local_pt} ---") # Add this | |
if load_from_local_pt: | |
print("[INFERENCE_LOG] Attempting to load model from LOCAL .pt checkpoint...") # Logging | |
print("--- Debug: Entering LOCAL .pt loading path ---") # Add this | |
# Base BERT config must still be loaded, usually from a Hub ID (e.g., original base model) | |
# This base_model_for_config_id is crucial for building the correct ModernBertForSentiment structure. | |
base_model_for_config_id = model_yaml_cfg.get('base_model_for_config', model_yaml_cfg.get('name_or_path')) | |
if not base_model_for_config_id: | |
raise ValueError("model.base_model_for_config or model.name_or_path must be specified in config.yaml when loading local .pt for ModernBertForSentiment structure.") | |
print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: base_model_for_config_id: {base_model_for_config_id}") # Logging | |
model_config = ModernBertConfig.from_pretrained( | |
base_model_for_config_id, | |
num_labels=model_yaml_cfg.get('num_labels', 1), # from config.yaml via model_yaml_cfg | |
pooling_strategy=model_yaml_cfg.get('pooling_strategy', 'mean'), # from config.yaml via model_yaml_cfg | |
num_weighted_layers=model_yaml_cfg.get('num_weighted_layers', 4) # from config.yaml via model_yaml_cfg | |
) | |
print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Loaded ModernBertConfig: {model_config.to_diff_dict()}") # Logging | |
print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Initializing ModernBertForSentiment with this config.") # Logging | |
self.model = ModernBertForSentiment(config=model_config) | |
print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Loading weights from checkpoint: {local_model_weights_path}") # Logging | |
checkpoint = torch.load(local_model_weights_path, map_location=torch.device('cpu')) | |
state_dict_to_load = checkpoint.get('model_state_dict', checkpoint.get('state_dict', checkpoint)) | |
if not isinstance(state_dict_to_load, dict): | |
raise TypeError(f"Loaded checkpoint from {local_model_weights_path} is not a dict or does not contain 'model_state_dict' or 'state_dict'.") | |
# Log first few keys for debugging | |
first_few_keys = list(state_dict_to_load.keys())[:5] | |
print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: First few keys from checkpoint state_dict: {first_few_keys}") # Logging | |
self.model.load_state_dict(state_dict_to_load) | |
print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Weights loaded successfully into ModernBertForSentiment from {local_model_weights_path}.") # Logging | |
else: | |
# Load from Hugging Face Hub | |
print(f"[INFERENCE_LOG] Attempting to load model from HUGGING_FACE_HUB: {model_hf_repo_id}") # Logging | |
# Here, we use the config that's packaged with the model on the Hub by default. | |
# We just add/override num_labels, pooling_strategy, num_weighted_layers if they are in our local config.yaml | |
# as these might be specific to our fine-tuning and not in the Hub's default config.json. | |
hub_config_overrides = { | |
"num_labels": model_yaml_cfg.get('num_labels', 1), | |
"pooling_strategy": model_yaml_cfg.get('pooling_strategy', 'mean'), | |
"num_weighted_layers": model_yaml_cfg.get('num_weighted_layers', 6) # Default to 6 now | |
} | |
print(f"[INFERENCE_LOG] HUB_LOAD: Overrides for Hub config: {hub_config_overrides}") # Logging | |
try: | |
# Using ModernBertForSentiment.from_pretrained directly. | |
# This assumes the config.json on the Hub for 'model_hf_repo_id' is compatible | |
# or that from_pretrained can correctly initialize ModernBertForSentiment with it. | |
self.model = ModernBertForSentiment.from_pretrained( | |
model_hf_repo_id, | |
**hub_config_overrides | |
) | |
print(f"[INFERENCE_LOG] HUB_LOAD: Model ModernBertForSentiment loaded successfully from {model_hf_repo_id}.") # Logging | |
except Exception as e: | |
print(f"[INFERENCE_LOG] HUB_LOAD: Error loading ModernBertForSentiment from {model_hf_repo_id}: {e}") # Logging | |
print(f"[INFERENCE_LOG] HUB_LOAD: Falling back to AutoModelForSequenceClassification for {model_hf_repo_id}.") # Logging | |
# Fallback: Try with AutoModelForSequenceClassification if ModernBertForSentiment fails | |
# This might happen if the Hub model isn't strictly saved as a ModernBertForSentiment type | |
# or if its config.json doesn't have _custom_class set, etc. | |
self.model = AutoModelForSequenceClassification.from_pretrained( | |
model_hf_repo_id, | |
**hub_config_overrides | |
) | |
print(f"[INFERENCE_LOG] HUB_LOAD: AutoModelForSequenceClassification loaded for {model_hf_repo_id}.") # Logging | |
self.model.eval() | |
def predict(self, text: str) -> Dict[str, Any]: | |
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True) | |
with torch.no_grad(): | |
outputs = self.model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) | |
logits = outputs.get("logits") # Use .get for safety | |
if logits is None: | |
raise ValueError("Model output did not contain 'logits'. Check model's forward pass.") | |
prob = torch.sigmoid(logits).item() | |
return {"sentiment": "positive" if prob > 0.5 else "negative", "confidence": prob} |