Spaces:
Sleeping
Sleeping
File size: 8,314 Bytes
bf6d237 fcb3282 cd15232 bf6d237 cd15232 ba65734 bf6d237 cd15232 bf6d237 cd15232 bf6d237 bd74fef bf6d237 cd15232 ba65734 cd15232 766dd24 cd15232 ba65734 5f65a30 ba65734 cd15232 ba65734 cd15232 5f65a30 ba65734 cd15232 5f65a30 766dd24 5f65a30 ba65734 cd15232 ba65734 2de46cd b64e185 2de46cd b64e185 2de46cd ba65734 bf6d237 cd15232 bf6d237 cd15232 bf6d237 cd15232 b64e185 cd15232 bf6d237 cd15232 2de46cd 28fc0f3 bf6d237 aeb24f5 ba65734 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
"""FastAPI app creation, logger configuration and main API routes."""
import logging
from typing import Any
from fastapi import Depends, FastAPI, Request, status, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi
from injector import Injector
from fastapi import APIRouter
from private_gpt.paths import docs_path
from private_gpt.server.chat.chat_router import chat_router
from private_gpt.server.chunks.chunks_router import chunks_router
from private_gpt.server.completions.completions_router import completions_router
from private_gpt.server.embeddings.embeddings_router import embeddings_router
from private_gpt.server.health.health_router import health_router
from private_gpt.server.ingest.ingest_router import ingest_router
from private_gpt.server.utils import authentication
from private_gpt.settings.settings import Settings
from private_gpt.components.llm.llm_component import LLMComponent
from typing import Annotated
from sqlalchemy.orm import Session
from private_gpt.server.utils.authentication import get_current_user
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2AuthorizationCodeBearer
logger = logging.getLogger(__name__)
def create_app(root_injector: Injector) -> FastAPI:
# Start the API
with open(docs_path / "description.md") as description_file:
description = description_file.read()
tags_metadata = [
{
"name": "Ingestion",
"description": "High-level APIs covering document ingestion -internally "
"managing document parsing, splitting,"
"metadata extraction, embedding generation and storage- and ingested "
"documents CRUD."
"Each ingested document is identified by an ID that can be used to filter the "
"context"
"used in *Contextual Completions* and *Context Chunks* APIs.",
},
{
"name": "Contextual Completions",
"description": "High-level APIs covering contextual Chat and Completions. They "
"follow OpenAI's format, extending it to "
"allow using the context coming from ingested documents to create the "
"response. Internally"
"manage context retrieval, prompt engineering and the response generation.",
},
{
"name": "Context Chunks",
"description": "Low-level API that given a query return relevant chunks of "
"text coming from the ingested"
"documents.",
},
{
"name": "Embeddings",
"description": "Low-level API to obtain the vector representation of a given "
"text, using an Embeddings model."
"Follows OpenAI's embeddings API format.",
},
{
"name": "Health",
"description": "Simple health API to make sure the server is up and running.",
},
]
def get_app():
return app # Return the app instance
async def bind_injector_to_request(request: Request) -> None:
request.state.injector = root_injector
app = FastAPI(dependencies=[Depends(bind_injector_to_request)])
model_router = APIRouter(prefix ='/v1',dependencies=[Depends(get_current_user)])
#global hardcoded dictionary for LLMs
models_data = [
{"id": 1, "name": "gpt-3.5-turbo", "access": ["user", "admin"]},
{"id": 2, "name": "gpt-4", "access": ["admin"]},
{"id": 3, "name": "mistral-7B", "access": ["user", "admin"]},
]
@model_router.post("/switch_model")
async def switch_model(
new_model: str, current_user: dict = Depends(get_current_user)
):
# Check if the user has either "admin" or "user" role
if "user" not in current_user.get("role", []):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You are not authorized to use this API.",
)
# Retrieve model information and validate access
model_info = next((m for m in models_data if m["name"] == new_model), None)
if not model_info or current_user.get("role") not in model_info["access"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You are not authorized to access this model.",
)
# Switch the model using the LLMComponent
llm_component = root_injector.get(LLMComponent)
llm_component.switch_model(new_model, settings=settings)
# Return a success message
return {"message": f"Model switched to {new_model}"}
# Define a new APIRouter for the model_list
model_list_router = APIRouter(prefix="/v1", dependencies=[Depends(get_current_user)])
@model_list_router.get("/models_list", response_model=list[dict])
async def model_list(current_user: dict = Depends(get_current_user)):
"""
Get a list of models with their details.
"""
return models_data # Directly return the global variable
def custom_openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title="Invenxion-Chatbot",
description=description,
version="0.1.0",
summary="This is a production-ready AI project that allows you to "
"ask questions to your documents using the power of Large Language "
"Models (LLMs), even in scenarios without Internet connection. "
"100% private, no data leaves your execution environment at any point.",
license_info={
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0.html",
},
routes=app.routes,
tags=tags_metadata,
)
openapi_schema["info"]["x-logo"] = {
"url": "https://lh3.googleusercontent.com/drive-viewer"
"/AK7aPaD_iNlMoTquOBsw4boh4tIYxyEuhz6EtEs8nzq3yNkNAK00xGj"
"E1KUCmPJSk3TYOjcs6tReG6w_cLu1S7L_gPgT9z52iw=s2560"
}
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi # type: ignore[method-assign]
@app.get("/v1/me", status_code=status.HTTP_200_OK)
async def user(current_user: dict = Depends(get_current_user)):
if current_user is None:
raise HTTPException(status_code=401, detail="Authentication Failed")
return {"User": current_user}
app.include_router(authentication.router)
app.include_router(completions_router)
app.include_router(chat_router)
app.include_router(chunks_router)
app.include_router(ingest_router)
app.include_router(embeddings_router)
app.include_router(health_router)
app.include_router(model_router)
app.include_router(model_list_router)
settings = root_injector.get(Settings)
if settings.server.cors.enabled:
logger.debug("Setting up CORS middleware")
app.add_middleware(
CORSMiddleware,
allow_credentials=settings.server.cors.allow_credentials,
allow_origins=settings.server.cors.allow_origins,
allow_origin_regex=settings.server.cors.allow_origin_regex,
allow_methods=settings.server.cors.allow_methods,
allow_headers=settings.server.cors.allow_headers,
)
if settings.ui.enabled:
logger.debug("Importing the UI module")
from private_gpt.ui.ui import PrivateGptUi
ui = root_injector.get(PrivateGptUi)
ui.mount_in_app(app, settings.ui.path)
return app |