Uhhy commited on
Commit
50f9f62
·
verified ·
1 Parent(s): 50a95c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -13
app.py CHANGED
@@ -7,14 +7,23 @@ from dotenv import load_dotenv
7
  import re
8
  import huggingface_hub
9
  import spaces
10
- import httpx
11
 
12
  load_dotenv()
13
 
14
  app = FastAPI()
15
 
16
  global_data = {
17
- 'models': []
 
 
 
 
 
 
 
 
 
 
18
  }
19
 
20
  model_configs = [
@@ -48,7 +57,6 @@ class ModelManager:
48
  def load_model(self, model_config):
49
  return {"model": Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename']), "name": model_config['name']}
50
 
51
- @spaces.GPU(duration=0)
52
  def load_all_models(self):
53
  if self.loaded:
54
  return self.models
@@ -109,22 +117,18 @@ def remove_repetitive_responses(responses):
109
  for response in responses:
110
  normalized_response = remove_duplicates(response['response'])
111
  if normalized_response not in seen:
112
- seen.add(normalized_response)
113
  unique_responses.append(response)
 
114
  return unique_responses
115
 
116
  def select_best_response(responses):
 
 
117
  responses = remove_repetitive_responses(responses)
118
- responses = [remove_duplicates(response['response']) for response in responses]
119
- unique_responses = list(dict.fromkeys(responses))
120
- sorted_responses = sorted(unique_responses, key=lambda r: len(r), reverse=True)
121
- return sorted_responses[0]
122
 
123
- @app.post("/generate_chat")
124
- async def generate_chat(request: ChatRequest):
125
- if not request.message.strip():
126
- raise HTTPException(status_code=400, detail="Error: No message provided.")
127
-
128
  responses = []
129
  num_models = len(global_data['models'])
130
 
 
7
  import re
8
  import huggingface_hub
9
  import spaces
 
10
 
11
  load_dotenv()
12
 
13
  app = FastAPI()
14
 
15
  global_data = {
16
+ 'models': [],
17
+ 'tokens': {
18
+ 'eos': 'eos_token',
19
+ 'pad': 'pad_token',
20
+ 'padding': 'padding_token',
21
+ 'unk': 'unk_token',
22
+ 'bos': 'bos_token',
23
+ 'sep': 'sep_token',
24
+ 'cls': 'cls_token',
25
+ 'mask': 'mask_token'
26
+ }
27
  }
28
 
29
  model_configs = [
 
57
  def load_model(self, model_config):
58
  return {"model": Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename']), "name": model_config['name']}
59
 
 
60
  def load_all_models(self):
61
  if self.loaded:
62
  return self.models
 
117
  for response in responses:
118
  normalized_response = remove_duplicates(response['response'])
119
  if normalized_response not in seen:
 
120
  unique_responses.append(response)
121
+ seen.add(normalized_response)
122
  return unique_responses
123
 
124
  def select_best_response(responses):
125
+ if not responses:
126
+ return ""
127
  responses = remove_repetitive_responses(responses)
128
+ return max(set(responses), key=lambda x: x['response'].count("user"))
 
 
 
129
 
130
+ @app.post("/generate")
131
+ def generate_chat(request: ChatRequest):
 
 
 
132
  responses = []
133
  num_models = len(global_data['models'])
134