File size: 8,314 Bytes
bf6d237
 
 
 
fcb3282
cd15232
bf6d237
 
 
cd15232
ba65734
bf6d237
 
 
 
 
 
 
 
cd15232
bf6d237
cd15232
 
 
 
 
 
 
 
 
 
bf6d237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd74fef
 
 
 
bf6d237
 
 
 
cd15232
 
ba65734
 
 
 
 
 
 
 
cd15232
 
766dd24
cd15232
ba65734
 
 
 
 
 
5f65a30
ba65734
 
 
cd15232
 
ba65734
cd15232
5f65a30
ba65734
cd15232
5f65a30
766dd24
5f65a30
ba65734
cd15232
 
ba65734
2de46cd
b64e185
2de46cd
b64e185
2de46cd
 
 
 
ba65734
bf6d237
 
 
 
 
cd15232
bf6d237
 
cd15232
bf6d237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd15232
b64e185
cd15232
 
 
 
 
 
 
 
bf6d237
 
 
 
 
 
cd15232
2de46cd
28fc0f3
bf6d237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aeb24f5
ba65734
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""FastAPI app creation, logger configuration and main API routes."""
import logging
from typing import Any


from fastapi import Depends, FastAPI, Request, status, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi
from injector import Injector
from fastapi import APIRouter


from private_gpt.paths import docs_path
from private_gpt.server.chat.chat_router import chat_router
from private_gpt.server.chunks.chunks_router import chunks_router
from private_gpt.server.completions.completions_router import completions_router
from private_gpt.server.embeddings.embeddings_router import embeddings_router
from private_gpt.server.health.health_router import health_router
from private_gpt.server.ingest.ingest_router import ingest_router
from private_gpt.server.utils import authentication
from private_gpt.settings.settings import Settings
from private_gpt.components.llm.llm_component import LLMComponent

from typing import Annotated
from sqlalchemy.orm import Session
from private_gpt.server.utils.authentication import get_current_user

from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2AuthorizationCodeBearer



logger = logging.getLogger(__name__)


def create_app(root_injector: Injector) -> FastAPI:

    # Start the API
    with open(docs_path / "description.md") as description_file:
        description = description_file.read()

        tags_metadata = [
            {
                "name": "Ingestion",
                "description": "High-level APIs covering document ingestion -internally "
                "managing document parsing, splitting,"
                "metadata extraction, embedding generation and storage- and ingested "
                "documents CRUD."
                "Each ingested document is identified by an ID that can be used to filter the "
                "context"
                "used in *Contextual Completions* and *Context Chunks* APIs.",
            },
            {
                "name": "Contextual Completions",
                "description": "High-level APIs covering contextual Chat and Completions. They "
                "follow OpenAI's format, extending it to "
                "allow using the context coming from ingested documents to create the "
                "response. Internally"
                "manage context retrieval, prompt engineering and the response generation.",
            },
            {
                "name": "Context Chunks",
                "description": "Low-level API that given a query return relevant chunks of "
                "text coming from the ingested"
                "documents.",
            },
            {
                "name": "Embeddings",
                "description": "Low-level API to obtain the vector representation of a given "
                "text, using an Embeddings model."
                "Follows OpenAI's embeddings API format.",
            },
            {
                "name": "Health",
                "description": "Simple health API to make sure the server is up and running.",
            },
        ]

        def get_app():
            return app  # Return the app instance


        async def bind_injector_to_request(request: Request) -> None:
            request.state.injector = root_injector

        app = FastAPI(dependencies=[Depends(bind_injector_to_request)])
        model_router = APIRouter(prefix ='/v1',dependencies=[Depends(get_current_user)])

        #global hardcoded dictionary for LLMs
        models_data = [
                {"id": 1, "name": "gpt-3.5-turbo", "access": ["user", "admin"]},
                {"id": 2, "name": "gpt-4", "access": ["admin"]},
                {"id": 3, "name": "mistral-7B", "access": ["user", "admin"]},
                
            ]

        @model_router.post("/switch_model")
        async def switch_model(
            new_model: str, current_user: dict = Depends(get_current_user)
        ):
            # Check if the user has either "admin" or "user" role
            if "user" not in current_user.get("role", []):
                raise HTTPException(
                    status_code=status.HTTP_403_FORBIDDEN,
                    detail="You are not authorized to use this API.",
                )

            # Retrieve model information and validate access
            model_info = next((m for m in models_data if m["name"] == new_model), None)
            if not model_info or current_user.get("role") not in model_info["access"]:
                raise HTTPException(
                    status_code=status.HTTP_403_FORBIDDEN,
                    detail="You are not authorized to access this model.",
                )

            # Switch the model using the LLMComponent
            llm_component = root_injector.get(LLMComponent)
            llm_component.switch_model(new_model, settings=settings)


            # Return a success message
            return {"message": f"Model switched to {new_model}"}


        # Define a new APIRouter for the model_list
        model_list_router = APIRouter(prefix="/v1", dependencies=[Depends(get_current_user)])    

        @model_list_router.get("/models_list", response_model=list[dict])
        async def model_list(current_user: dict = Depends(get_current_user)):
            """
            Get a list of models with their details.
            """
            return models_data  # Directly return the global variable

        def custom_openapi() -> dict[str, Any]:
            if app.openapi_schema:
                return app.openapi_schema
            openapi_schema = get_openapi(
                title="Invenxion-Chatbot",
                description=description,
                version="0.1.0",
                summary="This is a production-ready AI project that allows you to "
                "ask questions to your documents using the power of Large Language "
                "Models (LLMs), even in scenarios without Internet connection. "
                "100% private, no data leaves your execution environment at any point.",
                license_info={
                    "name": "Apache 2.0",
                    "url": "https://www.apache.org/licenses/LICENSE-2.0.html",
                },
                routes=app.routes,
                tags=tags_metadata,
            )
            openapi_schema["info"]["x-logo"] = {
                "url": "https://lh3.googleusercontent.com/drive-viewer"
                "/AK7aPaD_iNlMoTquOBsw4boh4tIYxyEuhz6EtEs8nzq3yNkNAK00xGj"
                "E1KUCmPJSk3TYOjcs6tReG6w_cLu1S7L_gPgT9z52iw=s2560"
            }

            app.openapi_schema = openapi_schema
            return app.openapi_schema

        app.openapi = custom_openapi  # type: ignore[method-assign]


        @app.get("/v1/me", status_code=status.HTTP_200_OK)
        async def user(current_user: dict = Depends(get_current_user)):
            if current_user is None:
                raise HTTPException(status_code=401, detail="Authentication Failed")
            return {"User": current_user}

        
        app.include_router(authentication.router)

        app.include_router(completions_router)
        app.include_router(chat_router)
        app.include_router(chunks_router)
        app.include_router(ingest_router)
        app.include_router(embeddings_router)
        app.include_router(health_router)
        app.include_router(model_router)
        app.include_router(model_list_router)

        settings = root_injector.get(Settings)
        if settings.server.cors.enabled:
            logger.debug("Setting up CORS middleware")
            app.add_middleware(
                CORSMiddleware,
                allow_credentials=settings.server.cors.allow_credentials,
                allow_origins=settings.server.cors.allow_origins,
                allow_origin_regex=settings.server.cors.allow_origin_regex,
                allow_methods=settings.server.cors.allow_methods,
                allow_headers=settings.server.cors.allow_headers,
            )

        if settings.ui.enabled:
            logger.debug("Importing the UI module")
            from private_gpt.ui.ui import PrivateGptUi

            ui = root_injector.get(PrivateGptUi)
            ui.mount_in_app(app, settings.ui.path)

        return app