"""FastAPI app creation, logger configuration and main API routes.""" import logging from typing import Any from fastapi import Depends, FastAPI, Request, status, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.utils import get_openapi from injector import Injector from fastapi import APIRouter from private_gpt.paths import docs_path from private_gpt.server.chat.chat_router import chat_router from private_gpt.server.chunks.chunks_router import chunks_router from private_gpt.server.completions.completions_router import completions_router from private_gpt.server.embeddings.embeddings_router import embeddings_router from private_gpt.server.health.health_router import health_router from private_gpt.server.ingest.ingest_router import ingest_router from private_gpt.server.utils import authentication from private_gpt.settings.settings import Settings from private_gpt.components.llm.llm_component import LLMComponent from typing import Annotated from sqlalchemy.orm import Session from private_gpt.server.utils.authentication import get_current_user from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2AuthorizationCodeBearer logger = logging.getLogger(__name__) def create_app(root_injector: Injector) -> FastAPI: # Start the API with open(docs_path / "description.md") as description_file: description = description_file.read() tags_metadata = [ { "name": "Ingestion", "description": "High-level APIs covering document ingestion -internally " "managing document parsing, splitting," "metadata extraction, embedding generation and storage- and ingested " "documents CRUD." "Each ingested document is identified by an ID that can be used to filter the " "context" "used in *Contextual Completions* and *Context Chunks* APIs.", }, { "name": "Contextual Completions", "description": "High-level APIs covering contextual Chat and Completions. They " "follow OpenAI's format, extending it to " "allow using the context coming from ingested documents to create the " "response. Internally" "manage context retrieval, prompt engineering and the response generation.", }, { "name": "Context Chunks", "description": "Low-level API that given a query return relevant chunks of " "text coming from the ingested" "documents.", }, { "name": "Embeddings", "description": "Low-level API to obtain the vector representation of a given " "text, using an Embeddings model." "Follows OpenAI's embeddings API format.", }, { "name": "Health", "description": "Simple health API to make sure the server is up and running.", }, ] def get_app(): return app # Return the app instance async def bind_injector_to_request(request: Request) -> None: request.state.injector = root_injector app = FastAPI(dependencies=[Depends(bind_injector_to_request)]) model_router = APIRouter(prefix ='/v1',dependencies=[Depends(get_current_user)]) #global hardcoded dictionary for LLMs models_data = [ {"id": 1, "name": "gpt-3.5-turbo", "access": ["user", "admin"]}, {"id": 2, "name": "gpt-4", "access": ["admin"]}, {"id": 3, "name": "mistral-7B", "access": ["user", "admin"]}, ] @model_router.post("/switch_model") async def switch_model( new_model: str, current_user: dict = Depends(get_current_user) ): # Check if the user has either "admin" or "user" role if "user" not in current_user.get("role", []): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You are not authorized to use this API.", ) # Retrieve model information and validate access model_info = next((m for m in models_data if m["name"] == new_model), None) if not model_info or current_user.get("role") not in model_info["access"]: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You are not authorized to access this model.", ) # Switch the model using the LLMComponent llm_component = root_injector.get(LLMComponent) llm_component.switch_model(new_model, settings=settings) # Return a success message return {"message": f"Model switched to {new_model}"} # Define a new APIRouter for the model_list model_list_router = APIRouter(prefix="/v1", dependencies=[Depends(get_current_user)]) @model_list_router.get("/models_list", response_model=list[dict]) async def model_list(current_user: dict = Depends(get_current_user)): """ Get a list of models with their details. """ return models_data # Directly return the global variable def custom_openapi() -> dict[str, Any]: if app.openapi_schema: return app.openapi_schema openapi_schema = get_openapi( title="Invenxion-Chatbot", description=description, version="0.1.0", summary="This is a production-ready AI project that allows you to " "ask questions to your documents using the power of Large Language " "Models (LLMs), even in scenarios without Internet connection. " "100% private, no data leaves your execution environment at any point.", license_info={ "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0.html", }, routes=app.routes, tags=tags_metadata, ) openapi_schema["info"]["x-logo"] = { "url": "https://lh3.googleusercontent.com/drive-viewer" "/AK7aPaD_iNlMoTquOBsw4boh4tIYxyEuhz6EtEs8nzq3yNkNAK00xGj" "E1KUCmPJSk3TYOjcs6tReG6w_cLu1S7L_gPgT9z52iw=s2560" } app.openapi_schema = openapi_schema return app.openapi_schema app.openapi = custom_openapi # type: ignore[method-assign] @app.get("/v1/me", status_code=status.HTTP_200_OK) async def user(current_user: dict = Depends(get_current_user)): if current_user is None: raise HTTPException(status_code=401, detail="Authentication Failed") return {"User": current_user} app.include_router(authentication.router) app.include_router(completions_router) app.include_router(chat_router) app.include_router(chunks_router) app.include_router(ingest_router) app.include_router(embeddings_router) app.include_router(health_router) app.include_router(model_router) app.include_router(model_list_router) settings = root_injector.get(Settings) if settings.server.cors.enabled: logger.debug("Setting up CORS middleware") app.add_middleware( CORSMiddleware, allow_credentials=settings.server.cors.allow_credentials, allow_origins=settings.server.cors.allow_origins, allow_origin_regex=settings.server.cors.allow_origin_regex, allow_methods=settings.server.cors.allow_methods, allow_headers=settings.server.cors.allow_headers, ) if settings.ui.enabled: logger.debug("Importing the UI module") from private_gpt.ui.ui import PrivateGptUi ui = root_injector.get(PrivateGptUi) ui.mount_in_app(app, settings.ui.path) return app