File size: 5,971 Bytes
bf6d237
e489a89
bf6d237
 
a763397
bf6d237
 
 
 
 
 
 
 
 
302bfa0
 
 
bf6d237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02b6053
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a763397
 
bf6d237
 
 
 
 
 
 
 
 
02b6053
 
a763397
02b6053
 
 
 
a763397
02b6053
 
 
 
 
 
a763397
02b6053
 
 
 
 
 
 
 
 
 
 
 
 
a763397
02b6053
 
 
 
a763397
02b6053
bf6d237
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from fastapi import APIRouter, Depends, Request
from llama_index.llms import ChatMessage, MessageRole
from pydantic import BaseModel
from starlette.responses import StreamingResponse
import logging

from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.open_ai.openai_models import (
    OpenAICompletion,
    OpenAIMessage,
    to_openai_response,
    to_openai_sse_stream,
)
from private_gpt.server.chat.chat_service import ChatService
#from private_gpt.server.utils.auth import authenticated
from private_gpt.server.utils.authentication import get_current_user
chat_router = APIRouter(prefix="/v1", dependencies=[Depends(get_current_user)])


class ChatBody(BaseModel):
    messages: list[OpenAIMessage]
    use_context: bool = False
    context_filter: ContextFilter | None = None
    include_sources: bool = True
    stream: bool = False

    model_config = {
        "json_schema_extra": {
            "examples": [
                {
                    "messages": [
                        {
                            "role": "system",
                            "content": "You are a rapper. Always answer with a rap.",
                        },
                        {
                            "role": "user",
                            "content": "How do you fry an egg?",
                        },
                    ],
                    "stream": False,
                    "use_context": True,
                    "include_sources": True,
                    "context_filter": {
                        "docs_ids": ["c202d5e6-7b69-4869-81cc-dd574ee8ee11"]
                    },
                }
            ]
        }
    }


# @chat_router.post(
#     "/chat/completions",
#     response_model=None,
#     responses={200: {"model": OpenAICompletion}},
#     tags=["Contextual Completions"],
# )
# def chat_completion(
#     request: Request, body: ChatBody
# ) -> OpenAICompletion | StreamingResponse:
#     """Given a list of messages comprising a conversation, return a response.

#     Optionally include an initial `role: system` message to influence the way
#     the LLM answers.

#     If `use_context` is set to `true`, the model will use context coming
#     from the ingested documents to create the response. The documents being used can
#     be filtered using the `context_filter` and passing the document IDs to be used.
#     Ingested documents IDs can be found using `/ingest/list` endpoint. If you want
#     all ingested documents to be used, remove `context_filter` altogether.

#     When using `'include_sources': true`, the API will return the source Chunks used
#     to create the response, which come from the context provided.

#     When using `'stream': true`, the API will return data chunks following [OpenAI's
#     streaming model](https://platform.openai.com/docs/api-reference/chat/streaming):
#     ```
#     {"id":"12345","object":"completion.chunk","created":1694268190,
#     "model":"private-gpt","choices":[{"index":0,"delta":{"content":"Hello"},
#     "finish_reason":null}]}
#     ```
#     """
#     service = request.state.injector.get(ChatService)
#     all_messages = [
#         ChatMessage(content=m.content, role=MessageRole(m.role)) for m in body.messages
#     ]
#     if body.stream:
#         completion_gen = service.stream_chat(
#             messages=all_messages,
#             use_context=body.use_context,
#             context_filter=body.context_filter,
#         )
#         return StreamingResponse(
#             to_openai_sse_stream(
#                 completion_gen.response,
#                 completion_gen.sources if body.include_sources else None,
#             ),
#             media_type="text/event-stream",
#         )
#     else:
#         completion = service.chat(
#             messages=all_messages,
#             use_context=body.use_context,
#             context_filter=body.context_filter,
#         )
#         return to_openai_response(
#             completion.response, completion.sources if body.include_sources else None
#         )

logger = logging.getLogger(__name__)

@chat_router.post(
    "/chat/completions",
    response_model=None,
    responses={200: {"model": OpenAICompletion}},
    tags=["Contextual Completions"],
)
def chat_completion(
    request: Request, body: ChatBody
) -> OpenAICompletion | StreamingResponse:
    """Given a list of messages comprising a conversation, return a response."""
    try:
        logger.info("Received chat completion request with body: %s", body.json())
        service = request.state.injector.get(ChatService)
        all_messages = [
            ChatMessage(content=m.content, role=MessageRole(m.role)) for m in body.messages
        ]
        logger.info("Constructed all_messages: %s", all_messages)
        if body.stream:
            completion_gen = service.stream_chat(
                messages=all_messages,
                use_context=body.use_context,
                context_filter=body.context_filter,
            )
            logger.info("Streaming response initialized")
            return StreamingResponse(
                to_openai_sse_stream(
                    completion_gen.response,
                    completion_gen.sources if body.include_sources else None,
                ),
                media_type="text/event-stream",
            )
        else:
            completion = service.chat(
                messages=all_messages,
                use_context=body.use_context,
                context_filter=body.context_filter,
            )
            logger.info("Completed chat request: %s", completion.response)
            return to_openai_response(
                completion.response, completion.sources if body.include_sources else None
            )
    except Exception as e:
        logger.error("Error processing chat completion: %s", str(e), exc_info=True)
        return {"error": {"message": "Internal server error"}}