Update apis/chat_api.py
Browse files- apis/chat_api.py +21 -26
apis/chat_api.py
CHANGED
@@ -77,33 +77,28 @@ class ChatAPIApp:
|
|
77 |
def chat_completions(
|
78 |
self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
|
79 |
):
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
)
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
media_type="text/event-stream",
|
96 |
-
ping=2000,
|
97 |
-
ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}),
|
98 |
-
)
|
99 |
-
return event_source_response
|
100 |
-
else:
|
101 |
-
data_response = streamer.chat_return_dict(stream_response)
|
102 |
-
return data_response
|
103 |
-
except Exception as e:
|
104 |
-
logger.error(f"Error in chat_completions: {str(e)}")
|
105 |
-
logger.error(traceback.format_exc())
|
106 |
-
return {"error": "Internal server error"}
|
107 |
|
108 |
def setup_routes(self):
|
109 |
for prefix in ["", "/v1"]:
|
|
|
77 |
def chat_completions(
|
78 |
self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
|
79 |
):
|
80 |
+
streamer = MessageStreamer(model=item.model)
|
81 |
+
composer = MessageComposer(model=item.model)
|
82 |
+
composer.merge(messages=item.messages)
|
83 |
+
# streamer.chat = stream_chat_mock
|
84 |
+
|
85 |
+
stream_response = streamer.chat_response(
|
86 |
+
prompt=composer.merged_str,
|
87 |
+
temperature=item.temperature,
|
88 |
+
max_new_tokens=item.max_tokens,
|
89 |
+
api_key=api_key,
|
90 |
+
)
|
91 |
+
if item.stream:
|
92 |
+
event_source_response = EventSourceResponse(
|
93 |
+
streamer.chat_return_generator(stream_response),
|
94 |
+
media_type="text/event-stream",
|
95 |
+
ping=2000,
|
96 |
+
ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}),
|
97 |
)
|
98 |
+
return event_source_response
|
99 |
+
else:
|
100 |
+
data_response = streamer.chat_return_dict(stream_response)
|
101 |
+
return data_response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
def setup_routes(self):
|
104 |
for prefix in ["", "/v1"]:
|