jonathanjordan21 commited on
Commit
9c416cb
·
verified ·
1 Parent(s): e37b4b3

Update apis/chat_api.py

Browse files
Files changed (1) hide show
  1. apis/chat_api.py +30 -20
apis/chat_api.py CHANGED
@@ -151,26 +151,36 @@ class ChatAPIApp:
151
  default="Hello, who are you?",
152
  description="(str) Prompt",
153
  )
154
- temperature: Union[float, None] = Field(
155
- default=0.5,
156
- description="(float) Temperature",
157
- )
158
- top_p: Union[float, None] = Field(
159
- default=0.95,
160
- description="(float) top p",
161
- )
162
- max_tokens: Union[int, None] = Field(
163
- default=-1,
164
- description="(int) Max tokens",
165
- )
166
- use_cache: bool = Field(
167
- default=False,
168
- description="(bool) Use cache",
169
- )
170
  stream: bool = Field(
171
  default=True,
172
  description="(bool) Stream",
173
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  def generate_text(
176
  self, item: GenerateRequest, api_key: str = Depends(extract_api_key)
@@ -190,11 +200,11 @@ class ChatAPIApp:
190
  streamer = HuggingfaceStreamer(model=item.model)
191
  stream_response = streamer.chat_response(
192
  prompt=item.prompt,
193
- temperature=item.temperature,
194
- top_p=item.top_p,
195
- max_new_tokens=item.max_tokens,
196
  api_key=api_key,
197
- use_cache=item.use_cache,
198
  )
199
 
200
  if item.stream:
 
151
  default="Hello, who are you?",
152
  description="(str) Prompt",
153
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  stream: bool = Field(
155
  default=True,
156
  description="(bool) Stream",
157
  )
158
+ options: dict = Field(
159
+ default={
160
+ "temperature":0.5,
161
+ "top_p":0.95,
162
+ "max_tokens":-1,
163
+ "use_cache":False
164
+ },
165
+ description="(dict) Options"
166
+ )
167
+ # temperature: Union[float, None] = Field(
168
+ # default=0.5,
169
+ # description="(float) Temperature",
170
+ # )
171
+ # top_p: Union[float, None] = Field(
172
+ # default=0.95,
173
+ # description="(float) top p",
174
+ # )
175
+ # max_tokens: Union[int, None] = Field(
176
+ # default=-1,
177
+ # description="(int) Max tokens",
178
+ # )
179
+ # use_cache: bool = Field(
180
+ # default=False,
181
+ # description="(bool) Use cache",
182
+ # )
183
+
184
 
185
  def generate_text(
186
  self, item: GenerateRequest, api_key: str = Depends(extract_api_key)
 
200
  streamer = HuggingfaceStreamer(model=item.model)
201
  stream_response = streamer.chat_response(
202
  prompt=item.prompt,
203
+ temperature=item.options.get('temperature'),
204
+ top_p=item.options.get('top_p'),
205
+ max_new_tokens=item.options.get('max_new_tokens'),
206
  api_key=api_key,
207
+ use_cache=item.options.get('use_cache'),
208
  )
209
 
210
  if item.stream: