wannaphong commited on
Commit
2816564
·
1 Parent(s): d7b794d
Files changed (1) hide show
  1. app.py +2 -101
app.py CHANGED
@@ -35,104 +35,15 @@ import tensorflow_text
35
  from dataclasses import dataclass
36
 
37
  import numpy as np
38
- import tensorflow as tf
39
 
40
 
41
- class Encoder(ABC):
42
- @abstractmethod
43
- def encode(self, texts: List[str]) -> np.array:
44
- """
45
- output dimension expected to be one dimension and normalized (unit vector)
46
- """
47
- ...
48
-
49
-
50
- class MUSEEncoder(Encoder):
51
- def __init__(self, model_url: str = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"):
52
- self.embed = hub.load(model_url)
53
-
54
- def encode(self, texts: List[str]) -> np.array:
55
- embeds = self.embed(texts).numpy()
56
- embeds = embeds / np.linalg.norm(embeds, axis=1).reshape(embeds.shape[0], -1)
57
- return embeds
58
-
59
-
60
-
61
-
62
-
63
- @dataclass
64
- class SensitiveTopic:
65
- name: str
66
- respond_message: str
67
- sensitivity: float = None # range from 0 to 1
68
- demonstrations: List[str] = None
69
- adhoc_embeded_demonstrations: np.array = None # dimension = [N_ADHOC, DIM]. Please kindly note that this suppose to
70
-
71
-
72
- DEFAULT_SENSITIVITY = 0.4
73
-
74
-
75
- class SensitiveTopicProtector:
76
- def __init__(
77
- self,
78
- sensitive_topics: List[SensitiveTopic],
79
- encoder: Encoder = MUSEEncoder(),
80
- default_sensitivity: float = DEFAULT_SENSITIVITY
81
- ):
82
- self.sensitive_topics = sensitive_topics
83
- self.default_sensitivity = default_sensitivity
84
- self.encoder = encoder
85
- self.topic_embeddings = self._get_topic_embeddings()
86
-
87
- def _get_topic_embeddings(self) -> Dict[str, List[np.array]]:
88
- topic_embeddings = {}
89
- for topic in self.sensitive_topics:
90
- current_topic_embeddings = None
91
- if topic.demonstrations is not None:
92
- current_topic_embeddings = self.encoder.encode(texts=topic.demonstrations) if current_topic_embeddings is None \
93
- else np.concatenate((current_topic_embeddings, self.encoder.encode(texts=topic.demonstrations)), axis=0)
94
- if topic.adhoc_embeded_demonstrations is not None:
95
- current_topic_embeddings = topic.adhoc_embeded_demonstrations if current_topic_embeddings is None \
96
- else np.concatenate((current_topic_embeddings, topic.adhoc_embeded_demonstrations), axis=0)
97
- topic_embeddings[topic.name] = current_topic_embeddings
98
- return topic_embeddings
99
-
100
- def filter(self, text: str) -> Tuple[bool, str]:
101
- is_sensitive, respond_message = False, None
102
- text_embedding = self.encoder.encode([text,])
103
- for topic in self.sensitive_topics:
104
- risk_scores = np.einsum('ik,jk->j', text_embedding, self.topic_embeddings[topic.name])
105
- max_risk_score = np.max(risk_scores)
106
- if topic.sensitivity:
107
- if max_risk_score > (1.0 - topic.sensitivity):
108
- return True, topic.respond_message
109
- continue
110
- if max_risk_score > (1.0 - self.default_sensitivity):
111
- return True, topic.respond_message
112
- return is_sensitive, respond_message
113
-
114
- @classmethod
115
- def fromRaw(cls, raw_sensitive_topics: List[Dict], encoder: Encoder = MUSEEncoder(), default_sensitivity: float = DEFAULT_SENSITIVITY):
116
- sensitive_topics = [SensitiveTopic(**topic) for topic in raw_sensitive_topics]
117
- return cls(sensitive_topics=sensitive_topics, encoder=encoder, default_sensitivity=default_sensitivity)
118
-
119
-
120
-
121
- f = open("sensitive_topics.pkl", "rb")
122
- sensitive_topics = pickle.load(f)
123
- f.close()
124
-
125
- guardian = SensitiveTopicProtector.fromRaw(sensitive_topics)
126
-
127
-
128
- name_model = "pythainlp/wangchanglm-7.5B-sft-en-8bit-sharded"
129
  model = AutoModelForCausalLM.from_pretrained(
130
  name_model,
131
  device_map="auto",
132
  torch_dtype=torch.bfloat16,
133
  offload_folder="./",
134
  low_cpu_mem_usage=True,
135
- load_in_8bit=False
136
  )
137
  tokenizer = AutoTokenizer.from_pretrained("facebook/xglm-7.5B")
138
 
@@ -238,9 +149,6 @@ def gen_instruct(text,max_new_tokens=512,top_p=0.95,temperature=0.9,top_k=50):
238
  return tokenizer.decode(output_tokens[0][len(batch["input_ids"][0]):], skip_special_tokens=True)
239
 
240
  def gen_chatbot_old(text):
241
- is_sensitive, respond_message = guardian.filter(text)
242
- if is_sensitive:
243
- return respond_message
244
 
245
  batch = tokenizer(text, return_tensors="pt")
246
  #context_tokens = tokenizer(text, add_special_tokens=False)['input_ids']
@@ -278,9 +186,6 @@ def instruct_generate(
278
  temperature: float = 0.1,
279
  top_p: float = 0.75,
280
  ):
281
- is_sensitive, respond_message = guardian.filter(instruct)
282
- if is_sensitive:
283
- return respond_message
284
 
285
  if input == 'none' or len(input)<2:
286
  prompt = PROMPT_DICT['prompt_no_input'].format_map(
@@ -391,11 +296,7 @@ with gr.Blocks(height=900) as demo:
391
  x=sumbit_data(save="chat",prompt=_bot,vote=3,feedback=feedback)
392
  return {feedback_chatbot_ok: gr.update(visible=True),feedback_chatbot_box: gr.update(visible=False)}
393
  def user(user_message, history):
394
- is_sensitive, respond_message = guardian.filter(user_message)
395
- if is_sensitive:
396
- bot_message = respond_message
397
- else:
398
- bot_message = chatgpt_chain.predict(human_input=user_message)
399
  history.append((user_message, bot_message))
400
  return "", history,gr.update(visible=True)
401
  def reset():
 
35
  from dataclasses import dataclass
36
 
37
  import numpy as np
 
38
 
39
 
40
+ name_model = "pythainlp/wangchanglm-7.5B-sft-en-sharded"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  model = AutoModelForCausalLM.from_pretrained(
42
  name_model,
43
  device_map="auto",
44
  torch_dtype=torch.bfloat16,
45
  offload_folder="./",
46
  low_cpu_mem_usage=True,
 
47
  )
48
  tokenizer = AutoTokenizer.from_pretrained("facebook/xglm-7.5B")
49
 
 
149
  return tokenizer.decode(output_tokens[0][len(batch["input_ids"][0]):], skip_special_tokens=True)
150
 
151
  def gen_chatbot_old(text):
 
 
 
152
 
153
  batch = tokenizer(text, return_tensors="pt")
154
  #context_tokens = tokenizer(text, add_special_tokens=False)['input_ids']
 
186
  temperature: float = 0.1,
187
  top_p: float = 0.75,
188
  ):
 
 
 
189
 
190
  if input == 'none' or len(input)<2:
191
  prompt = PROMPT_DICT['prompt_no_input'].format_map(
 
296
  x=sumbit_data(save="chat",prompt=_bot,vote=3,feedback=feedback)
297
  return {feedback_chatbot_ok: gr.update(visible=True),feedback_chatbot_box: gr.update(visible=False)}
298
  def user(user_message, history):
299
+ bot_message = chatgpt_chain.predict(human_input=user_message)
 
 
 
 
300
  history.append((user_message, bot_message))
301
  return "", history,gr.update(visible=True)
302
  def reset():