File size: 8,752 Bytes
bf6d237
 
 
 
cd15232
bf6d237
 
 
cd15232
8d6a4d0
bf6d237
 
 
 
 
 
 
 
cd15232
bf6d237
cd15232
 
 
987aa8b
cd15232
 
 
 
 
 
 
 
 
bf6d237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd15232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2de46cd
b64e185
2de46cd
b64e185
2de46cd
 
 
 
 
 
 
 
 
 
 
36ff373
2de46cd
cd15232
 
 
 
 
 
 
 
 
bf6d237
 
 
 
 
cd15232
bf6d237
 
cd15232
bf6d237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd15232
 
 
 
 
 
 
 
 
 
 
 
b64e185
cd15232
 
 
 
 
 
 
 
bf6d237
 
 
 
 
 
cd15232
2de46cd
28fc0f3
bf6d237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aeb24f5
bf6d237
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
"""FastAPI app creation, logger configuration and main API routes."""
import logging
from typing import Any

from fastapi import Depends, FastAPI, Request, status, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi
from injector import Injector
from fastapi import APIRouter
#from fastapi import JSONResponse

from private_gpt.paths import docs_path
from private_gpt.server.chat.chat_router import chat_router
from private_gpt.server.chunks.chunks_router import chunks_router
from private_gpt.server.completions.completions_router import completions_router
from private_gpt.server.embeddings.embeddings_router import embeddings_router
from private_gpt.server.health.health_router import health_router
from private_gpt.server.ingest.ingest_router import ingest_router
from private_gpt.server.utils import authentication
from private_gpt.settings.settings import Settings
from private_gpt.components.llm.llm_component import LLMComponent

#databse for user authentication integration
#from private_gpt.server.utils import models 
#from private_gpt.server.utils.database import engine, SessionLocal
from typing import Annotated
from sqlalchemy.orm import Session
from private_gpt.server.utils.authentication import get_current_user

from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2AuthorizationCodeBearer



logger = logging.getLogger(__name__)


def create_app(root_injector: Injector) -> FastAPI:

    # Start the API
    with open(docs_path / "description.md") as description_file:
        description = description_file.read()

        tags_metadata = [
            {
                "name": "Ingestion",
                "description": "High-level APIs covering document ingestion -internally "
                "managing document parsing, splitting,"
                "metadata extraction, embedding generation and storage- and ingested "
                "documents CRUD."
                "Each ingested document is identified by an ID that can be used to filter the "
                "context"
                "used in *Contextual Completions* and *Context Chunks* APIs.",
            },
            {
                "name": "Contextual Completions",
                "description": "High-level APIs covering contextual Chat and Completions. They "
                "follow OpenAI's format, extending it to "
                "allow using the context coming from ingested documents to create the "
                "response. Internally"
                "manage context retrieval, prompt engineering and the response generation.",
            },
            {
                "name": "Context Chunks",
                "description": "Low-level API that given a query return relevant chunks of "
                "text coming from the ingested"
                "documents.",
            },
            {
                "name": "Embeddings",
                "description": "Low-level API to obtain the vector representation of a given "
                "text, using an Embeddings model."
                "Follows OpenAI's embeddings API format.",
            },
            {
                "name": "Health",
                "description": "Simple health API to make sure the server is up and running.",
            },
        ]

        async def bind_injector_to_request(request: Request) -> None:
            request.state.injector = root_injector

        app = FastAPI(dependencies=[Depends(bind_injector_to_request)])
        

        model_router = APIRouter(prefix ='/v1',dependencies=[Depends(get_current_user)])

        @model_router.post("/switch_model")
        async def switch_model(
            new_model: str, current_user: dict = Depends(get_current_user)
        ):
            # Check if the user has the "admin" role
            if "admin" not in current_user.get("role", []):
                raise HTTPException(
                    status_code=status.HTTP_403_FORBIDDEN,
                    detail="You are not an admin and cannot use this API.",
                )

            #logic to switch the LLM model based on the user's request
            
            llm_component = root_injector.get(LLMComponent)
            llm_component.switch_model(new_model, settings=settings)

            return {"message": f"Model switched to {new_model}"}

        # Define a new APIRouter for the model_list
        model_list_router = APIRouter(prefix="/v1", dependencies=[Depends(get_current_user)])    

        @model_list_router.get("/models_list", response_model=list[dict])
        async def model_list(current_user: dict = Depends(get_current_user)):
            """
            Get a list of models with their details.
            """
            # In this example, hardcoding some sample model data
            models_data = [
                {"id": 1, "name": "gpt-3.5-turbo", "access": ["user", "admin"]},
                {"id": 2, "name": "gpt-4", "access": ["admin"]},
                # Add more models as needed
            ]

            return models_data

        # @model_router.post("/switch_model")
        # async def switch_model(new_model: str):
        #     # Implement logic to switch the LLM model based on the user's request
        #     # Example: Change the LLM model in LLMComponent based on the new_model parameter
        #     llm_component = root_injector.get(LLMComponent)
        #     llm_component.switch_model(new_model, settings=settings)

        #     return {"message": f"Model switched to {new_model}"}


        def custom_openapi() -> dict[str, Any]:
            if app.openapi_schema:
                return app.openapi_schema
            openapi_schema = get_openapi(
                title="Invenxion-Chatbot",
                description=description,
                version="0.1.0",
                summary="This is a production-ready AI project that allows you to "
                "ask questions to your documents using the power of Large Language "
                "Models (LLMs), even in scenarios without Internet connection. "
                "100% private, no data leaves your execution environment at any point.",
                license_info={
                    "name": "Apache 2.0",
                    "url": "https://www.apache.org/licenses/LICENSE-2.0.html",
                },
                routes=app.routes,
                tags=tags_metadata,
            )
            openapi_schema["info"]["x-logo"] = {
                "url": "https://lh3.googleusercontent.com/drive-viewer"
                "/AK7aPaD_iNlMoTquOBsw4boh4tIYxyEuhz6EtEs8nzq3yNkNAK00xGj"
                "E1KUCmPJSk3TYOjcs6tReG6w_cLu1S7L_gPgT9z52iw=s2560"
            }

            app.openapi_schema = openapi_schema
            return app.openapi_schema

        app.openapi = custom_openapi  # type: ignore[method-assign]

        #models.Base.metadata.create_all(bind=engine)

        # def get_db():
        #     db = SessionLocal()
        #     try:
        #         yield db
        #     finally:
        #         db.close()

        #db_dependency = Annotated[Session, Depends(get_db)]
        #user_dependency = Annotated[dict, Depends(get_current_user)]

        @app.get("/v1/me", status_code=status.HTTP_200_OK)
        async def user(current_user: dict = Depends(get_current_user)):
            if current_user is None:
                raise HTTPException(status_code=401, detail="Authentication Failed")
            return {"User": current_user}

        
        app.include_router(authentication.router)

        app.include_router(completions_router)
        app.include_router(chat_router)
        app.include_router(chunks_router)
        app.include_router(ingest_router)
        app.include_router(embeddings_router)
        app.include_router(health_router)
        app.include_router(model_router)
        app.include_router(model_list_router)

        settings = root_injector.get(Settings)
        if settings.server.cors.enabled:
            logger.debug("Setting up CORS middleware")
            app.add_middleware(
                CORSMiddleware,
                allow_credentials=settings.server.cors.allow_credentials,
                allow_origins=settings.server.cors.allow_origins,
                allow_origin_regex=settings.server.cors.allow_origin_regex,
                allow_methods=settings.server.cors.allow_methods,
                allow_headers=settings.server.cors.allow_headers,
            )

        if settings.ui.enabled:
            logger.debug("Importing the UI module")
            from private_gpt.ui.ui import PrivateGptUi

            ui = root_injector.get(PrivateGptUi)
            ui.mount_in_app(app, settings.ui.path)

        return app