Hjgugugjhuhjggg commited on
Commit
b394a28
·
verified ·
1 Parent(s): db6c4e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -2
app.py CHANGED
@@ -14,6 +14,9 @@ from nltk.corpus import stopwords
14
  from sklearn.feature_extraction.text import TfidfVectorizer
15
  from sklearn.metrics.pairwise import cosine_similarity
16
  import nltk
 
 
 
17
 
18
  nltk.download('punkt')
19
  nltk.download('stopwords')
@@ -52,7 +55,7 @@ class ModelManager:
52
  if model_name not in self.models:
53
  try:
54
  model_path = hf_hub_download(repo_id=config['repo_id'], use_auth_token=HUGGINGFACE_TOKEN)
55
- model = Llama.from_file(model_path)
56
  self.models[model_name] = model
57
  except Exception as e:
58
  self.models[model_name] = None
@@ -82,7 +85,7 @@ async def process_message(message: str) -> dict:
82
  inputs = message.strip()
83
  responses = {}
84
 
85
- with ThreadPoolExecutor(max_workers=len(global_data['model_configs'])) as executor:
86
  futures = [executor.submit(generate_model_response, model_manager.get_model(config['name']), inputs) for config in global_data['model_configs'] if model_manager.get_model(config['name'])]
87
  for i, future in enumerate(tqdm(as_completed(futures), total=len(futures), desc="Generating responses")):
88
  try:
@@ -127,6 +130,46 @@ async def startup_event():
127
  async def shutdown_event():
128
  gc.collect()
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  if __name__ == "__main__":
 
 
 
 
131
  port = int(os.environ.get("PORT", 7860))
132
  uvicorn.run(app, host="0.0.0.0", port=port)
 
14
  from sklearn.feature_extraction.text import TfidfVectorizer
15
  from sklearn.metrics.pairwise import cosine_similarity
16
  import nltk
17
+ import uvicorn
18
+ import psutil
19
+ import torch
20
 
21
  nltk.download('punkt')
22
  nltk.download('stopwords')
 
55
  if model_name not in self.models:
56
  try:
57
  model_path = hf_hub_download(repo_id=config['repo_id'], use_auth_token=HUGGINGFACE_TOKEN)
58
+ model = Llama.from_file(model_path, n_ctx=512, n_gpu=1)
59
  self.models[model_name] = model
60
  except Exception as e:
61
  self.models[model_name] = None
 
85
  inputs = message.strip()
86
  responses = {}
87
 
88
+ with ThreadPoolExecutor(max_workers=min(len(global_data['model_configs']), 4)) as executor:
89
  futures = [executor.submit(generate_model_response, model_manager.get_model(config['name']), inputs) for config in global_data['model_configs'] if model_manager.get_model(config['name'])]
90
  for i, future in enumerate(tqdm(as_completed(futures), total=len(futures), desc="Generating responses")):
91
  try:
 
130
  async def shutdown_event():
131
  gc.collect()
132
 
133
+ def release_resources():
134
+ try:
135
+ torch.cuda.empty_cache()
136
+ gc.collect()
137
+ except Exception as e:
138
+ print(f"Failed to release resources: {e}")
139
+
140
+ def resource_manager():
141
+ MAX_RAM_PERCENT = 1
142
+ MAX_CPU_PERCENT = 1
143
+ MAX_GPU_PERCENT = 1
144
+ MAX_RAM_MB = 1
145
+
146
+ while True:
147
+ try:
148
+ virtual_mem = psutil.virtual_memory()
149
+ current_ram_percent = virtual_mem.percent
150
+ current_ram_mb = virtual_mem.used / (1024 * 1024)
151
+
152
+ if current_ram_percent > MAX_RAM_PERCENT or current_ram_mb > MAX_RAM_MB:
153
+ release_resources()
154
+
155
+ current_cpu_percent = psutil.cpu_percent()
156
+ if current_cpu_percent > MAX_CPU_PERCENT:
157
+ psutil.Process(os.getpid()).nice()
158
+
159
+ if torch.cuda.is_available():
160
+ gpu = torch.cuda.current_device()
161
+ gpu_mem = torch.cuda.memory_percent(gpu)
162
+
163
+ if gpu_mem > MAX_GPU_PERCENT:
164
+ release_resources()
165
+
166
+ except Exception as e:
167
+ print(f"Error in resource manager: {e}")
168
+
169
  if __name__ == "__main__":
170
+ import threading
171
+ resource_thread = threading.Thread(target=resource_manager)
172
+ resource_thread.daemon = True
173
+ resource_thread.start()
174
  port = int(os.environ.get("PORT", 7860))
175
  uvicorn.run(app, host="0.0.0.0", port=port)