Spaces:
Build error
Build error
chore: add bad words
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ from transformers import (
|
|
7 |
AutoModelForPreTraining,
|
8 |
AutoProcessor,
|
9 |
AutoConfig,
|
|
|
10 |
)
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
from safetensors.torch import load_file
|
@@ -18,10 +19,16 @@ assert MODEL_NAME is not None
|
|
18 |
MODEL_PATH = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors")
|
19 |
DEVICE = torch.device("cuda")
|
20 |
|
|
|
21 |
|
22 |
def fix_compiled_state_dict(state_dict: dict):
|
23 |
return {k.replace("._orig_mod.", "."): v for k, v in state_dict.items()}
|
24 |
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
def prepare_models():
|
27 |
config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
@@ -44,6 +51,7 @@ def prepare_models():
|
|
44 |
|
45 |
def demo():
|
46 |
model, processor = prepare_models()
|
|
|
47 |
|
48 |
@spaces.GPU(duration=5)
|
49 |
@torch.inference_mode()
|
@@ -83,6 +91,7 @@ def demo():
|
|
83 |
top_p=top_p,
|
84 |
eos_token_id=processor.decoder_tokenizer.eos_token_id,
|
85 |
pad_token_id=processor.decoder_tokenizer.pad_token_id,
|
|
|
86 |
)
|
87 |
elapsed = time.time() - start_time
|
88 |
|
|
|
7 |
AutoModelForPreTraining,
|
8 |
AutoProcessor,
|
9 |
AutoConfig,
|
10 |
+
PreTrainedTokenizerFast
|
11 |
)
|
12 |
from huggingface_hub import hf_hub_download
|
13 |
from safetensors.torch import load_file
|
|
|
19 |
MODEL_PATH = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors")
|
20 |
DEVICE = torch.device("cuda")
|
21 |
|
22 |
+
BAD_WORD_KEYWORDS = ["(medium)"]
|
23 |
|
24 |
def fix_compiled_state_dict(state_dict: dict):
|
25 |
return {k.replace("._orig_mod.", "."): v for k, v in state_dict.items()}
|
26 |
|
27 |
+
def get_bad_words_ids(tokenizer: PreTrainedTokenizerFast):
|
28 |
+
ids = [
|
29 |
+
[id] for token, id in tokenizer.vocab.items() if any(word in token for BAD_WORD_KEYWORDS)
|
30 |
+
]
|
31 |
+
return ids
|
32 |
|
33 |
def prepare_models():
|
34 |
config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
|
|
51 |
|
52 |
def demo():
|
53 |
model, processor = prepare_models()
|
54 |
+
ban_ids = get_bad_words_ids(processor.decoder_tokenizer)
|
55 |
|
56 |
@spaces.GPU(duration=5)
|
57 |
@torch.inference_mode()
|
|
|
91 |
top_p=top_p,
|
92 |
eos_token_id=processor.decoder_tokenizer.eos_token_id,
|
93 |
pad_token_id=processor.decoder_tokenizer.pad_token_id,
|
94 |
+
bad_words_ids=ban_ids,
|
95 |
)
|
96 |
elapsed = time.time() - start_time
|
97 |
|