Akjava commited on
Commit
8ecfb8d
Β·
verified Β·
1 Parent(s): 786d1a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -16
app.py CHANGED
@@ -57,20 +57,13 @@ def generate_text(messages):
57
  model = AutoModelForCausalLM.from_pretrained(
58
  model_id, token=huggingface_token ,torch_dtype=dtype,device_map=device
59
  )
60
- text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer,torch_dtype=dtype,device_map=device ) #pipeline has not to(device)
61
  result = text_generator(messages, max_new_tokens=256, do_sample=True, temperature=0.7)
62
 
63
- generated_output = result[0]["generated_text"]
64
- if isinstance(generated_output, list):
65
- for message in reversed(generated_output):
66
- if message.get("role") == "assistant":
67
- content= message.get("content", "No content found.")
68
- return content
69
-
70
- return "No assistant response found."
71
- else:
72
- return "Unexpected output format."
73
-
74
 
75
 
76
  def call_generate_text(message, history):
@@ -80,12 +73,12 @@ def call_generate_text(message, history):
80
 
81
  messages = history+[{"role":"user","content":message}]
82
  try:
83
- text = generate_text(messages)
84
- return text
 
85
  except RuntimeError as e:
86
  print(f"An unexpected error occurred: {e}")
87
-
88
- return ""
89
 
90
  demo = gr.ChatInterface(call_generate_text,type="messages")
91
 
 
57
  model = AutoModelForCausalLM.from_pretrained(
58
  model_id, token=huggingface_token ,torch_dtype=dtype,device_map=device
59
  )
60
+ text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer,torch_dtype=dtype,device_map=device ,stream=True) #pipeline has not to(device)
61
  result = text_generator(messages, max_new_tokens=256, do_sample=True, temperature=0.7)
62
 
63
+ generated_output = ""
64
+ for token in result:
65
+ generated_output += token["generated_token"]
66
+ yield generated_output
 
 
 
 
 
 
 
67
 
68
 
69
  def call_generate_text(message, history):
 
73
 
74
  messages = history+[{"role":"user","content":message}]
75
  try:
76
+
77
+ for text in generate_text(messages):
78
+ yield text
79
  except RuntimeError as e:
80
  print(f"An unexpected error occurred: {e}")
81
+ yield ""
 
82
 
83
  demo = gr.ChatInterface(call_generate_text,type="messages")
84