p1atdev commited on
Commit
56adc5c
·
1 Parent(s): ff91c77

chore: add bad words

Browse files
Files changed (1) hide show
  1. app.py +9 -0
app.py CHANGED
@@ -7,6 +7,7 @@ from transformers import (
7
  AutoModelForPreTraining,
8
  AutoProcessor,
9
  AutoConfig,
 
10
  )
11
  from huggingface_hub import hf_hub_download
12
  from safetensors.torch import load_file
@@ -18,10 +19,16 @@ assert MODEL_NAME is not None
18
  MODEL_PATH = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors")
19
  DEVICE = torch.device("cuda")
20
 
 
21
 
22
  def fix_compiled_state_dict(state_dict: dict):
23
  return {k.replace("._orig_mod.", "."): v for k, v in state_dict.items()}
24
 
 
 
 
 
 
25
 
26
  def prepare_models():
27
  config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)
@@ -44,6 +51,7 @@ def prepare_models():
44
 
45
  def demo():
46
  model, processor = prepare_models()
 
47
 
48
  @spaces.GPU(duration=5)
49
  @torch.inference_mode()
@@ -83,6 +91,7 @@ def demo():
83
  top_p=top_p,
84
  eos_token_id=processor.decoder_tokenizer.eos_token_id,
85
  pad_token_id=processor.decoder_tokenizer.pad_token_id,
 
86
  )
87
  elapsed = time.time() - start_time
88
 
 
7
  AutoModelForPreTraining,
8
  AutoProcessor,
9
  AutoConfig,
10
+ PreTrainedTokenizerFast
11
  )
12
  from huggingface_hub import hf_hub_download
13
  from safetensors.torch import load_file
 
19
  MODEL_PATH = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors")
20
  DEVICE = torch.device("cuda")
21
 
22
+ BAD_WORD_KEYWORDS = ["(medium)"]
23
 
24
  def fix_compiled_state_dict(state_dict: dict):
25
  return {k.replace("._orig_mod.", "."): v for k, v in state_dict.items()}
26
 
27
+ def get_bad_words_ids(tokenizer: PreTrainedTokenizerFast):
28
+ ids = [
29
+ [id] for token, id in tokenizer.vocab.items() if any(word in token for BAD_WORD_KEYWORDS)
30
+ ]
31
+ return ids
32
 
33
  def prepare_models():
34
  config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)
 
51
 
52
  def demo():
53
  model, processor = prepare_models()
54
+ ban_ids = get_bad_words_ids(processor.decoder_tokenizer)
55
 
56
  @spaces.GPU(duration=5)
57
  @torch.inference_mode()
 
91
  top_p=top_p,
92
  eos_token_id=processor.decoder_tokenizer.eos_token_id,
93
  pad_token_id=processor.decoder_tokenizer.pad_token_id,
94
+ bad_words_ids=ban_ids,
95
  )
96
  elapsed = time.time() - start_time
97