Uhhy commited on
Commit
da3119b
·
verified ·
1 Parent(s): cf56a9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -30
app.py CHANGED
@@ -11,6 +11,7 @@ load_dotenv()
11
 
12
  app = FastAPI()
13
 
 
14
  global_data = {
15
  'models': {},
16
  'tokens': {
@@ -75,7 +76,7 @@ class ModelManager:
75
  self.loaded = True
76
  return models
77
  except Exception as e:
78
- print(f"Error loading all models: {e}")
79
  return []
80
 
81
  model_manager = ModelManager()
@@ -113,49 +114,34 @@ def remove_repetitive_responses(responses):
113
  unique_responses.append(response)
114
  return unique_responses
115
 
116
- def generate_chat_response(request, model_data):
117
- model = model_data['model']
118
- try:
119
- user_input = normalize_input(request.message)
120
- response = model(user_input, top_k=request.top_k, top_p=request.top_p, temperature=request.temperature)
121
- return response
122
- except Exception as e:
123
- print(f"Error generating response with model {model_data['name']}: {e}")
124
- return None
125
-
126
- @app.post("/generate")
127
  @spaces.GPU(duration=0)
 
128
  async def generate(request: ChatRequest):
129
  try:
130
- responses = []
131
- models = global_data['models']
132
- for model_data in models:
133
- response = generate_chat_response(request, model_data)
134
- if response:
135
- responses.append({
136
- "model": model_data['name'],
137
- "response": response
138
- })
139
-
140
- if not responses:
141
- raise HTTPException(status_code=500, detail="Error: No responses generated.")
142
 
143
- responses = remove_repetitive_responses(responses)
 
 
144
  best_response = responses[0] if responses else {}
 
145
  return {
146
  "best_response": best_response,
147
- "all_responses": responses
148
  }
149
- except Exception:
150
- pass
 
151
 
152
  @app.api_route("/{method_name:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
153
  async def handle_request(method_name: str, request: Request):
154
  try:
155
  body = await request.json()
156
  return {"message": "Request handled successfully", "body": body}
157
- except Exception:
158
- pass
 
159
 
160
  if __name__ == "__main__":
161
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
11
 
12
  app = FastAPI()
13
 
14
+ # Global data storage
15
  global_data = {
16
  'models': {},
17
  'tokens': {
 
76
  self.loaded = True
77
  return models
78
  except Exception as e:
79
+ print(f"Error loading models: {e}")
80
  return []
81
 
82
  model_manager = ModelManager()
 
114
  unique_responses.append(response)
115
  return unique_responses
116
 
 
 
 
 
 
 
 
 
 
 
 
117
  @spaces.GPU(duration=0)
118
+ @app.post("/generate")
119
  async def generate(request: ChatRequest):
120
  try:
121
+ if not global_data['models']:
122
+ raise HTTPException(status_code=500, detail="Models not loaded")
 
 
 
 
 
 
 
 
 
 
123
 
124
+ model = global_data['models'][0]['model']
125
+ inputs = normalize_input(request.message)
126
+ responses = model.generate(inputs, top_k=request.top_k, top_p=request.top_p, temperature=request.temperature)
127
  best_response = responses[0] if responses else {}
128
+ unique_responses = remove_repetitive_responses(responses)
129
  return {
130
  "best_response": best_response,
131
+ "all_responses": unique_responses
132
  }
133
+ except Exception as e:
134
+ print(f"Error in generate endpoint: {e}")
135
+ return {"error": str(e)}
136
 
137
  @app.api_route("/{method_name:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
138
  async def handle_request(method_name: str, request: Request):
139
  try:
140
  body = await request.json()
141
  return {"message": "Request handled successfully", "body": body}
142
+ except Exception as e:
143
+ print(f"Error handling request: {e}")
144
+ return {"error": str(e)}
145
 
146
  if __name__ == "__main__":
147
  uvicorn.run(app, host="0.0.0.0", port=7860)