wannaphong commited on
Commit
1c793ed
·
1 Parent(s): 10f78b6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +423 -0
app.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import tensorflow_hub as hub
6
+ import tensorflow_text
7
+
8
+
9
+ class Encoder(ABC):
10
+ @abstractmethod
11
+ def encode(self, texts: List[str]) -> np.array:
12
+ """
13
+ output dimension expected to be one dimension and normalized (unit vector)
14
+ """
15
+ ...
16
+
17
+
18
+ class MUSEEncoder(Encoder):
19
+ def __init__(self, model_url: str = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"):
20
+ self.embed = hub.load(model_url)
21
+
22
+ def encode(self, texts: List[str]) -> np.array:
23
+ embeds = self.embed(texts).numpy()
24
+ embeds = embeds / np.linalg.norm(embeds, axis=1).reshape(embeds.shape[0], -1)
25
+ return embeds
26
+
27
+
28
+ from dataclasses import dataclass
29
+ from typing import Dict, List, Tuple
30
+
31
+ import numpy as np
32
+ import tensorflow as tf
33
+
34
+
35
+ @dataclass
36
+ class SensitiveTopic:
37
+ name: str
38
+ respond_message: str
39
+ sensitivity: float = None # range from 0 to 1
40
+ demonstrations: List[str] = None
41
+ adhoc_embeded_demonstrations: np.array = None # dimension = [N_ADHOC, DIM]. Please kindly note that this suppose to
42
+
43
+
44
+ DEFAULT_SENSITIVITY = 0.7
45
+
46
+
47
+ class SensitiveTopicProtector:
48
+ def __init__(
49
+ self,
50
+ sensitive_topics: List[SensitiveTopic],
51
+ encoder: Encoder = MUSEEncoder(),
52
+ default_sensitivity: float = DEFAULT_SENSITIVITY
53
+ ):
54
+ self.sensitive_topics = sensitive_topics
55
+ self.default_sensitivity = default_sensitivity
56
+ self.encoder = encoder
57
+ self.topic_embeddings = self._get_topic_embeddings()
58
+
59
+ def _get_topic_embeddings(self) -> Dict[str, List[np.array]]:
60
+ topic_embeddings = {}
61
+ for topic in self.sensitive_topics:
62
+ current_topic_embeddings = None
63
+ if topic.demonstrations is not None:
64
+ current_topic_embeddings = self.encoder.encode(texts=topic.demonstrations) if current_topic_embeddings is None \
65
+ else np.concatenate((current_topic_embeddings, self.encoder.encode(texts=topic.demonstrations)), axis=0)
66
+ if topic.adhoc_embeded_demonstrations is not None:
67
+ current_topic_embeddings = topic.adhoc_embeded_demonstrations if current_topic_embeddings is None \
68
+ else np.concatenate((current_topic_embeddings, topic.adhoc_embeded_demonstrations), axis=0)
69
+ topic_embeddings[topic.name] = current_topic_embeddings
70
+ return topic_embeddings
71
+
72
+ def filter(self, text: str) -> Tuple[bool, str]:
73
+ is_sensitive, respond_message = False, None
74
+ text_embedding = self.encoder.encode([text,])
75
+ for topic in self.sensitive_topics:
76
+ risk_scores = np.einsum('ik,jk->j', text_embedding, self.topic_embeddings[topic.name])
77
+ max_risk_score = np.max(risk_scores)
78
+ if topic.sensitivity:
79
+ if max_risk_score > (1.0 - topic.sensitivity):
80
+ return True, topic.respond_message
81
+ continue
82
+ if max_risk_score > (1.0 - self.default_sensitivity):
83
+ return True, topic.respond_message
84
+ return is_sensitive, respond_message
85
+
86
+ @classmethod
87
+ def fromRaw(cls, raw_sensitive_topics: List[Dict], encoder: Encoder = MUSEEncoder(), default_sensitivity: float = DEFAULT_SENSITIVITY):
88
+ sensitive_topics = [SensitiveTopic(**topic) for topic in raw_sensitive_topics]
89
+ return cls(sensitive_topics=sensitive_topics, encoder=encoder, default_sensitivity=default_sensitivity)
90
+
91
+
92
+ import pickle
93
+
94
+ f = open("sensitive_topics.pkl", "rb")
95
+ sensitive_topics = pickle.load(f)
96
+ f.close()
97
+
98
+ guardian = SensitiveTopicProtector.fromRaw(sensitive_topics)
99
+
100
+ import warnings
101
+ warnings.filterwarnings("ignore")
102
+ import gradio as gr
103
+ import torch
104
+ from transformers import AutoModelForCausalLM, AutoTokenizer
105
+ from typing import Optional, Union, List, Dict, Any
106
+ import random
107
+ import time
108
+ import datetime
109
+ import os
110
+ import re
111
+ import pandas as pd
112
+
113
+ # name_model = "pythainlp/wangchanglm-7.5B-sft-adapter-merged-sharded"
114
+ model = AutoModelForCausalLM.from_pretrained(
115
+ name_model,
116
+ device_map="auto",
117
+ torch_dtype=torch.bfloat16,
118
+ offload_folder="./",
119
+ low_cpu_mem_usage=True,
120
+ )
121
+ tokenizer = AutoTokenizer.from_pretrained("facebook/xglm-7.5B")
122
+
123
+
124
+ Thai = "Yes"
125
+
126
+ from langchain.llms import HuggingFacePipeline
127
+ from transformers import pipeline
128
+
129
+ from langchain import ConversationChain, LLMChain, PromptTemplate
130
+ from langchain.memory import ConversationBufferWindowMemory
131
+ import torch
132
+
133
+
134
+ from transformers import AutoTokenizer,AutoModelForCausalLM
135
+
136
+ template = """
137
+ {history}
138
+ : {human_input}
139
+ :"""
140
+
141
+ prompt = PromptTemplate(
142
+ input_variables=["history", "human_input"],
143
+ template=template
144
+ )
145
+ exclude_pattern = re.compile(r'[^ก-๙]+') #|[^0-9a-zA-Z]+
146
+ def is_exclude(text):
147
+ return bool(exclude_pattern.search(text))
148
+
149
+ df = pd.DataFrame(tokenizer.vocab.items(), columns=['text', 'idx'])
150
+ df['is_exclude'] = df.text.map(is_exclude)
151
+ exclude_ids = df[df.is_exclude==True].idx.tolist()
152
+ if Thai=="Yes":
153
+ pipe = pipeline(
154
+ "text-generation",
155
+ model=model,
156
+ tokenizer=tokenizer,
157
+ max_new_tokens=512,
158
+ begin_suppress_tokens=exclude_ids,
159
+ no_repeat_ngram_size=2,
160
+ )
161
+ else:
162
+ pipe = pipeline(
163
+ "text-generation",
164
+ model=model,
165
+ tokenizer=tokenizer,
166
+ max_new_tokens=512,
167
+ no_repeat_ngram_size=2,
168
+ )
169
+ hf_pipeline = HuggingFacePipeline(pipeline=pipe)
170
+
171
+ chatgpt_chain = LLMChain(
172
+ llm=hf_pipeline,
173
+ prompt=prompt,
174
+ verbose=True,
175
+ memory=ConversationBufferWindowMemory(k=2),
176
+ )
177
+
178
+
179
+ api_url = "https://wangchanglm.numfa.com/api.php" # Don't open this url!!!
180
+ import requests
181
+ from urllib.request import urlopen
182
+ from urllib.parse import urlencode
183
+ from urllib.error import HTTPError, URLError
184
+ from urllib.request import Request
185
+ import copy
186
+
187
+ def sumbit_data(save,prompt,vote,feedback=None,max_len=None,temp=None,top_p=None,name_model=name_model):
188
+ api_url = "https://wangchanglm.numfa.com/api.php"
189
+ myobj = {
190
+ 'save': save,
191
+ 'prompt':prompt,
192
+ 'vote':vote,
193
+ 'feedback':feedback,
194
+ 'max_len':max_len,
195
+ 'temp':temp,
196
+ 'top_p':top_p,
197
+ 'model':name_model
198
+ }
199
+ _temp_url ="https://wangchanglm.numfa.com/api.php"
200
+ _temp_url += "?" + urlencode(myobj, doseq=True, safe="/")
201
+ html = urlopen(_temp_url).read().decode('utf-8')
202
+ return True
203
+
204
+
205
+ def gen_instruct(text,max_new_tokens=512,top_p=0.95,temperature=0.9,top_k=50):
206
+ batch = tokenizer(text, return_tensors="pt")
207
+ with torch.cuda.amp.autocast(): # cuda -> cpu if cpu
208
+ if Thai=="Yes":
209
+ output_tokens = model.generate(
210
+ input_ids=batch["input_ids"],
211
+ max_new_tokens=max_new_tokens, # 512
212
+ begin_suppress_tokens = exclude_ids,
213
+ no_repeat_ngram_size=2,
214
+ #oasst k50
215
+ top_k=top_k,
216
+ top_p=top_p, # 0.95
217
+ typical_p=1.,
218
+ temperature=temperature, # 0.9
219
+ )
220
+ else:
221
+ output_tokens = model.generate(
222
+ input_ids=batch["input_ids"],
223
+ max_new_tokens=max_new_tokens, # 512
224
+ no_repeat_ngram_size=2,
225
+ #oasst k50
226
+ top_k=top_k,
227
+ top_p=top_p, # 0.95
228
+ typical_p=1.,
229
+ temperature=temperature, # 0.9
230
+ )
231
+ return tokenizer.decode(output_tokens[0][len(batch["input_ids"][0]):], skip_special_tokens=True)
232
+
233
+ def gen_chatbot_old(text):
234
+ is_sensitive, respond_message = guardian.filter(text)
235
+ if is_sensitive:
236
+ return respond_message
237
+
238
+ batch = tokenizer(text, return_tensors="pt")
239
+ #context_tokens = tokenizer(text, add_special_tokens=False)['input_ids']
240
+ #logits_processor = FocusContextProcessor(context_tokens, model.config.vocab_size, scaling_factor = 1.5)
241
+ with torch.cpu.amp.autocast(): # cuda if gpu
242
+ output_tokens = model.generate(
243
+ input_ids=batch["input_ids"],
244
+ max_new_tokens=512,
245
+ begin_suppress_tokens = exclude_ids,
246
+ no_repeat_ngram_size=2,
247
+ )
248
+ return tokenizer.decode(output_tokens[0], skip_special_tokens=True).split(": ")[-1]
249
+
250
+ def list2prompt(history):
251
+ _text = ""
252
+ for user,bot in history:
253
+ _text+=": "+user+"\n: "
254
+ if bot!=None:
255
+ _text+=bot+"\n"
256
+ return _text
257
+
258
+ PROMPT_DICT = {
259
+ "prompt_input": (
260
+ ": {input}\n: {instruction}\n: "
261
+ ),
262
+ "prompt_no_input": (
263
+ ": {instruction}\n: "
264
+ ),
265
+ }
266
+
267
+ def instruct_generate(
268
+ instruct: str,
269
+ input: str = 'none',
270
+ max_gen_len=512,
271
+ temperature: float = 0.1,
272
+ top_p: float = 0.75,
273
+ ):
274
+ is_sensitive, respond_message = guardian.filter(instruct)
275
+ if is_sensitive:
276
+ return respond_message
277
+
278
+ if input == 'none' or len(input)<2:
279
+ prompt = PROMPT_DICT['prompt_no_input'].format_map(
280
+ {'instruction': instruct, 'input': ''})
281
+ else:
282
+ prompt = PROMPT_DICT['prompt_input'].format_map(
283
+ {'instruction': instruct, 'input': input})
284
+ result = gen_instruct(prompt,max_gen_len,top_p,temperature)
285
+ return result
286
+
287
+ with gr.Blocks(height=900) as demo:
288
+ chatgpt_chain = LLMChain(
289
+ llm=hf_pipeline,
290
+ prompt=prompt,
291
+ verbose=True,
292
+ memory=ConversationBufferWindowMemory(k=2),
293
+ )
294
+ with gr.Tab("Text Generation"):
295
+ with gr.Row():
296
+ with gr.Column():
297
+ instruction = gr.Textbox(lines=2, label="Instruction",max_lines=10)
298
+ input = gr.Textbox(
299
+ lines=2, label="Context input", placeholder='none',max_lines=5)
300
+ max_len = gr.Slider(minimum=1, maximum=1024,
301
+ value=512, label="Max new tokens")
302
+ with gr.Accordion(label='Advanced options', open=False):
303
+ temp = gr.Slider(minimum=0, maximum=1,
304
+ value=0.9, label="Temperature")
305
+ top_p = gr.Slider(minimum=0, maximum=1,
306
+ value=0.95, label="Top p")
307
+
308
+ run_botton = gr.Button("Run")
309
+
310
+ with gr.Column():
311
+ outputs = gr.Textbox(lines=10, label="Output")
312
+ with gr.Column(visible=False) as feedback_gen_box:
313
+ gen_radio = gr.Radio(
314
+ ["Good", "Bad", "Report"], label="Do you think about the chat?")
315
+ feedback_gen = gr.Textbox(placeholder="Feedback chatbot",show_label=False, lines=4)
316
+ feedback_gen_submit = gr.Button("Submit Feedback")
317
+ with gr.Row(visible=False) as feedback_gen_ok:
318
+ gr.Markdown("Thank you for feedback.")
319
+
320
+ def save_up2(instruction, input,prompt,max_len,temp,top_p,choice,feedback):
321
+ save="gen"
322
+ if input == 'none' or len(input)<2:
323
+ _prompt = PROMPT_DICT['prompt_no_input'].format_map(
324
+ {'instruction': instruction, 'input': ''})
325
+ else:
326
+ _prompt = PROMPT_DICT['prompt_input'].format_map(
327
+ {'instruction': instruction, 'input': input})
328
+ prompt=_prompt+prompt
329
+ if choice=="Good":
330
+ sumbit_data(save=save,prompt=prompt,vote=1,feedback=feedback,max_len=max_len,temp=temp,top_p=top_p)
331
+ elif choice=="Bad":
332
+ sumbit_data(save=save,prompt=prompt,vote=0,feedback=feedback,max_len=max_len,temp=temp,top_p=top_p)
333
+ else:
334
+ sumbit_data(save=save,prompt=prompt,vote=3,feedback=feedback,max_len=max_len,temp=temp,top_p=top_p)
335
+ return {feedback_gen_box: gr.update(visible=False),feedback_gen_ok: gr.update(visible=True)}
336
+ def gen(instruct: str,input: str = 'none',max_gen_len=512,temperature: float = 0.1,top_p: float = 0.75):
337
+ feedback_gen_ok.update(visible=False)
338
+ _temp= instruct_generate(instruct,input,max_gen_len,temperature,top_p)
339
+ feedback_gen_box.update(visible=True)
340
+ return {outputs:_temp,feedback_gen_box: gr.update(visible=True),feedback_gen_ok: gr.update(visible=False)}
341
+ feedback_gen_submit.click(fn=save_up2, inputs=[instruction, input,outputs,max_len,temp,top_p,gen_radio,feedback_gen], outputs=[feedback_gen_box,feedback_gen_ok], queue=True)
342
+ inputs = [instruction, input, max_len, temp, top_p]
343
+ run_botton.click(fn=gen, inputs=inputs, outputs=[outputs,feedback_gen_box,feedback_gen_ok])
344
+ examples = gr.Examples(examples=["แต่งกลอนวันแม่","แต่งกลอนแปดวันแม่",'อยากลดความอ้วนทำไง','จงแต่งเรียงความเรื่องความฝันของคนรุ่นใหม่ต่อประเทศไทย'],inputs=[instruction])
345
+ with gr.Tab("ChatBot"):
346
+ with gr.Column():
347
+ chatbot = gr.Chatbot(label="Chat Message Box", placeholder="Chat Message Box",show_label=False).style(container=False)
348
+ with gr.Row():
349
+ with gr.Column(scale=0.85):
350
+ msg = gr.Textbox(placeholder="พิมพ์คำถามของคุณที่นี่... (กด enter หรือ submit หลังพิมพ์เสร็จ)",show_label=False)
351
+ with gr.Column(scale=0.15, min_width=0):
352
+ submit = gr.Button("Submit")
353
+ with gr.Column():
354
+ with gr.Column(visible=False) as feedback_chatbot_box:
355
+ chatbot_radio = gr.Radio(
356
+ ["Good", "Bad", "Report"], label="Do you think about the chat?"
357
+ )
358
+ feedback_chatbot = gr.Textbox(placeholder="Feedback chatbot",show_label=False, lines=4)
359
+ feedback_chatbot_submit = gr.Button("Submit Feedback")
360
+ with gr.Row(visible=False) as feedback_chatbot_ok:
361
+ gr.Markdown("Thank you for feedback.")
362
+ clear = gr.Button("Clear")
363
+ def save_up(history,choice,feedback):
364
+ _bot = list2prompt(history)
365
+ x=False
366
+ if choice=="Good":
367
+ x=sumbit_data(save="chat",prompt=_bot,vote=1,feedback=feedback)
368
+ elif choice=="Bad":
369
+ x=sumbit_data(save="chat",prompt=_bot,vote=0,feedback=feedback)
370
+ else:
371
+ x=sumbit_data(save="chat",prompt=_bot,vote=3,feedback=feedback)
372
+ return {feedback_chatbot_ok: gr.update(visible=True),feedback_chatbot_box: gr.update(visible=False)}
373
+ def user(user_message, history):
374
+ is_sensitive, respond_message = guardian.filter(user_message)
375
+ if is_sensitive:
376
+ bot_message = respond_message
377
+ else:
378
+ bot_message = chatgpt_chain.predict(human_input=user_message)
379
+ history.append((user_message, bot_message))
380
+ return "", history,gr.update(visible=True)
381
+ def reset():
382
+ chatgpt_chain.memory.clear()
383
+ print("clear!")
384
+ feedback_chatbot_submit.click(fn=save_up, inputs=[chatbot,chatbot_radio,feedback_chatbot], outputs=[feedback_chatbot_ok,feedback_chatbot_box,], queue=True)
385
+ clear.click(reset, None, chatbot, queue=False)
386
+ submit_event = msg.submit(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot,feedback_chatbot_box], queue=True)
387
+ submit_click_event = submit.click(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot,feedback_chatbot_box], queue=True)
388
+ with gr.Tab("ChatBot without LangChain"):
389
+ chatbot2 = gr.Chatbot()
390
+ msg2 = gr.Textbox(label="Your sentence here... (press enter to submit)")
391
+ with gr.Column():
392
+ with gr.Column(visible=False) as feedback_chatbot_box2:
393
+ chatbot_radio2 = gr.Radio(
394
+ ["Good", "Bad", "Report"], label="Do you think about the chat?"
395
+ )
396
+ feedback_chatbot2 = gr.Textbox(placeholder="Feedback chatbot",show_label=False, lines=4)
397
+ feedback_chatbot_submit2 = gr.Button("Submit Feedback")
398
+ with gr.Row(visible=False) as feedback_chatbot_ok2:
399
+ gr.Markdown("Thank you for feedback.")
400
+
401
+ def user2(user_message, history):
402
+ return "", history + [[user_message, None]]
403
+ def bot2(history):
404
+ _bot = list2prompt(history)
405
+ bot_message = gen_chatbot_old(_bot)
406
+ history[-1][1] = bot_message
407
+ return history,gr.update(visible=True)
408
+ def save_up2(history,choice,feedback):
409
+ _bot = list2prompt(history)
410
+ x=False
411
+ if choice=="Good":
412
+ x=sumbit_data(save="chat",prompt=_bot,vote=1,feedback=feedback,name_model=name_model+"-chat_old")
413
+ elif choice=="Bad":
414
+ x=sumbit_data(save="chat",prompt=_bot,vote=0,feedback=feedback,name_model=name_model+"-chat_old")
415
+ else:
416
+ x=sumbit_data(save="chat",prompt=_bot,vote=3,feedback=feedback,name_model=name_model+"-chat_old")
417
+ return {feedback_chatbot_ok2: gr.update(visible=True),feedback_chatbot_box2: gr.update(visible=False)}
418
+ msg2.submit(user2, [msg2, chatbot2], [msg2, chatbot2], queue=False).then(bot2, chatbot2, [chatbot2,feedback_chatbot_box2])
419
+ feedback_chatbot_submit2.click(fn=save_up2, inputs=[chatbot2,chatbot_radio2,feedback_chatbot2], outputs=[feedback_chatbot_ok2,feedback_chatbot_box2], queue=True)
420
+ clear2 = gr.Button("Clear")
421
+ clear2.click(lambda: None, None, chatbot2, queue=False)
422
+ demo.queue()
423
+ demo.launch()