Spaces:
Sleeping
Sleeping
Update private_gpt/launcher.py
Browse files- private_gpt/launcher.py +25 -41
private_gpt/launcher.py
CHANGED
@@ -7,7 +7,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
7 |
from fastapi.openapi.utils import get_openapi
|
8 |
from injector import Injector
|
9 |
from fastapi import APIRouter
|
10 |
-
|
11 |
|
12 |
from private_gpt.paths import docs_path
|
13 |
from private_gpt.server.chat.chat_router import chat_router
|
@@ -20,9 +20,6 @@ from private_gpt.server.utils import authentication
|
|
20 |
from private_gpt.settings.settings import Settings
|
21 |
from private_gpt.components.llm.llm_component import LLMComponent
|
22 |
|
23 |
-
#databse for user authentication integration
|
24 |
-
#from private_gpt.server.utils import models
|
25 |
-
#from private_gpt.server.utils.database import engine, SessionLocal
|
26 |
from typing import Annotated
|
27 |
from sqlalchemy.orm import Session
|
28 |
from private_gpt.server.utils.authentication import get_current_user
|
@@ -82,28 +79,43 @@ def create_app(root_injector: Injector) -> FastAPI:
|
|
82 |
request.state.injector = root_injector
|
83 |
|
84 |
app = FastAPI(dependencies=[Depends(bind_injector_to_request)])
|
85 |
-
|
86 |
-
|
87 |
model_router = APIRouter(prefix ='/v1',dependencies=[Depends(get_current_user)])
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
@model_router.post("/switch_model")
|
90 |
async def switch_model(
|
91 |
new_model: str, current_user: dict = Depends(get_current_user)
|
92 |
):
|
93 |
-
# Check if the user has
|
94 |
-
if "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
raise HTTPException(
|
96 |
status_code=status.HTTP_403_FORBIDDEN,
|
97 |
-
detail="You are not
|
98 |
)
|
99 |
|
100 |
-
#
|
101 |
-
|
102 |
llm_component = root_injector.get(LLMComponent)
|
103 |
llm_component.switch_model(new_model, settings=settings)
|
104 |
|
|
|
105 |
return {"message": f"Model switched to {new_model}"}
|
106 |
|
|
|
107 |
# Define a new APIRouter for the model_list
|
108 |
model_list_router = APIRouter(prefix="/v1", dependencies=[Depends(get_current_user)])
|
109 |
|
@@ -112,24 +124,7 @@ def create_app(root_injector: Injector) -> FastAPI:
|
|
112 |
"""
|
113 |
Get a list of models with their details.
|
114 |
"""
|
115 |
-
#
|
116 |
-
models_data = [
|
117 |
-
{"id": 1, "name": "gpt-3.5-turbo", "access": ["user", "admin"]},
|
118 |
-
{"id": 2, "name": "gpt-4", "access": ["admin"]},
|
119 |
-
# Add more models as needed
|
120 |
-
]
|
121 |
-
|
122 |
-
return models_data
|
123 |
-
|
124 |
-
# @model_router.post("/switch_model")
|
125 |
-
# async def switch_model(new_model: str):
|
126 |
-
# # Implement logic to switch the LLM model based on the user's request
|
127 |
-
# # Example: Change the LLM model in LLMComponent based on the new_model parameter
|
128 |
-
# llm_component = root_injector.get(LLMComponent)
|
129 |
-
# llm_component.switch_model(new_model, settings=settings)
|
130 |
-
|
131 |
-
# return {"message": f"Model switched to {new_model}"}
|
132 |
-
|
133 |
|
134 |
def custom_openapi() -> dict[str, Any]:
|
135 |
if app.openapi_schema:
|
@@ -160,17 +155,6 @@ def create_app(root_injector: Injector) -> FastAPI:
|
|
160 |
|
161 |
app.openapi = custom_openapi # type: ignore[method-assign]
|
162 |
|
163 |
-
#models.Base.metadata.create_all(bind=engine)
|
164 |
-
|
165 |
-
# def get_db():
|
166 |
-
# db = SessionLocal()
|
167 |
-
# try:
|
168 |
-
# yield db
|
169 |
-
# finally:
|
170 |
-
# db.close()
|
171 |
-
|
172 |
-
#db_dependency = Annotated[Session, Depends(get_db)]
|
173 |
-
#user_dependency = Annotated[dict, Depends(get_current_user)]
|
174 |
|
175 |
@app.get("/v1/me", status_code=status.HTTP_200_OK)
|
176 |
async def user(current_user: dict = Depends(get_current_user)):
|
@@ -209,4 +193,4 @@ def create_app(root_injector: Injector) -> FastAPI:
|
|
209 |
ui = root_injector.get(PrivateGptUi)
|
210 |
ui.mount_in_app(app, settings.ui.path)
|
211 |
|
212 |
-
return app
|
|
|
7 |
from fastapi.openapi.utils import get_openapi
|
8 |
from injector import Injector
|
9 |
from fastapi import APIRouter
|
10 |
+
|
11 |
|
12 |
from private_gpt.paths import docs_path
|
13 |
from private_gpt.server.chat.chat_router import chat_router
|
|
|
20 |
from private_gpt.settings.settings import Settings
|
21 |
from private_gpt.components.llm.llm_component import LLMComponent
|
22 |
|
|
|
|
|
|
|
23 |
from typing import Annotated
|
24 |
from sqlalchemy.orm import Session
|
25 |
from private_gpt.server.utils.authentication import get_current_user
|
|
|
79 |
request.state.injector = root_injector
|
80 |
|
81 |
app = FastAPI(dependencies=[Depends(bind_injector_to_request)])
|
|
|
|
|
82 |
model_router = APIRouter(prefix ='/v1',dependencies=[Depends(get_current_user)])
|
83 |
|
84 |
+
#global hardcoded dictionary for LLMs
|
85 |
+
models_data = [
|
86 |
+
{"id": 1, "name": "gpt-3.5-turbo", "access": ["user", "admin"]},
|
87 |
+
{"id": 2, "name": "gpt-4", "access": ["admin"]},
|
88 |
+
{"id": 3, "name": "mistral-7B", "access": ["user", "admin"]},
|
89 |
+
|
90 |
+
]
|
91 |
+
|
92 |
@model_router.post("/switch_model")
|
93 |
async def switch_model(
|
94 |
new_model: str, current_user: dict = Depends(get_current_user)
|
95 |
):
|
96 |
+
# Check if the user has either "admin" or "user" role
|
97 |
+
if "user" not in current_user.get("role", []):
|
98 |
+
raise HTTPException(
|
99 |
+
status_code=status.HTTP_403_FORBIDDEN,
|
100 |
+
detail="You are not authorized to use this API.",
|
101 |
+
)
|
102 |
+
|
103 |
+
# Retrieve model information and validate access
|
104 |
+
model_info = next((m for m in models_data if m["name"] == new_model), None)
|
105 |
+
if not model_info or current_user.get("role") not in model_info["access"]:
|
106 |
raise HTTPException(
|
107 |
status_code=status.HTTP_403_FORBIDDEN,
|
108 |
+
detail="You are not authorized to access this model.",
|
109 |
)
|
110 |
|
111 |
+
# Switch the model using the LLMComponent
|
|
|
112 |
llm_component = root_injector.get(LLMComponent)
|
113 |
llm_component.switch_model(new_model, settings=settings)
|
114 |
|
115 |
+
# Return a success message
|
116 |
return {"message": f"Model switched to {new_model}"}
|
117 |
|
118 |
+
|
119 |
# Define a new APIRouter for the model_list
|
120 |
model_list_router = APIRouter(prefix="/v1", dependencies=[Depends(get_current_user)])
|
121 |
|
|
|
124 |
"""
|
125 |
Get a list of models with their details.
|
126 |
"""
|
127 |
+
return models_data # Directly return the global variable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
def custom_openapi() -> dict[str, Any]:
|
130 |
if app.openapi_schema:
|
|
|
155 |
|
156 |
app.openapi = custom_openapi # type: ignore[method-assign]
|
157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
@app.get("/v1/me", status_code=status.HTTP_200_OK)
|
160 |
async def user(current_user: dict = Depends(get_current_user)):
|
|
|
193 |
ui = root_injector.get(PrivateGptUi)
|
194 |
ui.mount_in_app(app, settings.ui.path)
|
195 |
|
196 |
+
return app
|