Uhhy commited on
Commit
3964343
·
verified ·
1 Parent(s): 544fd0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -20
app.py CHANGED
@@ -50,32 +50,26 @@ model_configs = [
50
  class ModelManager:
51
  def __init__(self):
52
  self.loaded = False
 
53
 
54
  def load_model(self, model_config):
55
- try:
56
- return {"model": Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename']), "name": model_config['name']}
57
- except Exception as e:
58
- print(f"Error loading model {model_config['name']}: {e}")
59
- return None
60
 
61
  def load_all_models(self):
62
- if self.loaded:
63
- return global_data['models']
 
 
 
64
 
65
- with ThreadPoolExecutor() as executor:
66
- futures = [executor.submit(self.load_model, config) for config in model_configs]
67
- models = []
68
- for future in as_completed(futures):
69
- model = future.result()
70
- if model:
71
- models.append(model)
72
-
73
- global_data['models'] = {model['name']: model['model'] for model in models}
74
- self.loaded = True
75
- return global_data['models']
76
 
77
  model_manager = ModelManager()
78
- model_manager.load_all_models()
79
 
80
  class ChatRequest(BaseModel):
81
  message: str
@@ -103,7 +97,7 @@ def remove_duplicates(text):
103
  def generate_model_response(model, inputs, top_k, top_p, temperature):
104
  try:
105
  response = model(inputs, top_k=top_k, top_p=top_p, temperature=temperature)
106
- return remove_duplicates(response)
107
  except Exception as e:
108
  print(f"Error generating model response: {e}")
109
  return ""
 
50
  class ModelManager:
51
  def __init__(self):
52
  self.loaded = False
53
+ self.models = {}
54
 
55
  def load_model(self, model_config):
56
+ if model_config['name'] not in self.models:
57
+ try:
58
+ self.models[model_config['name']] = Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename'])
59
+ except Exception as e:
60
+ print(f"Error loading model {model_config['name']}: {e}")
61
 
62
  def load_all_models(self):
63
+ if not self.loaded:
64
+ with ThreadPoolExecutor() as executor:
65
+ for config in model_configs:
66
+ executor.submit(self.load_model, config)
67
+ self.loaded = True
68
 
69
+ return self.models
 
 
 
 
 
 
 
 
 
 
70
 
71
  model_manager = ModelManager()
72
+ global_data['models'] = model_manager.load_all_models()
73
 
74
  class ChatRequest(BaseModel):
75
  message: str
 
97
  def generate_model_response(model, inputs, top_k, top_p, temperature):
98
  try:
99
  response = model(inputs, top_k=top_k, top_p=top_p, temperature=temperature)
100
+ return remove_duplicates(response['choices'][0]['text'])
101
  except Exception as e:
102
  print(f"Error generating model response: {e}")
103
  return ""