Uhhy commited on
Commit
2aca525
·
verified ·
1 Parent(s): 4c925e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -45
app.py CHANGED
@@ -5,13 +5,18 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
5
  import uvicorn
6
  import re
7
  from dotenv import load_dotenv
 
8
  import spaces
9
 
10
  load_dotenv()
11
 
12
  app = FastAPI()
13
 
14
- # Global data storage
 
 
 
 
15
  global_data = {
16
  'models': {},
17
  'tokens': {
@@ -55,8 +60,7 @@ class ModelManager:
55
  def load_model(self, model_config):
56
  try:
57
  return {"model": Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename']), "name": model_config['name']}
58
- except Exception as e:
59
- print(f"Error loading model {model_config['name']}: {e}")
60
  pass
61
 
62
  def load_all_models(self):
@@ -72,13 +76,12 @@ class ModelManager:
72
  if model:
73
  models.append(model)
74
 
75
- global_data['models'] = models
76
  self.loaded = True
77
- return models
78
- except Exception as e:
79
- print(f"Error loading models: {e}")
80
  pass
81
- return []
82
 
83
  model_manager = ModelManager()
84
  model_manager.load_all_models()
@@ -112,48 +115,30 @@ def remove_repetitive_responses(responses):
112
  normalized_response = remove_duplicates(response['response'])
113
  if normalized_response not in seen:
114
  seen.add(normalized_response)
115
- unique_responses.append(response)
116
  return unique_responses
117
 
 
118
  @spaces.GPU(duration=0)
119
- async def generate_model_response(model, inputs, top_k, top_p, temperature):
120
- try:
121
- responses = model.generate(inputs, top_k=top_k, top_p=top_p, temperature=temperature)
122
- return responses
123
- except Exception as e:
124
- print(f"Error generating model response: {e}")
125
- pass
126
- return []
127
-
128
- @app.post("/generate")
129
  async def generate(request: ChatRequest):
130
  try:
131
- if not global_data['models']:
132
- raise HTTPException(status_code=500, detail="Models not loaded")
133
-
134
- model = global_data['models'][0]['model']
135
- inputs = normalize_input(request.message)
136
- responses = await generate_model_response(model, inputs, request.top_k, request.top_p, request.temperature)
137
- best_response = responses[0] if responses else {}
138
- unique_responses = remove_repetitive_responses(responses)
139
- return {
140
- "best_response": best_response,
141
- "all_responses": unique_responses
142
- }
143
- except Exception as e:
144
- print(f"Error in generate endpoint: {e}")
145
- pass
146
- return {"error": str(e)}
147
-
148
- @app.api_route("/{method_name:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
149
- async def handle_request(method_name: str, request: Request):
150
- try:
151
- body = await request.json()
152
- return {"message": "Request handled successfully", "body": body}
153
  except Exception as e:
154
- print(f"Error handling request: {e}")
155
- pass
156
- return {"error": str(e)}
157
 
158
  if __name__ == "__main__":
159
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
5
  import uvicorn
6
  import re
7
  from dotenv import load_dotenv
8
+ from spaces.zero import ZeroGPU
9
  import spaces
10
 
11
  load_dotenv()
12
 
13
  app = FastAPI()
14
 
15
+ try:
16
+ ZeroGPU.initialize()
17
+ except Exception:
18
+ pass
19
+
20
  global_data = {
21
  'models': {},
22
  'tokens': {
 
60
  def load_model(self, model_config):
61
  try:
62
  return {"model": Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename']), "name": model_config['name']}
63
+ except Exception:
 
64
  pass
65
 
66
  def load_all_models(self):
 
76
  if model:
77
  models.append(model)
78
 
79
+ global_data['models'] = {model['name']: model['model'] for model in models}
80
  self.loaded = True
81
+ return global_data['models']
82
+ except Exception:
 
83
  pass
84
+ return {}
85
 
86
  model_manager = ModelManager()
87
  model_manager.load_all_models()
 
115
  normalized_response = remove_duplicates(response['response'])
116
  if normalized_response not in seen:
117
  seen.add(normalized_response)
118
+ unique_responses.append({'model': response['model'], 'response': normalized_response})
119
  return unique_responses
120
 
121
+ @app.post("/generate/")
122
  @spaces.GPU(duration=0)
 
 
 
 
 
 
 
 
 
 
123
  async def generate(request: ChatRequest):
124
  try:
125
+ normalized_message = normalize_input(request.message)
126
+ with ThreadPoolExecutor() as executor:
127
+ futures = [executor.submit(model.generate, f"<s>[INST]{normalized_message} [/INST]",
128
+ top_k=request.top_k, top_p=request.top_p, temperature=request.temperature)
129
+ for model in global_data['models'].values()]
130
+ responses = []
131
+ for future, model_name in zip(as_completed(futures), global_data['models']):
132
+ generated_text = future.result()
133
+ responses.append({'model': model_name, 'response': generated_text})
134
+
135
+ return remove_repetitive_responses(responses)
136
+ except NotImplementedError as nie:
137
+ raise HTTPException(status_code=500, detail=str(nie))
138
+ except ZeroGPU.ZeroGPUException as gpu_exc:
139
+ raise HTTPException(status_code=500, detail=f"ZeroGPU Error: {gpu_exc}")
 
 
 
 
 
 
 
140
  except Exception as e:
141
+ raise HTTPException(status_code=500, detail=str(e))
 
 
142
 
143
  if __name__ == "__main__":
144
+ uvicorn.run(app, host="0.0.0.0", port=8000)