ka1kuk commited on
Commit
6cede14
·
verified ·
1 Parent(s): e9fe12b

Update apis/chat_api.py

Browse files
Files changed (1) hide show
  1. apis/chat_api.py +27 -21
apis/chat_api.py CHANGED
@@ -2,6 +2,7 @@ import argparse
2
  import os
3
  import sys
4
  import uvicorn
 
5
 
6
  from fastapi import FastAPI, Depends
7
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
@@ -76,28 +77,33 @@ class ChatAPIApp:
76
  def chat_completions(
77
  self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
78
  ):
79
- streamer = MessageStreamer(model=item.model)
80
- composer = MessageComposer(model=item.model)
81
- composer.merge(messages=item.messages)
82
- # streamer.chat = stream_chat_mock
83
-
84
- stream_response = streamer.chat_response(
85
- prompt=composer.merged_str,
86
- temperature=item.temperature,
87
- max_new_tokens=item.max_tokens,
88
- api_key=api_key,
89
- )
90
- if item.stream:
91
- event_source_response = EventSourceResponse(
92
- streamer.chat_return_generator(stream_response),
93
- media_type="text/event-stream",
94
- ping=2000,
95
- ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}),
96
  )
97
- return event_source_response
98
- else:
99
- data_response = streamer.chat_return_dict(stream_response)
100
- return data_response
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  def setup_routes(self):
103
  for prefix in ["", "/v1"]:
 
2
  import os
3
  import sys
4
  import uvicorn
5
+ import traceback
6
 
7
  from fastapi import FastAPI, Depends
8
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
 
77
  def chat_completions(
78
  self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
79
  ):
80
+ try:
81
+ streamer = MessageStreamer(model=item.model)
82
+ composer = MessageComposer(model=item.model)
83
+ composer.merge(messages=item.messages)
84
+
85
+ stream_response = streamer.chat_response(
86
+ prompt=composer.merged_str,
87
+ temperature=item.temperature,
88
+ max_new_tokens=item.max_tokens,
89
+ api_key=api_key,
 
 
 
 
 
 
 
90
  )
91
+
92
+ if item.stream:
93
+ event_source_response = EventSourceResponse(
94
+ streamer.chat_return_generator(stream_response),
95
+ media_type="text/event-stream",
96
+ ping=2000,
97
+ ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}),
98
+ )
99
+ return event_source_response
100
+ else:
101
+ data_response = streamer.chat_return_dict(stream_response)
102
+ return data_response
103
+ except Exception as e:
104
+ logger.error(f"Error in chat_completions: {str(e)}")
105
+ logger.error(traceback.format_exc())
106
+ return {"error": "Internal server error"}
107
 
108
  def setup_routes(self):
109
  for prefix in ["", "/v1"]: