Sushwetabm commited on
Commit
aff0b1f
Β·
1 Parent(s): ce3ac0e

updated model.py

Browse files
Files changed (1) hide show
  1. model.py +17 -47
model.py CHANGED
@@ -1,4 +1,4 @@
1
- # model.py - Optimized version
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
  from functools import lru_cache
@@ -9,117 +9,87 @@ import logging
9
 
10
  logger = logging.getLogger(__name__)
11
 
12
- # Global variables to store loaded model
13
  _tokenizer = None
14
  _model = None
15
  _model_loading = False
16
  _model_loaded = False
17
 
18
  @lru_cache(maxsize=1)
19
- # def get_model_config():
20
- # """Cache model configuration"""
21
- # return {
22
- # "model_id": "deepseek-ai/deepseek-coder-1.3b-instruct",
23
- # "torch_dtype": torch.bfloat16,
24
- # "device_map": "auto",
25
- # "trust_remote_code": True,
26
- # # Add these optimizations
27
- # "low_cpu_mem_usage": True,
28
- # "use_cache": True,
29
- # }
30
  def get_model_config():
31
  return {
32
  "model_id": "Salesforce/codet5p-220m",
33
  "trust_remote_code": True
34
  }
 
35
  def load_model_sync():
36
- """Synchronous model loading with optimizations"""
37
  global _tokenizer, _model, _model_loaded
38
-
39
  if _model_loaded:
40
  return _tokenizer, _model
41
-
42
  config = get_model_config()
43
  model_id = config["model_id"]
44
-
45
  logger.info(f"πŸ”§ Loading model {model_id}...")
46
-
47
  try:
48
- # Set cache directory to avoid re-downloading
49
  cache_dir = os.environ.get("TRANSFORMERS_CACHE", "./model_cache")
50
  os.makedirs(cache_dir, exist_ok=True)
51
-
52
- # Load tokenizer first (faster)
53
  logger.info("πŸ“ Loading tokenizer...")
54
  _tokenizer = AutoTokenizer.from_pretrained(
55
  model_id,
56
  trust_remote_code=config["trust_remote_code"],
57
  cache_dir=cache_dir,
58
- use_fast=True, # Use fast tokenizer if available
59
  )
60
-
61
- # Load model with optimizations
62
  logger.info("🧠 Loading model...")
63
  _model = AutoModelForSeq2SeqLM.from_pretrained(
64
  model_id,
65
  trust_remote_code=config["trust_remote_code"],
66
- torch_dtype=config["torch_dtype"],
67
- device_map=config["device_map"],
68
- low_cpu_mem_usage=config["low_cpu_mem_usage"],
69
- cache_dir=cache_dir,
70
- offload_folder="offload",
71
- offload_state_dict=True
72
  )
73
-
74
- # Set to evaluation mode
75
  _model.eval()
76
-
77
  _model_loaded = True
78
  logger.info("βœ… Model loaded successfully!")
79
  return _tokenizer, _model
80
-
81
  except Exception as e:
82
  logger.error(f"❌ Failed to load model: {e}")
83
  raise
84
 
85
  async def load_model_async():
86
- """Asynchronous model loading"""
87
  global _model_loading
88
-
89
  if _model_loaded:
90
  return _tokenizer, _model
91
-
92
  if _model_loading:
93
- # Wait for ongoing loading to complete
94
  while _model_loading and not _model_loaded:
95
  await asyncio.sleep(0.1)
96
  return _tokenizer, _model
97
-
98
  _model_loading = True
99
-
100
  try:
101
- # Run model loading in thread pool to avoid blocking
102
  loop = asyncio.get_event_loop()
103
  with ThreadPoolExecutor(max_workers=1) as executor:
104
- tokenizer, model = await loop.run_in_executor(
105
- executor, load_model_sync
106
- )
107
  return tokenizer, model
108
  finally:
109
  _model_loading = False
110
 
111
  def get_model():
112
- """Get the loaded model (for synchronous access)"""
113
  if not _model_loaded:
114
  return load_model_sync()
115
  return _tokenizer, _model
116
 
117
  def is_model_loaded():
118
- """Check if model is loaded"""
119
  return _model_loaded
120
 
121
  def get_model_info():
122
- """Get model information without loading"""
123
  config = get_model_config()
124
  return {
125
  "model_id": config["model_id"],
 
1
+ # model.py - Fixed for CodeT5+
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
  from functools import lru_cache
 
9
 
10
  logger = logging.getLogger(__name__)
11
 
 
12
  _tokenizer = None
13
  _model = None
14
  _model_loading = False
15
  _model_loaded = False
16
 
17
  @lru_cache(maxsize=1)
 
 
 
 
 
 
 
 
 
 
 
18
  def get_model_config():
19
  return {
20
  "model_id": "Salesforce/codet5p-220m",
21
  "trust_remote_code": True
22
  }
23
+
24
  def load_model_sync():
 
25
  global _tokenizer, _model, _model_loaded
26
+
27
  if _model_loaded:
28
  return _tokenizer, _model
29
+
30
  config = get_model_config()
31
  model_id = config["model_id"]
32
+
33
  logger.info(f"πŸ”§ Loading model {model_id}...")
34
+
35
  try:
 
36
  cache_dir = os.environ.get("TRANSFORMERS_CACHE", "./model_cache")
37
  os.makedirs(cache_dir, exist_ok=True)
38
+
 
39
  logger.info("πŸ“ Loading tokenizer...")
40
  _tokenizer = AutoTokenizer.from_pretrained(
41
  model_id,
42
  trust_remote_code=config["trust_remote_code"],
43
  cache_dir=cache_dir,
44
+ use_fast=True,
45
  )
46
+
 
47
  logger.info("🧠 Loading model...")
48
  _model = AutoModelForSeq2SeqLM.from_pretrained(
49
  model_id,
50
  trust_remote_code=config["trust_remote_code"],
51
+ cache_dir=cache_dir
 
 
 
 
 
52
  )
53
+
 
54
  _model.eval()
 
55
  _model_loaded = True
56
  logger.info("βœ… Model loaded successfully!")
57
  return _tokenizer, _model
58
+
59
  except Exception as e:
60
  logger.error(f"❌ Failed to load model: {e}")
61
  raise
62
 
63
  async def load_model_async():
 
64
  global _model_loading
65
+
66
  if _model_loaded:
67
  return _tokenizer, _model
68
+
69
  if _model_loading:
 
70
  while _model_loading and not _model_loaded:
71
  await asyncio.sleep(0.1)
72
  return _tokenizer, _model
73
+
74
  _model_loading = True
75
+
76
  try:
 
77
  loop = asyncio.get_event_loop()
78
  with ThreadPoolExecutor(max_workers=1) as executor:
79
+ tokenizer, model = await loop.run_in_executor(executor, load_model_sync)
 
 
80
  return tokenizer, model
81
  finally:
82
  _model_loading = False
83
 
84
  def get_model():
 
85
  if not _model_loaded:
86
  return load_model_sync()
87
  return _tokenizer, _model
88
 
89
  def is_model_loaded():
 
90
  return _model_loaded
91
 
92
  def get_model_info():
 
93
  config = get_model_config()
94
  return {
95
  "model_id": config["model_id"],