Ibraaheem commited on
Commit
5f65a30
·
1 Parent(s): af78253

Update private_gpt/launcher.py

Browse files
Files changed (1) hide show
  1. private_gpt/launcher.py +5 -46
private_gpt/launcher.py CHANGED
@@ -2,7 +2,6 @@
2
  import logging
3
  from typing import Any
4
 
5
-
6
  from fastapi import Depends, FastAPI, Request, status, HTTPException
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi.openapi.utils import get_openapi
@@ -27,10 +26,6 @@ from private_gpt.server.utils.authentication import get_current_user
27
 
28
  from fastapi import Depends, HTTPException, status
29
  from fastapi.security import OAuth2AuthorizationCodeBearer
30
- from fastapi import FastAPI, Body, HTTPException
31
- from pydantic import BaseModel, validator
32
- import yaml
33
-
34
 
35
 
36
 
@@ -96,7 +91,7 @@ def create_app(root_injector: Injector) -> FastAPI:
96
 
97
  @model_router.post("/switch_model")
98
  async def switch_model(
99
- new_model: str, settings: Settings, current_user: dict = Depends(get_current_user)
100
  ):
101
  # Check if the user has either "admin" or "user" role
102
  if "user" not in current_user.get("role", []):
@@ -104,7 +99,7 @@ def create_app(root_injector: Injector) -> FastAPI:
104
  status_code=status.HTTP_403_FORBIDDEN,
105
  detail="You are not authorized to use this API.",
106
  )
107
-
108
  # Retrieve model information and validate access
109
  model_info = next((m for m in models_data if m["name"] == new_model), None)
110
  if not model_info or current_user.get("role") not in model_info["access"]:
@@ -112,16 +107,15 @@ def create_app(root_injector: Injector) -> FastAPI:
112
  status_code=status.HTTP_403_FORBIDDEN,
113
  detail="You are not authorized to access this model.",
114
  )
115
-
116
  # Switch the model using the LLMComponent
117
  llm_component = root_injector.get(LLMComponent)
118
- llm_component.switch_model(new_model, settings)
119
-
120
  # Return a success message
121
  return {"message": f"Model switched to {new_model}"}
122
 
123
 
124
-
125
  # Define a new APIRouter for the model_list
126
  model_list_router = APIRouter(prefix="/v1", dependencies=[Depends(get_current_user)])
127
 
@@ -132,41 +126,6 @@ def create_app(root_injector: Injector) -> FastAPI:
132
  """
133
  return models_data # Directly return the global variable
134
 
135
- # class NewLLMMode(BaseModel):
136
- # mode: str
137
-
138
- # @validator("mode")
139
- # def validate_mode(cls, v):
140
- # valid_modes = ["local", "openai", "sagemaker", "mock"] # Adjust as needed
141
- # if v not in valid_modes:
142
- # raise ValueError("Invalid LLM mode")
143
- # return v
144
-
145
- # @app.post("/settings/llm_mode")
146
- # async def update_llm_mode(new_mode: NewLLMMode = Body(...)):
147
- # try:
148
- # # Load settings
149
- # with open("settings.yaml", "r") as f:
150
- # settings = yaml.safe_load(f)
151
-
152
- # # Update llm.mode
153
- # settings["llm"]["mode"] = new_mode.mode
154
-
155
- # # Persist changes
156
- # with open("settings.yaml", "w") as f:
157
- # yaml.dump(settings, f)
158
-
159
- # return {"message": "LLM mode updated successfully"}
160
-
161
- # except Exception as e:
162
- # raise HTTPException(status_code=500, detail=f"Failed to update LLM mode: {e}")
163
-
164
-
165
-
166
- # Configuration module:
167
- allowed_modes = ["local", "openai", "sagemaker", "mock"] # Load from external source
168
-
169
-
170
  def custom_openapi() -> dict[str, Any]:
171
  if app.openapi_schema:
172
  return app.openapi_schema
 
2
  import logging
3
  from typing import Any
4
 
 
5
  from fastapi import Depends, FastAPI, Request, status, HTTPException
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from fastapi.openapi.utils import get_openapi
 
26
 
27
  from fastapi import Depends, HTTPException, status
28
  from fastapi.security import OAuth2AuthorizationCodeBearer
 
 
 
 
29
 
30
 
31
 
 
91
 
92
  @model_router.post("/switch_model")
93
  async def switch_model(
94
+ new_model: str, current_user: dict = Depends(get_current_user)
95
  ):
96
  # Check if the user has either "admin" or "user" role
97
  if "user" not in current_user.get("role", []):
 
99
  status_code=status.HTTP_403_FORBIDDEN,
100
  detail="You are not authorized to use this API.",
101
  )
102
+
103
  # Retrieve model information and validate access
104
  model_info = next((m for m in models_data if m["name"] == new_model), None)
105
  if not model_info or current_user.get("role") not in model_info["access"]:
 
107
  status_code=status.HTTP_403_FORBIDDEN,
108
  detail="You are not authorized to access this model.",
109
  )
110
+
111
  # Switch the model using the LLMComponent
112
  llm_component = root_injector.get(LLMComponent)
113
+ llm_component.switch_model(new_model, settings=settings)
114
+
115
  # Return a success message
116
  return {"message": f"Model switched to {new_model}"}
117
 
118
 
 
119
  # Define a new APIRouter for the model_list
120
  model_list_router = APIRouter(prefix="/v1", dependencies=[Depends(get_current_user)])
121
 
 
126
  """
127
  return models_data # Directly return the global variable
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  def custom_openapi() -> dict[str, Any]:
130
  if app.openapi_schema:
131
  return app.openapi_schema