Elliot4AI commited on
Commit
089c48d
·
1 Parent(s): 11b5b33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -2
app.py CHANGED
@@ -1,3 +1,79 @@
1
- import gradio as gr
 
 
 
 
 
 
 
2
 
3
- gr.Interface.load("models/Elliot4AI/Dugong-Llama2-7b-chinese").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir_merge = "Elliot4AI/Dugong-Llama2-7b-chinese"
2
+ # load base LLM model and tokenizer
3
+ model = AutoModelForCausalLM.from_pretrained(
4
+ output_dir_merge,
5
+ low_cpu_mem_usage=True,
6
+ torch_dtype=torch.float16,
7
+ load_in_8bit=True,
8
+ )
9
 
10
+ tokenizer = AutoTokenizer.from_pretrained(output_dir_merge)
11
+
12
+ def run_generation(user_text, top_p, temperature, top_k, max_new_tokens):
13
+ # Get the model and tokenizer, and tokenize the user text.
14
+ model_inputs = tokenizer([user_text], return_tensors="pt").input_ids.cuda()
15
+
16
+ # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
17
+ # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
18
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
19
+ generate_kwargs = dict(
20
+ inputs=model_inputs,
21
+ streamer=streamer,
22
+ max_new_tokens=max_new_tokens,
23
+ do_sample=True,
24
+ top_p=top_p,
25
+ temperature=float(temperature),
26
+ top_k=top_k
27
+ # repetition_penalty=2.0
28
+ )
29
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
30
+ t.start()
31
+
32
+ # Pull the generated text from the streamer, and update the model output.
33
+ model_output = ""
34
+ for new_text in streamer:
35
+ model_output += new_text
36
+ yield model_output
37
+ return model_output
38
+
39
+
40
+ def reset_textbox():
41
+ return gr.update(value='')
42
+ with gr.Blocks() as demo:
43
+ with gr.Tab("PatentQA-Dugong-Llama2-7b-chinese Agent"):
44
+ gr.Markdown(
45
+ "# 🤗 PatentQA_Dugong 🔥PatentQA_Dugong Agent🔥 \n"
46
+ "Dugong是一个用中文微调的Llama2-7b的模型, 微调后中文回答更顺畅 "
47
+ "目前采用流式输出"
48
+ "🤗💛"
49
+ )
50
+ # gr.Markdown("PatentQA_Dugong Agent: Dugong是一个用中文微调的Llama2-7b的模型, 微调后中文回答更顺畅,并且具有丰富英业达专利知识的人工智能助手,可以回答专利的相关信息,目前恢复速度稍慢")
51
+ with gr.Row():
52
+ with gr.Column(scale=4):
53
+ user_text = gr.Textbox(
54
+ placeholder="请输入你的问题",
55
+ label="问题"
56
+ )
57
+ model_output = gr.Textbox(label="回答", lines=10, interactive=False)
58
+ button_submit = gr.Button(value="提交")
59
+ clear = gr.ClearButton([user_text, model_output])
60
+
61
+ with gr.Column(scale=1):
62
+ max_new_tokens = gr.Slider(
63
+ minimum=1, maximum=1000, value=250, step=1, interactive=True, label="最大输出token数量",
64
+ )
65
+ top_p = gr.Slider(
66
+ minimum=0.05, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p (nucleus sampling)",
67
+ )
68
+ top_k = gr.Slider(
69
+ minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k",
70
+ )
71
+ temperature = gr.Slider(
72
+ minimum=0.1, maximum=5.0, value=0.8, step=0.1, interactive=True, label="温度",
73
+ )
74
+
75
+ user_text.submit(run_generation, [user_text, top_p, temperature, top_k, max_new_tokens], model_output)
76
+ button_submit.click(run_generation, [user_text, top_p, temperature, top_k, max_new_tokens], model_output)
77
+
78
+ demo.queue(max_size=32)
79
+ demo.launch(enable_queue=True)