propertyverification / models /model_loader.py
sksameermujahid's picture
Upload 23 files
6e3dbdb verified
# models/model_loader.py
from functools import lru_cache
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from .logging_config import logger
import os
MODEL_MAPPING = {
"zero-shot-classification": {
"primary": "distilbert-base-uncased", # Much smaller than BART
"fallback": "microsoft/DialoGPT-small", # Very small
"local_fallback": "distilbert-base-uncased"
},
"summarization": {
"primary": "sshleifer/distilbart-cnn-6-6", # Already small
"fallback": "t5-small", # Very small
"local_fallback": "t5-small"
},
"text-classification": {
"primary": "distilbert-base-uncased", # Already small
"fallback": "distilbert-base-uncased",
"local_fallback": "distilbert-base-uncased"
},
# Use a much smaller model for text generation
"text-generation": {
"primary": "distilgpt2", # Much smaller than TinyLlama
"fallback": "gpt2" # Small fallback
}
}
_model_cache = {}
@lru_cache(maxsize=2)
def load_model(task, model_name=None):
try:
fallback_used = None
if task == "text-generation":
model_name = "distilgpt2" # Use distilgpt2 instead of TinyLlama
elif model_name is None or model_name in MODEL_MAPPING.get(task, {}):
model_config = MODEL_MAPPING.get(task, {})
if model_name is None:
model_name = model_config.get("primary", "distilbert-base-uncased")
cache_key = f"{task}_{model_name}"
if cache_key in _model_cache:
logger.info(f"Using cached model: {model_name} for task: {task}")
return _model_cache[cache_key]
logger.info(f"Loading model: {model_name} for task: {task}")
model_kwargs = {"device": -1, "truncation": True}
if task == "zero-shot-classification":
model_kwargs.update({"max_length": 256, "truncation": True}) # Reduced max_length
elif task == "summarization":
model_kwargs.update({"max_length": 100, "min_length": 20, "do_sample": False, "num_beams": 1, "truncation": True}) # Reduced lengths
elif task == "text-generation":
model_kwargs.update({"max_length": 256, "do_sample": True, "temperature": 0.7, "top_p": 0.9, "repetition_penalty": 1.1, "truncation": True}) # Reduced max_length
try:
if task == "text-generation":
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
pad_token_id = tokenizer.eos_token_id if tokenizer.pad_token_id is None else tokenizer.pad_token_id
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=-1,
pad_token_id=pad_token_id,
truncation=True
)
pipe.fallback_used = False
_model_cache[cache_key] = pipe
logger.info(f"Successfully loaded text-generation model: {model_name}")
return pipe
else:
model = pipeline(task, model=model_name, **model_kwargs)
model.fallback_used = False
_model_cache[cache_key] = model
logger.info(f"Successfully loaded model: {model_name}")
return model
except Exception as e:
logger.warning(f"Failed to load primary model {model_name} for {task}: {str(e)}")
# Try fallback and local_fallback
model_config = MODEL_MAPPING.get(task, {})
for fallback_key in ["fallback", "local_fallback"]:
fallback_model = model_config.get(fallback_key)
if fallback_model and fallback_model != model_name: # Don't try the same model again
try:
logger.info(f"Trying fallback model: {fallback_model} for {task}")
model = pipeline(task, model=fallback_model, device=-1, truncation=True)
model.fallback_used = True
model.fallback_model = fallback_model
_model_cache[f"{task}_{fallback_model}"] = model
logger.info(f"Loaded fallback model: {fallback_model} for {task}")
return model
except Exception as e2:
logger.warning(f"Failed to load fallback model {fallback_model} for {task}: {str(e2)}")
logger.error(f"All model loading failed for {task}, using static fallback.")
return create_text_fallback(task)
except Exception as e:
logger.error(f"Error in load_model: {str(e)}")
return create_text_fallback(task)
def create_text_fallback(task):
class TextFallback:
def __init__(self, task_type):
self.task_type = task_type
self.fallback_used = True
self.fallback_model = "static_fallback"
def __call__(self, text, *args, **kwargs):
if self.task_type == "text-generation":
return [{"generated_text": "Summary unavailable: Unable to load TinyLlama model. Please check system memory or model availability."}]
elif self.task_type == "zero-shot-classification":
text_lower = text.lower()
labels = args[0] if args else ["positive", "negative"]
scores = []
for label in labels:
if label.lower() in text_lower:
scores.append(0.8)
else:
scores.append(0.2)
return {"labels": labels, "scores": scores}
elif self.task_type == "summarization":
sentences = text.split('.')
if len(sentences) > 3:
summary = '. '.join(sentences[:2]) + '.'
else:
summary = text[:200] + ('...' if len(text) > 200 else '')
return [{"summary_text": summary}]
else:
return {"result": "Model unavailable, using fallback"}
return TextFallback(task)
def clear_model_cache():
global _model_cache
_model_cache.clear()
logger.info("Model cache cleared")
def get_available_models():
return MODEL_MAPPING