greyfoss commited on
Commit
88c96b4
·
verified ·
1 Parent(s): 6552f96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -1
app.py CHANGED
@@ -1,3 +1,149 @@
 
 
 
1
  import gradio as gr
 
2
 
3
- gr.load("models/greyfoss/gpt2-chatbot-chinese").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ import torch
3
+ from collections import defaultdict
4
  import gradio as gr
5
+ from optimum.onnxruntime import ORTModelForCausalLM
6
 
7
+
8
+ import itertools
9
+ import re
10
+
11
+
12
+ user_token = "<User>"
13
+ eos_token = "<EOS>"
14
+ bos_token = "<BOS>"
15
+ bot_token = "<Assistant>"
16
+
17
+
18
+ def is_english_word(tested_string):
19
+ pattern = re.compile(r"^[a-zA-Z]+$")
20
+ return pattern.match(tested_string) is not None
21
+
22
+
23
+
24
+ def format(history):
25
+ prompt = bos_token
26
+
27
+ for idx, txt in enumerate(history):
28
+ if idx % 2 == 0:
29
+ prompt += f"{user_token}{txt}{eos_token}"
30
+ else:
31
+ prompt += f"{bot_token}{txt}"
32
+ prompt += bot_token
33
+ print(prompt)
34
+ return prompt
35
+
36
+ def gradio(model, tokenizer):
37
+ def response(
38
+ user_input,
39
+ chat_history,
40
+ top_k,
41
+ top_p,
42
+ temperature,
43
+ repetition_penalty,
44
+ no_repeat_ngram_size,
45
+ ):
46
+ history = list(itertools.chain(*chat_history))
47
+ history.append(user_input)
48
+
49
+ prompt = format(history)
50
+
51
+ input_ids = tokenizer.encode(
52
+ prompt,
53
+ return_tensors="pt",
54
+ add_special_tokens=False,
55
+ )
56
+
57
+ prompt_length = input_ids.shape[1]
58
+
59
+ beam_output = model.generate(
60
+ input_ids,
61
+ pad_token_id=tokenizer.pad_token_id,
62
+ max_new_tokens=255,
63
+ # num_beams=3,
64
+ top_k=top_k,
65
+ top_p=top_p,
66
+ no_repeat_ngram_size=no_repeat_ngram_size,
67
+ temperature=temperature,
68
+ repetition_penalty=repetition_penalty,
69
+ early_stopping=True,
70
+ # do_sample=True,
71
+ )
72
+ output = beam_output[0][prompt_length:]
73
+
74
+ tokens = tokenizer.convert_ids_to_tokens(output)
75
+ for i, token in enumerate(tokens[:-1]):
76
+ if is_english_word(token) and is_english_word(tokens[i + 1]):
77
+ tokens[i] = token + " "
78
+ text = "".join(tokens).replace("##", "").replace("<UNK>", "").strip()
79
+
80
+ return text
81
+
82
+ bot = gr.Chatbot(scale=8)
83
+
84
+ with gr.Blocks() as demo:
85
+ gr.Markdown("GPT2 chatbot | Powered by nlp-greyfoss")
86
+
87
+ with gr.Accordion("Parameters in generation", open=False):
88
+ with gr.Row():
89
+ top_k = gr.Slider(
90
+ 2.0,
91
+ 100.0,
92
+ label="top_k",
93
+ step=1,
94
+ value=50,
95
+ info="Limit the number of candidate tokens considered during decoding.",
96
+ )
97
+ top_p = gr.Slider(
98
+ 0.1,
99
+ 1.0,
100
+ label="top_p",
101
+ value=0.9,
102
+ info="Control the diversity of the output by selecting tokens with cumulative probabilities up to the Top-P threshold.",
103
+ )
104
+ temperature = gr.Slider(
105
+ 0.1,
106
+ 2.0,
107
+ label="temperature",
108
+ value=0.9,
109
+ info="Control the randomness of the generated text. A higher temperature results in more diverse and unpredictable outputs, while a lower temperature produces more conservative and coherent text.",
110
+ )
111
+ repetition_penalty = gr.Slider(
112
+ 0.1,
113
+ 2.0,
114
+ label="repetition_penalty",
115
+ value=1.2,
116
+ info="Discourage the model from generating repetitive tokens in a sequence.",
117
+ )
118
+ no_repeat_ngram_size = gr.Slider(
119
+ 0,
120
+ 100,
121
+ label="no_repeat_ngram_size",
122
+ step=1,
123
+ value=5,
124
+ info="Prevent the model from generating sequences of n consecutive tokens that have already been generated in the context. ",
125
+ )
126
+
127
+ gr.ChatInterface(
128
+ response,
129
+ chatbot=bot,
130
+ fill_vertical_space=True,
131
+ additional_inputs=[
132
+ top_k,
133
+ top_p,
134
+ temperature,
135
+ repetition_penalty,
136
+ no_repeat_ngram_size,
137
+ ],
138
+ )
139
+
140
+ demo.queue().launch()
141
+
142
+
143
+
144
+
145
+ tokenizer = AutoTokenizer.from_pretrained("greyfoss/gpt2-chatbot-chinese")
146
+
147
+ model = ORTModelForCausalLM.from_pretrained("greyfoss/gpt2-chatbot-chinese", export=True)
148
+
149
+ gradio(model, tokenizer)