File size: 5,562 Bytes
6d5a8ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f59cf24
 
c16d4e7
f59cf24
6d5a8ce
f59cf24
 
 
 
 
6d5a8ce
 
f59cf24
 
 
 
6d5a8ce
f59cf24
 
6d5a8ce
f59cf24
 
6d5a8ce
f59cf24
6d5a8ce
 
f59cf24
 
 
6d5a8ce
f59cf24
 
 
c16d4e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# # model.py - Optimized version
# from transformers import AutoTokenizer, AutoModelForCausalLM
# import torch
# from functools import lru_cache
# import os
# import asyncio
# from concurrent.futures import ThreadPoolExecutor
# import logging

# logger = logging.getLogger(__name__)

# # Global variables to store loaded model
# _tokenizer = None
# _model = None
# _model_loading = False
# _model_loaded = False

# @lru_cache(maxsize=1)
# def get_model_config():
#     """Cache model configuration"""
#     return {
#         "model_id": "deepseek-ai/deepseek-coder-1.3b-instruct",
#         "torch_dtype": torch.bfloat16,
#         "device_map": "auto",
#         "trust_remote_code": True,
#         # Add these optimizations
#         "low_cpu_mem_usage": True,
#         "use_cache": True,
#     }

# def load_model_sync():
#     """Synchronous model loading with optimizations"""
#     global _tokenizer, _model, _model_loaded
    
#     if _model_loaded:
#         return _tokenizer, _model
    
#     config = get_model_config()
#     model_id = config["model_id"]
    
#     logger.info(f"πŸ”§ Loading model {model_id}...")
    
#     try:
#         # Set cache directory to avoid re-downloading
#         cache_dir = os.environ.get("TRANSFORMERS_CACHE", "./model_cache")
#         os.makedirs(cache_dir, exist_ok=True)
        
#         # Load tokenizer first (faster)
#         logger.info("πŸ“ Loading tokenizer...")
#         _tokenizer = AutoTokenizer.from_pretrained(
#             model_id,
#             trust_remote_code=config["trust_remote_code"],
#             cache_dir=cache_dir,
#             use_fast=True,  # Use fast tokenizer if available
#         )
        
#         # Load model with optimizations
#         logger.info("🧠 Loading model...")
#         _model = AutoModelForCausalLM.from_pretrained(
#             model_id,
#             trust_remote_code=config["trust_remote_code"],
#             torch_dtype=config["torch_dtype"],
#             device_map=config["device_map"],
#             low_cpu_mem_usage=config["low_cpu_mem_usage"],
#             cache_dir=cache_dir,
#             offload_folder="offload",  
#              offload_state_dict=True      
#         )
        
#         # Set to evaluation mode
#         _model.eval()
        
#         _model_loaded = True
#         logger.info("βœ… Model loaded successfully!")
#         return _tokenizer, _model
        
#     except Exception as e:
#         logger.error(f"❌ Failed to load model: {e}")
#         raise

# async def load_model_async():
#     """Asynchronous model loading"""
#     global _model_loading
    
#     if _model_loaded:
#         return _tokenizer, _model
    
#     if _model_loading:
#         # Wait for ongoing loading to complete
#         while _model_loading and not _model_loaded:
#             await asyncio.sleep(0.1)
#         return _tokenizer, _model
    
#     _model_loading = True
    
#     try:
#         # Run model loading in thread pool to avoid blocking
#         loop = asyncio.get_event_loop()
#         with ThreadPoolExecutor(max_workers=1) as executor:
#             tokenizer, model = await loop.run_in_executor(
#                 executor, load_model_sync
#             )
#         return tokenizer, model
#     finally:
#         _model_loading = False

# def get_model():
#     """Get the loaded model (for synchronous access)"""
#     if not _model_loaded:
#         return load_model_sync()
#     return _tokenizer, _model

# def is_model_loaded():
#     """Check if model is loaded"""
#     return _model_loaded

# def get_model_info():
#     """Get model information without loading"""
#     config = get_model_config()
#     return {
#         "model_id": config["model_id"],
#         "loaded": _model_loaded,
#         "loading": _model_loading,
#     }

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from functools import lru_cache
import logging
import asyncio
logger = logging.getLogger(__name__)
_model_loaded = False
_tokenizer = None
_model = None
@lru_cache(maxsize=1)
def get_model_config():
    return {
        "model_id": "Salesforce/codet5p-220m",
        "trust_remote_code": True
    }

def load_model_sync():
    global _tokenizer, _model, _model_loaded

    if _model_loaded:
        return _tokenizer, _model

    config = get_model_config()
    model_id = config["model_id"]

    try:
        _tokenizer = AutoTokenizer.from_pretrained(model_id)
        _model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
        _model.eval()
        _model_loaded = True
        return _tokenizer, _model

    except Exception as e:
        logger.error(f"❌ Failed to load model: {e}")
        raise


async def load_model_async():
    global _tokenizer, _model, _model_loaded
    if _model_loaded:
        return

    config = get_model_config()
    model_id = config["model_id"]

    try:
        _tokenizer = AutoTokenizer.from_pretrained(model_id)
        _model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
        _model.eval()
        _model_loaded = True
        logger.info(f"βœ… Model {model_id} loaded successfully.")
    except Exception as e:
        logger.error(f"❌ Failed to load model: {e}")
        raise

def get_model():
    if not _model_loaded:
        raise ValueError("Model not loaded yet")
    return _tokenizer, _model

def is_model_loaded():
    return _model_loaded

def get_model_info():
    return {
        "model_id": get_model_config()["model_id"],
        "loaded": _model_loaded,
        "loading": not _model_loaded
    }