Spaces:
Runtime error
Runtime error
File size: 4,816 Bytes
3997444 81c24b6 3997444 e586fb9 81c24b6 e586fb9 3997444 466a76b 3997444 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import os
import gradio as gr
import mdtex2html
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
# Initialize model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-14B-Chat", device_map="auto", trust_remote_code=True, use_flash_attn=True).eval()
model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True)
# Postprocess function
def postprocess(self, y):
if y is None:
return []
for i, (message, response) in enumerate(y):
y[i] = (
None if message is None else mdtex2html.convert(message),
None if response is None else mdtex2html.convert(response),
)
return y
gr.Chatbot.postprocess = postprocess
# Text parsing function
def _parse_text(text):
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split("`")
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f"<br></code></pre>"
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", r"\`")
line = line.replace("<", "<")
line = line.replace(">", ">")
line = line.replace(" ", " ")
line = line.replace("*", "*")
line = line.replace("_", "_")
line = line.replace("-", "-")
line = line.replace(".", ".")
line = line.replace("!", "!")
line = line.replace("(", "(")
line = line.replace(")", ")")
line = line.replace("$", "$")
lines[i] = "<br>" + line
text = "".join(lines)
return text
# Demo launching function
def _launch_demo(args, model, tokenizer, config):
def predict(_query, _chatbot, _task_history):
print(f"User: {_parse_text(_query)}")
_chatbot.append((_parse_text(_query), ""))
full_response = ""
for response in model.chat_stream(tokenizer, _query, history=_task_history, generation_config=config):
_chatbot[-1] = (_parse_text(_query), _parse_text(response))
yield _chatbot
full_response = _parse_text(response)
print(f"History: {_task_history}")
_task_history.append((_query, full_response))
print(f"Qwen-Chat: {_parse_text(full_response)}")
def regenerate(_chatbot, _task_history):
if not _task_history:
yield _chatbot
return
item = _task_history.pop(-1)
_chatbot.pop(-1)
yield from predict(item[0], _chatbot, _task_history)
def reset_user_input():
return gr.update(value="")
def reset_state(_chatbot, _task_history):
_task_history.clear()
_chatbot.clear()
import gc
gc.collect()
torch.cuda.empty_cache()
return _chatbot
with gr.Blocks() as demo:
gr.Markdown("""
## Qwen-14B-Chat: A Large Language Model by Alibaba Cloud
**Space created by [@artificialguybr](https://twitter.com/artificialguybr) based on QWEN Code. Thanks HF for GPU!**
**Qwen is currently SOTA in the benchmarks for 14B models.**
### Performance Metrics:
- **MMLU Accuracy**:
- 0-shot: 64.6
- 5-shot: 66.5
- **HumanEval Pass@1**: 43.9
- **GSM8K Accuracy**:
- 0-shot: 60.1
- 8-shot: 59.3
""")
chatbot = gr.Chatbot(label='Qwen-Chat', elem_classes="control-height", queue=True)
query = gr.Textbox(lines=2, label='Input')
task_history = gr.State([])
with gr.Row():
empty_btn = gr.Button("๐งน Clear History")
submit_btn = gr.Button("๐ Submit")
regen_btn = gr.Button("๐ค๏ธ Regenerate")
submit_btn.click(predict, [query, chatbot, task_history], [chatbot], show_progress=True, queue=True) # Enable queue
submit_btn.click(reset_user_input, [], [query])
empty_btn.click(reset_state, [chatbot, task_history], outputs=[chatbot], show_progress=True)
regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True, queue=True) # Enable queue
demo.queue(max_size=20)
demo.launch(share=True)
# Main execution
if __name__ == "__main__":
_launch_demo(None, model, tokenizer, model.generation_config)
|