Update apis/chat_api.py
Browse files- apis/chat_api.py +27 -21
apis/chat_api.py
CHANGED
@@ -2,6 +2,7 @@ import argparse
|
|
2 |
import os
|
3 |
import sys
|
4 |
import uvicorn
|
|
|
5 |
|
6 |
from fastapi import FastAPI, Depends
|
7 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
@@ -76,28 +77,33 @@ class ChatAPIApp:
|
|
76 |
def chat_completions(
|
77 |
self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
|
78 |
):
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
)
|
90 |
-
if item.stream:
|
91 |
-
event_source_response = EventSourceResponse(
|
92 |
-
streamer.chat_return_generator(stream_response),
|
93 |
-
media_type="text/event-stream",
|
94 |
-
ping=2000,
|
95 |
-
ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}),
|
96 |
)
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
def setup_routes(self):
|
103 |
for prefix in ["", "/v1"]:
|
|
|
2 |
import os
|
3 |
import sys
|
4 |
import uvicorn
|
5 |
+
import traceback
|
6 |
|
7 |
from fastapi import FastAPI, Depends
|
8 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
|
77 |
def chat_completions(
|
78 |
self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
|
79 |
):
|
80 |
+
try:
|
81 |
+
streamer = MessageStreamer(model=item.model)
|
82 |
+
composer = MessageComposer(model=item.model)
|
83 |
+
composer.merge(messages=item.messages)
|
84 |
+
|
85 |
+
stream_response = streamer.chat_response(
|
86 |
+
prompt=composer.merged_str,
|
87 |
+
temperature=item.temperature,
|
88 |
+
max_new_tokens=item.max_tokens,
|
89 |
+
api_key=api_key,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
)
|
91 |
+
|
92 |
+
if item.stream:
|
93 |
+
event_source_response = EventSourceResponse(
|
94 |
+
streamer.chat_return_generator(stream_response),
|
95 |
+
media_type="text/event-stream",
|
96 |
+
ping=2000,
|
97 |
+
ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}),
|
98 |
+
)
|
99 |
+
return event_source_response
|
100 |
+
else:
|
101 |
+
data_response = streamer.chat_return_dict(stream_response)
|
102 |
+
return data_response
|
103 |
+
except Exception as e:
|
104 |
+
logger.error(f"Error in chat_completions: {str(e)}")
|
105 |
+
logger.error(traceback.format_exc())
|
106 |
+
return {"error": "Internal server error"}
|
107 |
|
108 |
def setup_routes(self):
|
109 |
for prefix in ["", "/v1"]:
|