Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -5,13 +5,18 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
5 |
import uvicorn
|
6 |
import re
|
7 |
from dotenv import load_dotenv
|
|
|
8 |
import spaces
|
9 |
|
10 |
load_dotenv()
|
11 |
|
12 |
app = FastAPI()
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
15 |
global_data = {
|
16 |
'models': {},
|
17 |
'tokens': {
|
@@ -55,8 +60,7 @@ class ModelManager:
|
|
55 |
def load_model(self, model_config):
|
56 |
try:
|
57 |
return {"model": Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename']), "name": model_config['name']}
|
58 |
-
except Exception
|
59 |
-
print(f"Error loading model {model_config['name']}: {e}")
|
60 |
pass
|
61 |
|
62 |
def load_all_models(self):
|
@@ -72,13 +76,12 @@ class ModelManager:
|
|
72 |
if model:
|
73 |
models.append(model)
|
74 |
|
75 |
-
global_data['models'] = models
|
76 |
self.loaded = True
|
77 |
-
return models
|
78 |
-
except Exception
|
79 |
-
print(f"Error loading models: {e}")
|
80 |
pass
|
81 |
-
return
|
82 |
|
83 |
model_manager = ModelManager()
|
84 |
model_manager.load_all_models()
|
@@ -112,48 +115,30 @@ def remove_repetitive_responses(responses):
|
|
112 |
normalized_response = remove_duplicates(response['response'])
|
113 |
if normalized_response not in seen:
|
114 |
seen.add(normalized_response)
|
115 |
-
unique_responses.append(response)
|
116 |
return unique_responses
|
117 |
|
|
|
118 |
@spaces.GPU(duration=0)
|
119 |
-
async def generate_model_response(model, inputs, top_k, top_p, temperature):
|
120 |
-
try:
|
121 |
-
responses = model.generate(inputs, top_k=top_k, top_p=top_p, temperature=temperature)
|
122 |
-
return responses
|
123 |
-
except Exception as e:
|
124 |
-
print(f"Error generating model response: {e}")
|
125 |
-
pass
|
126 |
-
return []
|
127 |
-
|
128 |
-
@app.post("/generate")
|
129 |
async def generate(request: ChatRequest):
|
130 |
try:
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
return {"error": str(e)}
|
147 |
-
|
148 |
-
@app.api_route("/{method_name:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
|
149 |
-
async def handle_request(method_name: str, request: Request):
|
150 |
-
try:
|
151 |
-
body = await request.json()
|
152 |
-
return {"message": "Request handled successfully", "body": body}
|
153 |
except Exception as e:
|
154 |
-
|
155 |
-
pass
|
156 |
-
return {"error": str(e)}
|
157 |
|
158 |
if __name__ == "__main__":
|
159 |
-
uvicorn.run(app, host="0.0.0.0", port=
|
|
|
5 |
import uvicorn
|
6 |
import re
|
7 |
from dotenv import load_dotenv
|
8 |
+
from spaces.zero import ZeroGPU
|
9 |
import spaces
|
10 |
|
11 |
load_dotenv()
|
12 |
|
13 |
app = FastAPI()
|
14 |
|
15 |
+
try:
|
16 |
+
ZeroGPU.initialize()
|
17 |
+
except Exception:
|
18 |
+
pass
|
19 |
+
|
20 |
global_data = {
|
21 |
'models': {},
|
22 |
'tokens': {
|
|
|
60 |
def load_model(self, model_config):
|
61 |
try:
|
62 |
return {"model": Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename']), "name": model_config['name']}
|
63 |
+
except Exception:
|
|
|
64 |
pass
|
65 |
|
66 |
def load_all_models(self):
|
|
|
76 |
if model:
|
77 |
models.append(model)
|
78 |
|
79 |
+
global_data['models'] = {model['name']: model['model'] for model in models}
|
80 |
self.loaded = True
|
81 |
+
return global_data['models']
|
82 |
+
except Exception:
|
|
|
83 |
pass
|
84 |
+
return {}
|
85 |
|
86 |
model_manager = ModelManager()
|
87 |
model_manager.load_all_models()
|
|
|
115 |
normalized_response = remove_duplicates(response['response'])
|
116 |
if normalized_response not in seen:
|
117 |
seen.add(normalized_response)
|
118 |
+
unique_responses.append({'model': response['model'], 'response': normalized_response})
|
119 |
return unique_responses
|
120 |
|
121 |
+
@app.post("/generate/")
|
122 |
@spaces.GPU(duration=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
async def generate(request: ChatRequest):
|
124 |
try:
|
125 |
+
normalized_message = normalize_input(request.message)
|
126 |
+
with ThreadPoolExecutor() as executor:
|
127 |
+
futures = [executor.submit(model.generate, f"<s>[INST]{normalized_message} [/INST]",
|
128 |
+
top_k=request.top_k, top_p=request.top_p, temperature=request.temperature)
|
129 |
+
for model in global_data['models'].values()]
|
130 |
+
responses = []
|
131 |
+
for future, model_name in zip(as_completed(futures), global_data['models']):
|
132 |
+
generated_text = future.result()
|
133 |
+
responses.append({'model': model_name, 'response': generated_text})
|
134 |
+
|
135 |
+
return remove_repetitive_responses(responses)
|
136 |
+
except NotImplementedError as nie:
|
137 |
+
raise HTTPException(status_code=500, detail=str(nie))
|
138 |
+
except ZeroGPU.ZeroGPUException as gpu_exc:
|
139 |
+
raise HTTPException(status_code=500, detail=f"ZeroGPU Error: {gpu_exc}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
except Exception as e:
|
141 |
+
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
142 |
|
143 |
if __name__ == "__main__":
|
144 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|