mateoluksenberg commited on
Commit
0063488
·
verified ·
1 Parent(s): b633491

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -36
app.py CHANGED
@@ -260,45 +260,22 @@ def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096,
260
  gen_kwargs = {**input_ids, **generate_kwargs}
261
 
262
  with torch.no_grad():
263
- thread = Thread(target=model.generate, kwargs=gen_kwargs)
264
- thread.start()
265
- buffer = ""
266
- for new_text in streamer:
267
- buffer += new_text
268
 
269
- return buffer
270
 
 
 
 
 
 
 
 
 
271
  except Exception as e:
272
- return f"Error: {str(e)}"
273
-
274
- @app.post("/chat/")
275
- async def test_endpoint(
276
- text: str = Form(...),
277
- file: UploadFile = File(None)
278
- ):
279
- if file:
280
- file_content = BytesIO(await file.read())
281
- file_name = file.filename
282
-
283
- message = {
284
- "text": text,
285
- "file_content": file_content,
286
- "file_name": file_name
287
- }
288
- else:
289
- message = {
290
- "text": text,
291
- "file_content": None,
292
- "file_name": None
293
- }
294
-
295
- print(message)
296
- response = simple_chat(message)
297
-
298
- if isinstance(response, str) and response.startswith("Error:"):
299
- return PlainTextResponse(response, status_code=500)
300
- else:
301
- return StreamingResponse(BytesIO(response.encode()), media_type="text/plain")
302
 
303
  # def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096, top_p: float = 1, top_k: int = 10, penalty: float = 1.0):
304
  # try:
 
260
  gen_kwargs = {**input_ids, **generate_kwargs}
261
 
262
  with torch.no_grad():
263
+ generated_ids = model.generate(input_ids['input_ids'], **generate_kwargs)
264
+ generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
 
 
 
265
 
 
266
 
267
+ print("---------")
268
+ print(generated_ids[0])
269
+ print("---------")
270
+ print(generated_text)
271
+ print("---------")
272
+
273
+
274
+ return PlainTextResponse(generated_text)
275
  except Exception as e:
276
+ return PlainTextResponse(f"Error: {str(e)}")
277
+
278
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
  # def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096, top_p: float = 1, top_k: int = 10, penalty: float = 1.0):
281
  # try: