Ibraaheem commited on
Commit
ba65734
·
1 Parent(s): c5b591f

Update private_gpt/launcher.py

Browse files
Files changed (1) hide show
  1. private_gpt/launcher.py +25 -41
private_gpt/launcher.py CHANGED
@@ -7,7 +7,7 @@ from fastapi.middleware.cors import CORSMiddleware
7
  from fastapi.openapi.utils import get_openapi
8
  from injector import Injector
9
  from fastapi import APIRouter
10
- #from fastapi import JSONResponse
11
 
12
  from private_gpt.paths import docs_path
13
  from private_gpt.server.chat.chat_router import chat_router
@@ -20,9 +20,6 @@ from private_gpt.server.utils import authentication
20
  from private_gpt.settings.settings import Settings
21
  from private_gpt.components.llm.llm_component import LLMComponent
22
 
23
- #databse for user authentication integration
24
- #from private_gpt.server.utils import models
25
- #from private_gpt.server.utils.database import engine, SessionLocal
26
  from typing import Annotated
27
  from sqlalchemy.orm import Session
28
  from private_gpt.server.utils.authentication import get_current_user
@@ -82,28 +79,43 @@ def create_app(root_injector: Injector) -> FastAPI:
82
  request.state.injector = root_injector
83
 
84
  app = FastAPI(dependencies=[Depends(bind_injector_to_request)])
85
-
86
-
87
  model_router = APIRouter(prefix ='/v1',dependencies=[Depends(get_current_user)])
88
 
 
 
 
 
 
 
 
 
89
  @model_router.post("/switch_model")
90
  async def switch_model(
91
  new_model: str, current_user: dict = Depends(get_current_user)
92
  ):
93
- # Check if the user has the "admin" role
94
- if "admin" not in current_user.get("role", []):
 
 
 
 
 
 
 
 
95
  raise HTTPException(
96
  status_code=status.HTTP_403_FORBIDDEN,
97
- detail="You are not an admin and cannot use this API.",
98
  )
99
 
100
- #logic to switch the LLM model based on the user's request
101
-
102
  llm_component = root_injector.get(LLMComponent)
103
  llm_component.switch_model(new_model, settings=settings)
104
 
 
105
  return {"message": f"Model switched to {new_model}"}
106
 
 
107
  # Define a new APIRouter for the model_list
108
  model_list_router = APIRouter(prefix="/v1", dependencies=[Depends(get_current_user)])
109
 
@@ -112,24 +124,7 @@ def create_app(root_injector: Injector) -> FastAPI:
112
  """
113
  Get a list of models with their details.
114
  """
115
- # In this example, hardcoding some sample model data
116
- models_data = [
117
- {"id": 1, "name": "gpt-3.5-turbo", "access": ["user", "admin"]},
118
- {"id": 2, "name": "gpt-4", "access": ["admin"]},
119
- # Add more models as needed
120
- ]
121
-
122
- return models_data
123
-
124
- # @model_router.post("/switch_model")
125
- # async def switch_model(new_model: str):
126
- # # Implement logic to switch the LLM model based on the user's request
127
- # # Example: Change the LLM model in LLMComponent based on the new_model parameter
128
- # llm_component = root_injector.get(LLMComponent)
129
- # llm_component.switch_model(new_model, settings=settings)
130
-
131
- # return {"message": f"Model switched to {new_model}"}
132
-
133
 
134
  def custom_openapi() -> dict[str, Any]:
135
  if app.openapi_schema:
@@ -160,17 +155,6 @@ def create_app(root_injector: Injector) -> FastAPI:
160
 
161
  app.openapi = custom_openapi # type: ignore[method-assign]
162
 
163
- #models.Base.metadata.create_all(bind=engine)
164
-
165
- # def get_db():
166
- # db = SessionLocal()
167
- # try:
168
- # yield db
169
- # finally:
170
- # db.close()
171
-
172
- #db_dependency = Annotated[Session, Depends(get_db)]
173
- #user_dependency = Annotated[dict, Depends(get_current_user)]
174
 
175
  @app.get("/v1/me", status_code=status.HTTP_200_OK)
176
  async def user(current_user: dict = Depends(get_current_user)):
@@ -209,4 +193,4 @@ def create_app(root_injector: Injector) -> FastAPI:
209
  ui = root_injector.get(PrivateGptUi)
210
  ui.mount_in_app(app, settings.ui.path)
211
 
212
- return app
 
7
  from fastapi.openapi.utils import get_openapi
8
  from injector import Injector
9
  from fastapi import APIRouter
10
+
11
 
12
  from private_gpt.paths import docs_path
13
  from private_gpt.server.chat.chat_router import chat_router
 
20
  from private_gpt.settings.settings import Settings
21
  from private_gpt.components.llm.llm_component import LLMComponent
22
 
 
 
 
23
  from typing import Annotated
24
  from sqlalchemy.orm import Session
25
  from private_gpt.server.utils.authentication import get_current_user
 
79
  request.state.injector = root_injector
80
 
81
  app = FastAPI(dependencies=[Depends(bind_injector_to_request)])
 
 
82
  model_router = APIRouter(prefix ='/v1',dependencies=[Depends(get_current_user)])
83
 
84
+ #global hardcoded dictionary for LLMs
85
+ models_data = [
86
+ {"id": 1, "name": "gpt-3.5-turbo", "access": ["user", "admin"]},
87
+ {"id": 2, "name": "gpt-4", "access": ["admin"]},
88
+ {"id": 3, "name": "mistral-7B", "access": ["user", "admin"]},
89
+
90
+ ]
91
+
92
  @model_router.post("/switch_model")
93
  async def switch_model(
94
  new_model: str, current_user: dict = Depends(get_current_user)
95
  ):
96
+ # Check if the user has either "admin" or "user" role
97
+ if "user" not in current_user.get("role", []):
98
+ raise HTTPException(
99
+ status_code=status.HTTP_403_FORBIDDEN,
100
+ detail="You are not authorized to use this API.",
101
+ )
102
+
103
+ # Retrieve model information and validate access
104
+ model_info = next((m for m in models_data if m["name"] == new_model), None)
105
+ if not model_info or current_user.get("role") not in model_info["access"]:
106
  raise HTTPException(
107
  status_code=status.HTTP_403_FORBIDDEN,
108
+ detail="You are not authorized to access this model.",
109
  )
110
 
111
+ # Switch the model using the LLMComponent
 
112
  llm_component = root_injector.get(LLMComponent)
113
  llm_component.switch_model(new_model, settings=settings)
114
 
115
+ # Return a success message
116
  return {"message": f"Model switched to {new_model}"}
117
 
118
+
119
  # Define a new APIRouter for the model_list
120
  model_list_router = APIRouter(prefix="/v1", dependencies=[Depends(get_current_user)])
121
 
 
124
  """
125
  Get a list of models with their details.
126
  """
127
+ return models_data # Directly return the global variable
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  def custom_openapi() -> dict[str, Any]:
130
  if app.openapi_schema:
 
155
 
156
  app.openapi = custom_openapi # type: ignore[method-assign]
157
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  @app.get("/v1/me", status_code=status.HTTP_200_OK)
160
  async def user(current_user: dict = Depends(get_current_user)):
 
193
  ui = root_injector.get(PrivateGptUi)
194
  ui.mount_in_app(app, settings.ui.path)
195
 
196
+ return app