nidra / app.py
m1k3wn's picture
Update app.py
16c0a32 verified
raw
history blame
2.18 kB
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import T5Tokenizer, T5ForConditionalGeneration, GenerationConfig
from typing import Optional, Dict, Any, ClassVar
import logging
import os
import sys
import traceback
import psutil
from functools import lru_cache
[... rest of your existing code until ModelManager class ...]
class ModelManager:
_instances: ClassVar[Dict[str, tuple]] = {}
@classmethod
def get_model_and_tokenizer(cls, model_name: str):
if model_name not in cls._instances:
try:
model_path = MODELS[model_name]
logger.info(f"Loading tokenizer for {model_name}")
tokenizer = T5Tokenizer.from_pretrained(
model_path,
token=HF_TOKEN,
local_files_only=False,
return_special_tokens_mask=True
)
logger.info(f"Loading model {model_name}")
model = T5ForConditionalGeneration.from_pretrained(
model_path,
token=HF_TOKEN,
local_files_only=False,
low_cpu_mem_usage=True,
torch_dtype=torch.float32
).cpu()
cls._instances[model_name] = (model, tokenizer)
logger.info(f"Successfully loaded {model_name}")
except Exception as e:
logger.error(f"Error loading {model_name}: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Failed to load model {model_name}: {str(e)}"
)
return cls._instances[model_name]
[... rest of your existing code until before @app.get("/version") ...]
@app.get("/debug/memory")
async def memory_usage():
process = psutil.Process()
memory_info = process.memory_info()
return {
"memory_used_mb": memory_info.rss / 1024 / 1024,
"memory_percent": process.memory_percent(),
"cpu_percent": process.cpu_percent()
}
[... rest of your existing code ...]