|
|
|
|
|
from functools import lru_cache |
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
|
from .logging_config import logger |
|
import os |
|
|
|
MODEL_MAPPING = { |
|
"zero-shot-classification": { |
|
"primary": "distilbert-base-uncased", |
|
"fallback": "microsoft/DialoGPT-small", |
|
"local_fallback": "distilbert-base-uncased" |
|
}, |
|
"summarization": { |
|
"primary": "sshleifer/distilbart-cnn-6-6", |
|
"fallback": "t5-small", |
|
"local_fallback": "t5-small" |
|
}, |
|
"text-classification": { |
|
"primary": "distilbert-base-uncased", |
|
"fallback": "distilbert-base-uncased", |
|
"local_fallback": "distilbert-base-uncased" |
|
}, |
|
|
|
"text-generation": { |
|
"primary": "distilgpt2", |
|
"fallback": "gpt2" |
|
} |
|
} |
|
|
|
_model_cache = {} |
|
|
|
@lru_cache(maxsize=2) |
|
def load_model(task, model_name=None): |
|
try: |
|
fallback_used = None |
|
if task == "text-generation": |
|
model_name = "distilgpt2" |
|
elif model_name is None or model_name in MODEL_MAPPING.get(task, {}): |
|
model_config = MODEL_MAPPING.get(task, {}) |
|
if model_name is None: |
|
model_name = model_config.get("primary", "distilbert-base-uncased") |
|
cache_key = f"{task}_{model_name}" |
|
if cache_key in _model_cache: |
|
logger.info(f"Using cached model: {model_name} for task: {task}") |
|
return _model_cache[cache_key] |
|
logger.info(f"Loading model: {model_name} for task: {task}") |
|
model_kwargs = {"device": -1, "truncation": True} |
|
if task == "zero-shot-classification": |
|
model_kwargs.update({"max_length": 256, "truncation": True}) |
|
elif task == "summarization": |
|
model_kwargs.update({"max_length": 100, "min_length": 20, "do_sample": False, "num_beams": 1, "truncation": True}) |
|
elif task == "text-generation": |
|
model_kwargs.update({"max_length": 256, "do_sample": True, "temperature": 0.7, "top_p": 0.9, "repetition_penalty": 1.1, "truncation": True}) |
|
try: |
|
if task == "text-generation": |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
pad_token_id = tokenizer.eos_token_id if tokenizer.pad_token_id is None else tokenizer.pad_token_id |
|
pipe = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
device=-1, |
|
pad_token_id=pad_token_id, |
|
truncation=True |
|
) |
|
pipe.fallback_used = False |
|
_model_cache[cache_key] = pipe |
|
logger.info(f"Successfully loaded text-generation model: {model_name}") |
|
return pipe |
|
else: |
|
model = pipeline(task, model=model_name, **model_kwargs) |
|
model.fallback_used = False |
|
_model_cache[cache_key] = model |
|
logger.info(f"Successfully loaded model: {model_name}") |
|
return model |
|
except Exception as e: |
|
logger.warning(f"Failed to load primary model {model_name} for {task}: {str(e)}") |
|
|
|
model_config = MODEL_MAPPING.get(task, {}) |
|
for fallback_key in ["fallback", "local_fallback"]: |
|
fallback_model = model_config.get(fallback_key) |
|
if fallback_model and fallback_model != model_name: |
|
try: |
|
logger.info(f"Trying fallback model: {fallback_model} for {task}") |
|
model = pipeline(task, model=fallback_model, device=-1, truncation=True) |
|
model.fallback_used = True |
|
model.fallback_model = fallback_model |
|
_model_cache[f"{task}_{fallback_model}"] = model |
|
logger.info(f"Loaded fallback model: {fallback_model} for {task}") |
|
return model |
|
except Exception as e2: |
|
logger.warning(f"Failed to load fallback model {fallback_model} for {task}: {str(e2)}") |
|
logger.error(f"All model loading failed for {task}, using static fallback.") |
|
return create_text_fallback(task) |
|
except Exception as e: |
|
logger.error(f"Error in load_model: {str(e)}") |
|
return create_text_fallback(task) |
|
|
|
def create_text_fallback(task): |
|
class TextFallback: |
|
def __init__(self, task_type): |
|
self.task_type = task_type |
|
self.fallback_used = True |
|
self.fallback_model = "static_fallback" |
|
def __call__(self, text, *args, **kwargs): |
|
if self.task_type == "text-generation": |
|
return [{"generated_text": "Summary unavailable: Unable to load TinyLlama model. Please check system memory or model availability."}] |
|
elif self.task_type == "zero-shot-classification": |
|
text_lower = text.lower() |
|
labels = args[0] if args else ["positive", "negative"] |
|
scores = [] |
|
for label in labels: |
|
if label.lower() in text_lower: |
|
scores.append(0.8) |
|
else: |
|
scores.append(0.2) |
|
return {"labels": labels, "scores": scores} |
|
elif self.task_type == "summarization": |
|
sentences = text.split('.') |
|
if len(sentences) > 3: |
|
summary = '. '.join(sentences[:2]) + '.' |
|
else: |
|
summary = text[:200] + ('...' if len(text) > 200 else '') |
|
return [{"summary_text": summary}] |
|
else: |
|
return {"result": "Model unavailable, using fallback"} |
|
return TextFallback(task) |
|
|
|
def clear_model_cache(): |
|
global _model_cache |
|
_model_cache.clear() |
|
logger.info("Model cache cleared") |
|
|
|
def get_available_models(): |
|
return MODEL_MAPPING |
|
|