File size: 2,178 Bytes
16c0a32
2095fff
78a09b4
5c94eeb
f7ed1d0
2095fff
7394c77
a4b1bdb
fab4412
16c0a32
f7ed1d0
a4b1bdb
16c0a32
f7ed1d0
 
 
16c0a32
f7ed1d0
 
 
 
 
 
 
 
 
 
 
 
 
 
16c0a32
 
 
 
 
 
 
f7ed1d0
 
 
 
 
 
 
 
 
 
 
16c0a32
4347c84
16c0a32
 
 
 
f7ed1d0
16c0a32
 
 
f7ed1d0
2580a1e
16c0a32
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import T5Tokenizer, T5ForConditionalGeneration, GenerationConfig
from typing import Optional, Dict, Any, ClassVar
import logging
import os
import sys
import traceback
import psutil
from functools import lru_cache

[... rest of your existing code until ModelManager class ...]

class ModelManager:
    _instances: ClassVar[Dict[str, tuple]] = {}
    
    @classmethod
    def get_model_and_tokenizer(cls, model_name: str):
        if model_name not in cls._instances:
            try:
                model_path = MODELS[model_name]
                logger.info(f"Loading tokenizer for {model_name}")
                tokenizer = T5Tokenizer.from_pretrained(
                    model_path,
                    token=HF_TOKEN,
                    local_files_only=False,
                    return_special_tokens_mask=True
                )
                
                logger.info(f"Loading model {model_name}")
                model = T5ForConditionalGeneration.from_pretrained(
                    model_path,
                    token=HF_TOKEN,
                    local_files_only=False,
                    low_cpu_mem_usage=True,
                    torch_dtype=torch.float32
                ).cpu()
                
                cls._instances[model_name] = (model, tokenizer)
                logger.info(f"Successfully loaded {model_name}")
            except Exception as e:
                logger.error(f"Error loading {model_name}: {str(e)}")
                raise HTTPException(
                    status_code=500,
                    detail=f"Failed to load model {model_name}: {str(e)}"
                )
        return cls._instances[model_name]

[... rest of your existing code until before @app.get("/version") ...]

@app.get("/debug/memory")
async def memory_usage():
    process = psutil.Process()
    memory_info = process.memory_info()
    return {
        "memory_used_mb": memory_info.rss / 1024 / 1024,
        "memory_percent": process.memory_percent(),
        "cpu_percent": process.cpu_percent()
    }

[... rest of your existing code ...]