Akjava commited on
Commit
43cc94e
Β·
verified Β·
1 Parent(s): 8ecfb8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -2,6 +2,7 @@ import spaces
2
  import os
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
5
  import gradio as gr
6
 
7
  text_generator = None
@@ -36,7 +37,7 @@ if not is_hugging_face:
36
  model = AutoModelForCausalLM.from_pretrained(
37
  model_id, token=huggingface_token ,torch_dtype=dtype,device_map=device
38
  )
39
- text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer,torch_dtype=dtype,device_map=device ) #pipeline has not to(device)
40
 
41
  if next(model.parameters()).is_cuda:
42
  print("The model is on a GPU")
@@ -57,9 +58,10 @@ def generate_text(messages):
57
  model = AutoModelForCausalLM.from_pretrained(
58
  model_id, token=huggingface_token ,torch_dtype=dtype,device_map=device
59
  )
60
- text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer,torch_dtype=dtype,device_map=device ,stream=True) #pipeline has not to(device)
 
61
  result = text_generator(messages, max_new_tokens=256, do_sample=True, temperature=0.7)
62
-
63
  generated_output = ""
64
  for token in result:
65
  generated_output += token["generated_token"]
 
2
  import os
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
5
+ from transformers import TextStreamer
6
  import gradio as gr
7
 
8
  text_generator = None
 
37
  model = AutoModelForCausalLM.from_pretrained(
38
  model_id, token=huggingface_token ,torch_dtype=dtype,device_map=device
39
  )
40
+ text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer,torch_dtype=dtype,device_map=device,stream=True ) #pipeline has not to(device)
41
 
42
  if next(model.parameters()).is_cuda:
43
  print("The model is on a GPU")
 
58
  model = AutoModelForCausalLM.from_pretrained(
59
  model_id, token=huggingface_token ,torch_dtype=dtype,device_map=device
60
  )
61
+ streamer = TextStreamer(tokenizer, skip_prompt=True)
62
+ text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer,torch_dtype=dtype,device_map=device ,streamer=streamer) #pipeline has not to(device)
63
  result = text_generator(messages, max_new_tokens=256, do_sample=True, temperature=0.7)
64
+ print(result)
65
  generated_output = ""
66
  for token in result:
67
  generated_output += token["generated_token"]