mateoluksenberg commited on
Commit
08271ae
·
verified ·
1 Parent(s): f759069

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -21
app.py CHANGED
@@ -241,7 +241,6 @@ def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096,
241
 
242
  conversation = [{"role": "user", "content": input_text}]
243
  input_ids = tokenizer(conversation[-1]['content'], return_tensors="pt").to(model.device)
244
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
245
 
246
  generate_kwargs = dict(
247
  max_length=max_length,
@@ -250,30 +249,16 @@ def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096,
250
  top_k=top_k,
251
  temperature=temperature,
252
  repetition_penalty=penalty,
253
- eos_token_id=[151329, 151336, 151338],
254
- streamer=streamer
255
  )
256
 
257
- buffer = ""
258
-
259
- def generate_text():
260
- with torch.no_grad():
261
- model.generate(input_ids['input_ids'], **generate_kwargs)
262
-
263
- # Run generation in a separate thread
264
- thread = Thread(target=generate_text)
265
- thread.start()
266
-
267
- # Collect generated text in real-time
268
- for new_text in streamer:
269
- buffer += new_text
270
-
271
- # Wait for the generation thread to finish
272
- thread.join()
273
 
274
  # Process to remove any prefix or unwanted prompt
275
  text_original = input_text.strip()
276
- results_text = buffer[len(text_original):].strip()
277
 
278
  print(" ")
279
  print("------")
@@ -289,7 +274,6 @@ def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096,
289
 
290
 
291
 
292
-
293
  # @spaces.GPU()
294
  # def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096, top_p: float = 1, top_k: int = 10, penalty: float = 1.0):
295
  # try:
 
241
 
242
  conversation = [{"role": "user", "content": input_text}]
243
  input_ids = tokenizer(conversation[-1]['content'], return_tensors="pt").to(model.device)
 
244
 
245
  generate_kwargs = dict(
246
  max_length=max_length,
 
249
  top_k=top_k,
250
  temperature=temperature,
251
  repetition_penalty=penalty,
252
+ eos_token_id=[151329, 151336, 151338]
 
253
  )
254
 
255
+ with torch.no_grad():
256
+ generated_ids = model.generate(input_ids['input_ids'], **generate_kwargs)
257
+ generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
  # Process to remove any prefix or unwanted prompt
260
  text_original = input_text.strip()
261
+ results_text = generated_text[len(text_original):].strip()
262
 
263
  print(" ")
264
  print("------")
 
274
 
275
 
276
 
 
277
  # @spaces.GPU()
278
  # def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096, top_p: float = 1, top_k: int = 10, penalty: float = 1.0):
279
  # try: