Spaces:
Sleeping
Sleeping
Update private_gpt/launcher.py
Browse files- private_gpt/launcher.py +5 -46
private_gpt/launcher.py
CHANGED
@@ -2,7 +2,6 @@
|
|
2 |
import logging
|
3 |
from typing import Any
|
4 |
|
5 |
-
|
6 |
from fastapi import Depends, FastAPI, Request, status, HTTPException
|
7 |
from fastapi.middleware.cors import CORSMiddleware
|
8 |
from fastapi.openapi.utils import get_openapi
|
@@ -27,10 +26,6 @@ from private_gpt.server.utils.authentication import get_current_user
|
|
27 |
|
28 |
from fastapi import Depends, HTTPException, status
|
29 |
from fastapi.security import OAuth2AuthorizationCodeBearer
|
30 |
-
from fastapi import FastAPI, Body, HTTPException
|
31 |
-
from pydantic import BaseModel, validator
|
32 |
-
import yaml
|
33 |
-
|
34 |
|
35 |
|
36 |
|
@@ -96,7 +91,7 @@ def create_app(root_injector: Injector) -> FastAPI:
|
|
96 |
|
97 |
@model_router.post("/switch_model")
|
98 |
async def switch_model(
|
99 |
-
new_model: str,
|
100 |
):
|
101 |
# Check if the user has either "admin" or "user" role
|
102 |
if "user" not in current_user.get("role", []):
|
@@ -104,7 +99,7 @@ def create_app(root_injector: Injector) -> FastAPI:
|
|
104 |
status_code=status.HTTP_403_FORBIDDEN,
|
105 |
detail="You are not authorized to use this API.",
|
106 |
)
|
107 |
-
|
108 |
# Retrieve model information and validate access
|
109 |
model_info = next((m for m in models_data if m["name"] == new_model), None)
|
110 |
if not model_info or current_user.get("role") not in model_info["access"]:
|
@@ -112,16 +107,15 @@ def create_app(root_injector: Injector) -> FastAPI:
|
|
112 |
status_code=status.HTTP_403_FORBIDDEN,
|
113 |
detail="You are not authorized to access this model.",
|
114 |
)
|
115 |
-
|
116 |
# Switch the model using the LLMComponent
|
117 |
llm_component = root_injector.get(LLMComponent)
|
118 |
-
llm_component.switch_model(new_model, settings)
|
119 |
-
|
120 |
# Return a success message
|
121 |
return {"message": f"Model switched to {new_model}"}
|
122 |
|
123 |
|
124 |
-
|
125 |
# Define a new APIRouter for the model_list
|
126 |
model_list_router = APIRouter(prefix="/v1", dependencies=[Depends(get_current_user)])
|
127 |
|
@@ -132,41 +126,6 @@ def create_app(root_injector: Injector) -> FastAPI:
|
|
132 |
"""
|
133 |
return models_data # Directly return the global variable
|
134 |
|
135 |
-
# class NewLLMMode(BaseModel):
|
136 |
-
# mode: str
|
137 |
-
|
138 |
-
# @validator("mode")
|
139 |
-
# def validate_mode(cls, v):
|
140 |
-
# valid_modes = ["local", "openai", "sagemaker", "mock"] # Adjust as needed
|
141 |
-
# if v not in valid_modes:
|
142 |
-
# raise ValueError("Invalid LLM mode")
|
143 |
-
# return v
|
144 |
-
|
145 |
-
# @app.post("/settings/llm_mode")
|
146 |
-
# async def update_llm_mode(new_mode: NewLLMMode = Body(...)):
|
147 |
-
# try:
|
148 |
-
# # Load settings
|
149 |
-
# with open("settings.yaml", "r") as f:
|
150 |
-
# settings = yaml.safe_load(f)
|
151 |
-
|
152 |
-
# # Update llm.mode
|
153 |
-
# settings["llm"]["mode"] = new_mode.mode
|
154 |
-
|
155 |
-
# # Persist changes
|
156 |
-
# with open("settings.yaml", "w") as f:
|
157 |
-
# yaml.dump(settings, f)
|
158 |
-
|
159 |
-
# return {"message": "LLM mode updated successfully"}
|
160 |
-
|
161 |
-
# except Exception as e:
|
162 |
-
# raise HTTPException(status_code=500, detail=f"Failed to update LLM mode: {e}")
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
# Configuration module:
|
167 |
-
allowed_modes = ["local", "openai", "sagemaker", "mock"] # Load from external source
|
168 |
-
|
169 |
-
|
170 |
def custom_openapi() -> dict[str, Any]:
|
171 |
if app.openapi_schema:
|
172 |
return app.openapi_schema
|
|
|
2 |
import logging
|
3 |
from typing import Any
|
4 |
|
|
|
5 |
from fastapi import Depends, FastAPI, Request, status, HTTPException
|
6 |
from fastapi.middleware.cors import CORSMiddleware
|
7 |
from fastapi.openapi.utils import get_openapi
|
|
|
26 |
|
27 |
from fastapi import Depends, HTTPException, status
|
28 |
from fastapi.security import OAuth2AuthorizationCodeBearer
|
|
|
|
|
|
|
|
|
29 |
|
30 |
|
31 |
|
|
|
91 |
|
92 |
@model_router.post("/switch_model")
|
93 |
async def switch_model(
|
94 |
+
new_model: str, current_user: dict = Depends(get_current_user)
|
95 |
):
|
96 |
# Check if the user has either "admin" or "user" role
|
97 |
if "user" not in current_user.get("role", []):
|
|
|
99 |
status_code=status.HTTP_403_FORBIDDEN,
|
100 |
detail="You are not authorized to use this API.",
|
101 |
)
|
102 |
+
|
103 |
# Retrieve model information and validate access
|
104 |
model_info = next((m for m in models_data if m["name"] == new_model), None)
|
105 |
if not model_info or current_user.get("role") not in model_info["access"]:
|
|
|
107 |
status_code=status.HTTP_403_FORBIDDEN,
|
108 |
detail="You are not authorized to access this model.",
|
109 |
)
|
110 |
+
|
111 |
# Switch the model using the LLMComponent
|
112 |
llm_component = root_injector.get(LLMComponent)
|
113 |
+
llm_component.switch_model(new_model, settings=settings)
|
114 |
+
|
115 |
# Return a success message
|
116 |
return {"message": f"Model switched to {new_model}"}
|
117 |
|
118 |
|
|
|
119 |
# Define a new APIRouter for the model_list
|
120 |
model_list_router = APIRouter(prefix="/v1", dependencies=[Depends(get_current_user)])
|
121 |
|
|
|
126 |
"""
|
127 |
return models_data # Directly return the global variable
|
128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
def custom_openapi() -> dict[str, Any]:
|
130 |
if app.openapi_schema:
|
131 |
return app.openapi_schema
|