Spaces:
Sleeping
Sleeping
File size: 5,971 Bytes
bf6d237 e489a89 bf6d237 a763397 bf6d237 302bfa0 bf6d237 02b6053 a763397 bf6d237 02b6053 a763397 02b6053 a763397 02b6053 a763397 02b6053 a763397 02b6053 a763397 02b6053 bf6d237 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
from fastapi import APIRouter, Depends, Request
from llama_index.llms import ChatMessage, MessageRole
from pydantic import BaseModel
from starlette.responses import StreamingResponse
import logging
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.open_ai.openai_models import (
OpenAICompletion,
OpenAIMessage,
to_openai_response,
to_openai_sse_stream,
)
from private_gpt.server.chat.chat_service import ChatService
#from private_gpt.server.utils.auth import authenticated
from private_gpt.server.utils.authentication import get_current_user
chat_router = APIRouter(prefix="/v1", dependencies=[Depends(get_current_user)])
class ChatBody(BaseModel):
messages: list[OpenAIMessage]
use_context: bool = False
context_filter: ContextFilter | None = None
include_sources: bool = True
stream: bool = False
model_config = {
"json_schema_extra": {
"examples": [
{
"messages": [
{
"role": "system",
"content": "You are a rapper. Always answer with a rap.",
},
{
"role": "user",
"content": "How do you fry an egg?",
},
],
"stream": False,
"use_context": True,
"include_sources": True,
"context_filter": {
"docs_ids": ["c202d5e6-7b69-4869-81cc-dd574ee8ee11"]
},
}
]
}
}
# @chat_router.post(
# "/chat/completions",
# response_model=None,
# responses={200: {"model": OpenAICompletion}},
# tags=["Contextual Completions"],
# )
# def chat_completion(
# request: Request, body: ChatBody
# ) -> OpenAICompletion | StreamingResponse:
# """Given a list of messages comprising a conversation, return a response.
# Optionally include an initial `role: system` message to influence the way
# the LLM answers.
# If `use_context` is set to `true`, the model will use context coming
# from the ingested documents to create the response. The documents being used can
# be filtered using the `context_filter` and passing the document IDs to be used.
# Ingested documents IDs can be found using `/ingest/list` endpoint. If you want
# all ingested documents to be used, remove `context_filter` altogether.
# When using `'include_sources': true`, the API will return the source Chunks used
# to create the response, which come from the context provided.
# When using `'stream': true`, the API will return data chunks following [OpenAI's
# streaming model](https://platform.openai.com/docs/api-reference/chat/streaming):
# ```
# {"id":"12345","object":"completion.chunk","created":1694268190,
# "model":"private-gpt","choices":[{"index":0,"delta":{"content":"Hello"},
# "finish_reason":null}]}
# ```
# """
# service = request.state.injector.get(ChatService)
# all_messages = [
# ChatMessage(content=m.content, role=MessageRole(m.role)) for m in body.messages
# ]
# if body.stream:
# completion_gen = service.stream_chat(
# messages=all_messages,
# use_context=body.use_context,
# context_filter=body.context_filter,
# )
# return StreamingResponse(
# to_openai_sse_stream(
# completion_gen.response,
# completion_gen.sources if body.include_sources else None,
# ),
# media_type="text/event-stream",
# )
# else:
# completion = service.chat(
# messages=all_messages,
# use_context=body.use_context,
# context_filter=body.context_filter,
# )
# return to_openai_response(
# completion.response, completion.sources if body.include_sources else None
# )
logger = logging.getLogger(__name__)
@chat_router.post(
"/chat/completions",
response_model=None,
responses={200: {"model": OpenAICompletion}},
tags=["Contextual Completions"],
)
def chat_completion(
request: Request, body: ChatBody
) -> OpenAICompletion | StreamingResponse:
"""Given a list of messages comprising a conversation, return a response."""
try:
logger.info("Received chat completion request with body: %s", body.json())
service = request.state.injector.get(ChatService)
all_messages = [
ChatMessage(content=m.content, role=MessageRole(m.role)) for m in body.messages
]
logger.info("Constructed all_messages: %s", all_messages)
if body.stream:
completion_gen = service.stream_chat(
messages=all_messages,
use_context=body.use_context,
context_filter=body.context_filter,
)
logger.info("Streaming response initialized")
return StreamingResponse(
to_openai_sse_stream(
completion_gen.response,
completion_gen.sources if body.include_sources else None,
),
media_type="text/event-stream",
)
else:
completion = service.chat(
messages=all_messages,
use_context=body.use_context,
context_filter=body.context_filter,
)
logger.info("Completed chat request: %s", completion.response)
return to_openai_response(
completion.response, completion.sources if body.include_sources else None
)
except Exception as e:
logger.error("Error processing chat completion: %s", str(e), exc_info=True)
return {"error": {"message": "Internal server error"}}
|