Spaces:
Running
Running
Commit
·
f747801
1
Parent(s):
97bee26
implemented core generation + detection
Browse files- .gitattributes +2 -0
- Dockerfile +2 -2
- run.py +2 -2
- wm_interactive/core/detector.py +2 -2
- wm_interactive/core/generator.py +6 -2
- wm_interactive/core/hashing.py +1 -1
- wm_interactive/core/main.py +11 -14
- wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/0ad5ecc2035b7031b88afb544ee95e2d49baa484.lock +0 -0
- wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/36293b6099200eb8aeb55ae2c01bca2ba46d80d0.lock +0 -0
- wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/44719d2e365acac0637fd25a3acf46494ca45940.lock +0 -0
- wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/5af571cbf074e6d21a03528d2330792e532ca608f24ac70a143f6b369968ab8c.lock +0 -0
- wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/69503b13f727ba3812b6803e97442a6de05ef5eb.lock +0 -0
- wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/8c7b22013909450429303ed10be4398bd63f5457.lock +0 -0
- wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/da6c4d71a43aa7e6f785bdbb28ea5025438a73fa.lock +0 -0
- wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/f922b1797f0c88e71addc8393787831f2477a4bd.lock +0 -0
- wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/.no_exist/e2c3f7557efbdec707ae3a336371d169783f1da1/added_tokens.json +0 -0
- wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/0ad5ecc2035b7031b88afb544ee95e2d49baa484 +3 -0
- wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/36293b6099200eb8aeb55ae2c01bca2ba46d80d0 +3 -0
- wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/44719d2e365acac0637fd25a3acf46494ca45940 +3 -0
- wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/5af571cbf074e6d21a03528d2330792e532ca608f24ac70a143f6b369968ab8c +3 -0
- wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/69503b13f727ba3812b6803e97442a6de05ef5eb +3 -0
- wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/8c7b22013909450429303ed10be4398bd63f5457 +3 -0
- wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/da6c4d71a43aa7e6f785bdbb28ea5025438a73fa +3 -0
- wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/f922b1797f0c88e71addc8393787831f2477a4bd +3 -0
- wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/refs/main +3 -0
- wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/config.json +1 -0
- wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/generation_config.json +1 -0
- wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/merges.txt +1 -0
- wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/model.safetensors +1 -0
- wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/special_tokens_map.json +1 -0
- wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/tokenizer.json +1 -0
- wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/tokenizer_config.json +1 -0
- wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/vocab.json +1 -0
- wm_interactive/static/styles.css +14 -0
- wm_interactive/templates/index.html +132 -2
- wm_interactive/web/app.py +110 -12
- wm_interactive/web/utils.py +19 -1
.gitattributes
CHANGED
@@ -36,3 +36,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
36 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
37 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
38 |
static/ia_gen_droits_auteur.pdf filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
36 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
37 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
38 |
static/ia_gen_droits_auteur.pdf filter=lfs diff=lfs merge=lfs -text
|
39 |
+
wm_interactive/static/hf_cache/** filter=lfs diff=lfs merge=lfs -text
|
40 |
+
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/5af571cbf074e6d21a03528d2330792e532ca608f24ac70a143f6b369968ab8c filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
CHANGED
@@ -7,11 +7,11 @@ COPY requirements.txt .
|
|
7 |
RUN pip install --no-cache-dir -r requirements.txt
|
8 |
|
9 |
# Copy the rest of the application
|
10 |
-
COPY
|
11 |
COPY run.py .
|
12 |
|
13 |
# Create necessary directories
|
14 |
-
RUN mkdir -p
|
15 |
|
16 |
# Set environment variables
|
17 |
ENV PYTHONPATH=/app
|
|
|
7 |
RUN pip install --no-cache-dir -r requirements.txt
|
8 |
|
9 |
# Copy the rest of the application
|
10 |
+
COPY wm_interactive/ ./wm_interactive/
|
11 |
COPY run.py .
|
12 |
|
13 |
# Create necessary directories
|
14 |
+
RUN mkdir -p wm_interactive/static/hf_cache
|
15 |
|
16 |
# Set environment variables
|
17 |
ENV PYTHONPATH=/app
|
run.py
CHANGED
@@ -2,8 +2,8 @@
|
|
2 |
Main entry point for the watermark detection application.
|
3 |
Run with: python run.py
|
4 |
|
5 |
-
docker build -t wm-
|
6 |
-
docker run -p 7860:7860 wm-
|
7 |
"""
|
8 |
|
9 |
from wm_interactive.web.app import app
|
|
|
2 |
Main entry point for the watermark detection application.
|
3 |
Run with: python run.py
|
4 |
|
5 |
+
docker build -t wm-interactive .
|
6 |
+
docker run -p 7860:7860 wm-interactive
|
7 |
"""
|
8 |
|
9 |
from wm_interactive.web.app import app
|
wm_interactive/core/detector.py
CHANGED
@@ -159,7 +159,7 @@ class MarylandDetector(WmDetector):
|
|
159 |
tokenizer: AutoTokenizer,
|
160 |
ngram: int = 1,
|
161 |
seed: int = 0,
|
162 |
-
gamma: float = 0.
|
163 |
delta: float = 1.0,
|
164 |
**kwargs):
|
165 |
super().__init__(tokenizer, ngram, seed, **kwargs)
|
@@ -194,7 +194,7 @@ class MarylandDetectorZ(WmDetector):
|
|
194 |
tokenizer: AutoTokenizer,
|
195 |
ngram: int = 1,
|
196 |
seed: int = 0,
|
197 |
-
gamma: float = 0.
|
198 |
delta: float = 1.0,
|
199 |
**kwargs):
|
200 |
super().__init__(tokenizer, ngram, seed, **kwargs)
|
|
|
159 |
tokenizer: AutoTokenizer,
|
160 |
ngram: int = 1,
|
161 |
seed: int = 0,
|
162 |
+
gamma: float = 0.5,
|
163 |
delta: float = 1.0,
|
164 |
**kwargs):
|
165 |
super().__init__(tokenizer, ngram, seed, **kwargs)
|
|
|
194 |
tokenizer: AutoTokenizer,
|
195 |
ngram: int = 1,
|
196 |
seed: int = 0,
|
197 |
+
gamma: float = 0.5,
|
198 |
delta: float = 1.0,
|
199 |
**kwargs):
|
200 |
super().__init__(tokenizer, ngram, seed, **kwargs)
|
wm_interactive/core/generator.py
CHANGED
@@ -59,13 +59,17 @@ class WmGenerator():
|
|
59 |
next_tok = self.sample_next(outputs.logits[:, -1, :], aux, temperature, top_p)
|
60 |
tokens[0, cur_pos] = torch.where(input_text_mask[0, cur_pos], tokens[0, cur_pos], next_tok)
|
61 |
prev_pos = cur_pos
|
|
|
|
|
62 |
|
63 |
# cut to max gen len
|
64 |
t = tokens[0, :prompt_size + max_gen_len].tolist()
|
65 |
# cut to eos tok if any
|
66 |
finish_reason = 'length'
|
67 |
try:
|
68 |
-
|
|
|
|
|
69 |
finish_reason = 'eos'
|
70 |
except ValueError:
|
71 |
pass
|
@@ -158,7 +162,7 @@ class MarylandGenerator(WmGenerator):
|
|
158 |
"""
|
159 |
def __init__(self,
|
160 |
*args,
|
161 |
-
gamma: float = 0.
|
162 |
delta: float = 1.0,
|
163 |
test_mul: float = 0,
|
164 |
**kwargs
|
|
|
59 |
next_tok = self.sample_next(outputs.logits[:, -1, :], aux, temperature, top_p)
|
60 |
tokens[0, cur_pos] = torch.where(input_text_mask[0, cur_pos], tokens[0, cur_pos], next_tok)
|
61 |
prev_pos = cur_pos
|
62 |
+
if next_tok == self.eos_id:
|
63 |
+
break
|
64 |
|
65 |
# cut to max gen len
|
66 |
t = tokens[0, :prompt_size + max_gen_len].tolist()
|
67 |
# cut to eos tok if any
|
68 |
finish_reason = 'length'
|
69 |
try:
|
70 |
+
find_eos = t[prompt_size:].index(self.eos_id)
|
71 |
+
if find_eos:
|
72 |
+
t = t[: prompt_size+find_eos]
|
73 |
finish_reason = 'eos'
|
74 |
except ValueError:
|
75 |
pass
|
|
|
162 |
"""
|
163 |
def __init__(self,
|
164 |
*args,
|
165 |
+
gamma: float = 0.5,
|
166 |
delta: float = 1.0,
|
167 |
test_mul: float = 0,
|
168 |
**kwargs
|
wm_interactive/core/hashing.py
CHANGED
@@ -10,4 +10,4 @@ def get_seed_rng(
|
|
10 |
"""
|
11 |
for ii in input_ids:
|
12 |
start = (start * salt + ii) % (2 ** 64 - 1)
|
13 |
-
return start
|
|
|
10 |
"""
|
11 |
for ii in input_ids:
|
12 |
start = (start * salt + ii) % (2 ** 64 - 1)
|
13 |
+
return int(start)
|
wm_interactive/core/main.py
CHANGED
@@ -28,12 +28,12 @@ model_names = {
|
|
28 |
CACHE_DIR = "wm_interactive/static/hf_cache"
|
29 |
|
30 |
|
31 |
-
def load_prompts(json_path: str, prompt_type: str = "
|
32 |
"""Load prompts from a JSON file.
|
33 |
|
34 |
Args:
|
35 |
json_path: Path to the JSON file
|
36 |
-
prompt_type: Type of prompt dataset (alpaca)
|
37 |
nsamples: Number of samples to load (if None, load all)
|
38 |
|
39 |
Returns:
|
@@ -46,10 +46,13 @@ def load_prompts(json_path: str, prompt_type: str = "alpaca", nsamples: int = No
|
|
46 |
data = json.load(f)
|
47 |
|
48 |
if prompt_type == "alpaca":
|
49 |
-
prompts = [{"instruction":
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
53 |
else:
|
54 |
raise ValueError(f"Prompt type {prompt_type} not supported")
|
55 |
|
@@ -93,7 +96,7 @@ def get_args_parser():
|
|
93 |
# prompts parameters
|
94 |
parser.add_argument('--prompt_path', type=str, default=None,
|
95 |
help='Path to the prompt dataset. Required if --prompt is not provided')
|
96 |
-
parser.add_argument('--prompt_type', type=str, default="
|
97 |
help='Type of prompt dataset. Only used if --prompt_path is provided')
|
98 |
parser.add_argument('--prompt', type=str, nargs='+', default=None,
|
99 |
help='List of prompts to use. If not provided, prompts will be loaded from --prompt_path')
|
@@ -148,17 +151,11 @@ def main(args):
|
|
148 |
# Load tokenizer and model
|
149 |
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
150 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
151 |
-
print(f"Using device: {device}")
|
152 |
|
153 |
model = AutoModelForCausalLM.from_pretrained(
|
154 |
model_name,
|
155 |
-
device_map=device,
|
156 |
-
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
157 |
cache_dir=CACHE_DIR
|
158 |
-
)
|
159 |
-
model = model.eval()
|
160 |
-
for param in model.parameters():
|
161 |
-
param.requires_grad = False
|
162 |
|
163 |
# build watermark generator
|
164 |
if args.method == "none":
|
|
|
28 |
CACHE_DIR = "wm_interactive/static/hf_cache"
|
29 |
|
30 |
|
31 |
+
def load_prompts(json_path: str, prompt_type: str = "smollm", nsamples: int = None) -> list[dict]:
|
32 |
"""Load prompts from a JSON file.
|
33 |
|
34 |
Args:
|
35 |
json_path: Path to the JSON file
|
36 |
+
prompt_type: Type of prompt dataset (alpaca, smollm)
|
37 |
nsamples: Number of samples to load (if None, load all)
|
38 |
|
39 |
Returns:
|
|
|
46 |
data = json.load(f)
|
47 |
|
48 |
if prompt_type == "alpaca":
|
49 |
+
prompts = [{"instruction": item["instruction"]} for item in data]
|
50 |
+
elif prompt_type == "smollm":
|
51 |
+
prompts = []
|
52 |
+
for item in data:
|
53 |
+
prompt = "<|im_start|>system\nYou are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>\n"
|
54 |
+
prompt += f"<|im_start|>user\n{item['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
55 |
+
prompts.append({"instruction": prompt})
|
56 |
else:
|
57 |
raise ValueError(f"Prompt type {prompt_type} not supported")
|
58 |
|
|
|
96 |
# prompts parameters
|
97 |
parser.add_argument('--prompt_path', type=str, default=None,
|
98 |
help='Path to the prompt dataset. Required if --prompt is not provided')
|
99 |
+
parser.add_argument('--prompt_type', type=str, default="smollm",
|
100 |
help='Type of prompt dataset. Only used if --prompt_path is provided')
|
101 |
parser.add_argument('--prompt', type=str, nargs='+', default=None,
|
102 |
help='List of prompts to use. If not provided, prompts will be loaded from --prompt_path')
|
|
|
151 |
# Load tokenizer and model
|
152 |
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
153 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
154 |
|
155 |
model = AutoModelForCausalLM.from_pretrained(
|
156 |
model_name,
|
|
|
|
|
157 |
cache_dir=CACHE_DIR
|
158 |
+
).to(device)
|
|
|
|
|
|
|
159 |
|
160 |
# build watermark generator
|
161 |
if args.method == "none":
|
wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/0ad5ecc2035b7031b88afb544ee95e2d49baa484.lock
ADDED
File without changes
|
wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/36293b6099200eb8aeb55ae2c01bca2ba46d80d0.lock
ADDED
File without changes
|
wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/44719d2e365acac0637fd25a3acf46494ca45940.lock
ADDED
File without changes
|
wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/5af571cbf074e6d21a03528d2330792e532ca608f24ac70a143f6b369968ab8c.lock
ADDED
File without changes
|
wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/69503b13f727ba3812b6803e97442a6de05ef5eb.lock
ADDED
File without changes
|
wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/8c7b22013909450429303ed10be4398bd63f5457.lock
ADDED
File without changes
|
wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/da6c4d71a43aa7e6f785bdbb28ea5025438a73fa.lock
ADDED
File without changes
|
wm_interactive/static/hf_cache/.locks/models--HuggingFaceTB--SmolLM2-135M-Instruct/f922b1797f0c88e71addc8393787831f2477a4bd.lock
ADDED
File without changes
|
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/.no_exist/e2c3f7557efbdec707ae3a336371d169783f1da1/added_tokens.json
ADDED
File without changes
|
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/0ad5ecc2035b7031b88afb544ee95e2d49baa484
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:82b84012e3add4d01d12ba14442026e49b8cbbaead1f79ecf3d919784f82dc79
|
3 |
+
size 800662
|
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/36293b6099200eb8aeb55ae2c01bca2ba46d80d0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8eb740e8bbe4cff95ea7b4588d17a2432deb16e8075bc5828ff7ba9be94d982a
|
3 |
+
size 861
|
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/44719d2e365acac0637fd25a3acf46494ca45940
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2b7379f3ae813529281a5c602bc5a11c1d4e0a99107aaa597fe936c1e813ca52
|
3 |
+
size 655
|
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/5af571cbf074e6d21a03528d2330792e532ca608f24ac70a143f6b369968ab8c
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5af571cbf074e6d21a03528d2330792e532ca608f24ac70a143f6b369968ab8c
|
3 |
+
size 269060552
|
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/69503b13f727ba3812b6803e97442a6de05ef5eb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0b54e8aa4e53d5383e2e4bc635a56b43f9647f7b13832d5d9ecd8f82dac4f510
|
3 |
+
size 466391
|
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/8c7b22013909450429303ed10be4398bd63f5457
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4ec77d44f62efeb38d7e044a1db318f6a939438425312dfa333b8382dbad98df
|
3 |
+
size 3764
|
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/da6c4d71a43aa7e6f785bdbb28ea5025438a73fa
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:87b916edaaab66b3899b9d0dd0752727dff6666686da0504d89ae0a6e055a013
|
3 |
+
size 132
|
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/blobs/f922b1797f0c88e71addc8393787831f2477a4bd
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9ca9acddb6525a194ec8ac7a87f24fbba7232a9a15ffa1af0c1224fcd888e47c
|
3 |
+
size 2104556
|
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/refs/main
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:71a184f20b0fe5c1a9407ed75fa9633b681779c7f1a5ca478f22fdff69a6c7ab
|
3 |
+
size 40
|
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
../../blobs/36293b6099200eb8aeb55ae2c01bca2ba46d80d0
|
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/generation_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
../../blobs/da6c4d71a43aa7e6f785bdbb28ea5025438a73fa
|
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/merges.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
../../blobs/69503b13f727ba3812b6803e97442a6de05ef5eb
|
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/model.safetensors
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
../../blobs/5af571cbf074e6d21a03528d2330792e532ca608f24ac70a143f6b369968ab8c
|
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
../../blobs/44719d2e365acac0637fd25a3acf46494ca45940
|
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/tokenizer.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
../../blobs/f922b1797f0c88e71addc8393787831f2477a4bd
|
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
../../blobs/8c7b22013909450429303ed10be4398bd63f5457
|
wm_interactive/static/hf_cache/models--HuggingFaceTB--SmolLM2-135M-Instruct/snapshots/e2c3f7557efbdec707ae3a336371d169783f1da1/vocab.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
../../blobs/0ad5ecc2035b7031b88afb544ee95e2d49baa484
|
wm_interactive/static/styles.css
CHANGED
@@ -29,9 +29,23 @@ h1 {
|
|
29 |
resize: none;
|
30 |
font-size: 14px;
|
31 |
line-height: 1.5;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
height: 200px;
|
33 |
}
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
.token-display {
|
36 |
margin: 20px 0;
|
37 |
padding: 10px;
|
|
|
29 |
resize: none;
|
30 |
font-size: 14px;
|
31 |
line-height: 1.5;
|
32 |
+
margin-bottom: 10px;
|
33 |
+
}
|
34 |
+
|
35 |
+
.input-section #prompt_text {
|
36 |
+
height: 100px;
|
37 |
+
}
|
38 |
+
|
39 |
+
.input-section #user_text {
|
40 |
height: 200px;
|
41 |
}
|
42 |
|
43 |
+
.button-container {
|
44 |
+
display: flex;
|
45 |
+
gap: 10px;
|
46 |
+
margin-bottom: 10px;
|
47 |
+
}
|
48 |
+
|
49 |
.token-display {
|
50 |
margin: 20px 0;
|
51 |
padding: 10px;
|
wm_interactive/templates/index.html
CHANGED
@@ -56,8 +56,14 @@
|
|
56 |
|
57 |
<!-- Input Form -->
|
58 |
<div class="input-section">
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
<textarea id="user_text"
|
60 |
-
placeholder="
|
61 |
</div>
|
62 |
|
63 |
<!-- Token Display -->
|
@@ -87,7 +93,11 @@
|
|
87 |
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"></script>
|
88 |
<script>
|
89 |
let debounceTimeout = null;
|
|
|
90 |
const textarea = document.getElementById('user_text');
|
|
|
|
|
|
|
91 |
const tokenDisplay = document.getElementById('tokenDisplay');
|
92 |
const tokenCount = document.getElementById('tokenCount');
|
93 |
const scoredTokens = document.getElementById('scoredTokens');
|
@@ -98,6 +108,122 @@
|
|
98 |
const ngramInput = document.getElementById('ngram');
|
99 |
const detectorTypeSelect = document.getElementById('detectorType');
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
async function updateTokenization() {
|
102 |
const text = textarea.value;
|
103 |
try {
|
@@ -210,7 +336,11 @@
|
|
210 |
document.addEventListener('keydown', function(e) {
|
211 |
if ((e.metaKey || e.ctrlKey) && e.key === 'Enter') {
|
212 |
e.preventDefault();
|
213 |
-
|
|
|
|
|
|
|
|
|
214 |
}
|
215 |
});
|
216 |
|
|
|
56 |
|
57 |
<!-- Input Form -->
|
58 |
<div class="input-section">
|
59 |
+
<textarea id="prompt_text"
|
60 |
+
placeholder="Enter your prompt here to generate text with the model..."></textarea>
|
61 |
+
<div class="button-container">
|
62 |
+
<button class="btn btn-primary" id="generateBtn">Generate</button>
|
63 |
+
<button class="btn btn-secondary" id="stopBtn" disabled>Stop</button>
|
64 |
+
</div>
|
65 |
<textarea id="user_text"
|
66 |
+
placeholder="Generated text will appear here. Replace or edit this text to see how watermark detection works."></textarea>
|
67 |
</div>
|
68 |
|
69 |
<!-- Token Display -->
|
|
|
93 |
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"></script>
|
94 |
<script>
|
95 |
let debounceTimeout = null;
|
96 |
+
let eventSource = null;
|
97 |
const textarea = document.getElementById('user_text');
|
98 |
+
const promptArea = document.getElementById('prompt_text');
|
99 |
+
const generateBtn = document.getElementById('generateBtn');
|
100 |
+
const stopBtn = document.getElementById('stopBtn');
|
101 |
const tokenDisplay = document.getElementById('tokenDisplay');
|
102 |
const tokenCount = document.getElementById('tokenCount');
|
103 |
const scoredTokens = document.getElementById('scoredTokens');
|
|
|
108 |
const ngramInput = document.getElementById('ngram');
|
109 |
const detectorTypeSelect = document.getElementById('detectorType');
|
110 |
|
111 |
+
function startGeneration() {
|
112 |
+
const prompt = promptArea.value.trim();
|
113 |
+
if (!prompt) {
|
114 |
+
alert('Please enter a prompt first.');
|
115 |
+
return;
|
116 |
+
}
|
117 |
+
|
118 |
+
generateBtn.disabled = true;
|
119 |
+
stopBtn.disabled = false;
|
120 |
+
textarea.value = '';
|
121 |
+
|
122 |
+
// Get current parameters
|
123 |
+
const params = {
|
124 |
+
detector_type: detectorTypeSelect.value,
|
125 |
+
seed: parseInt(seedInput.value) || 0,
|
126 |
+
ngram: parseInt(ngramInput.value) || 1
|
127 |
+
};
|
128 |
+
|
129 |
+
// Create headers for SSE
|
130 |
+
const headers = new Headers({
|
131 |
+
'Content-Type': 'application/json',
|
132 |
+
'Accept': 'text/event-stream',
|
133 |
+
});
|
134 |
+
|
135 |
+
// Start fetch request
|
136 |
+
fetch('/generate', {
|
137 |
+
method: 'POST',
|
138 |
+
headers: headers,
|
139 |
+
body: JSON.stringify({
|
140 |
+
prompt: prompt,
|
141 |
+
params: params
|
142 |
+
})
|
143 |
+
}).then(response => {
|
144 |
+
const reader = response.body.getReader();
|
145 |
+
const decoder = new TextDecoder();
|
146 |
+
let buffer = '';
|
147 |
+
|
148 |
+
function processText(text) {
|
149 |
+
const lines = text.split('\n');
|
150 |
+
|
151 |
+
for (const line of lines) {
|
152 |
+
if (line.startsWith('data: ')) {
|
153 |
+
try {
|
154 |
+
const data = JSON.parse(line.slice(6));
|
155 |
+
|
156 |
+
if (data.error) {
|
157 |
+
alert('Error: ' + data.error);
|
158 |
+
stopGeneration();
|
159 |
+
return;
|
160 |
+
}
|
161 |
+
|
162 |
+
if (data.token) {
|
163 |
+
// Append new token to existing text
|
164 |
+
textarea.value += data.token;
|
165 |
+
updateTokenization();
|
166 |
+
}
|
167 |
+
|
168 |
+
if (data.text) {
|
169 |
+
// Final text (only used if something went wrong with streaming)
|
170 |
+
textarea.value = data.text;
|
171 |
+
updateTokenization();
|
172 |
+
}
|
173 |
+
|
174 |
+
if (data.done) {
|
175 |
+
stopGeneration();
|
176 |
+
}
|
177 |
+
} catch (e) {
|
178 |
+
console.error('Error parsing SSE data:', e);
|
179 |
+
}
|
180 |
+
}
|
181 |
+
}
|
182 |
+
}
|
183 |
+
|
184 |
+
function pump() {
|
185 |
+
return reader.read().then(({value, done}) => {
|
186 |
+
if (done) {
|
187 |
+
if (buffer.length > 0) {
|
188 |
+
processText(buffer);
|
189 |
+
}
|
190 |
+
return;
|
191 |
+
}
|
192 |
+
|
193 |
+
buffer += decoder.decode(value, {stream: true});
|
194 |
+
const lines = buffer.split('\n\n');
|
195 |
+
buffer = lines.pop();
|
196 |
+
|
197 |
+
for (const line of lines) {
|
198 |
+
processText(line);
|
199 |
+
}
|
200 |
+
|
201 |
+
return pump();
|
202 |
+
});
|
203 |
+
}
|
204 |
+
|
205 |
+
return pump();
|
206 |
+
})
|
207 |
+
.catch(error => {
|
208 |
+
console.error('Error:', error);
|
209 |
+
alert('Error: Failed to generate text');
|
210 |
+
})
|
211 |
+
.finally(() => {
|
212 |
+
generateBtn.disabled = false;
|
213 |
+
stopBtn.disabled = true;
|
214 |
+
});
|
215 |
+
}
|
216 |
+
|
217 |
+
function stopGeneration() {
|
218 |
+
generateBtn.disabled = false;
|
219 |
+
stopBtn.disabled = true;
|
220 |
+
}
|
221 |
+
|
222 |
+
// Add event listeners for generation buttons
|
223 |
+
generateBtn.addEventListener('click', startGeneration);
|
224 |
+
stopBtn.addEventListener('click', stopGeneration);
|
225 |
+
|
226 |
+
// Rest of the existing JavaScript code...
|
227 |
async function updateTokenization() {
|
228 |
const text = textarea.value;
|
229 |
try {
|
|
|
336 |
document.addEventListener('keydown', function(e) {
|
337 |
if ((e.metaKey || e.ctrlKey) && e.key === 'Enter') {
|
338 |
e.preventDefault();
|
339 |
+
if (document.activeElement === promptArea) {
|
340 |
+
generateBtn.click();
|
341 |
+
} else {
|
342 |
+
applyParamsBtn.click();
|
343 |
+
}
|
344 |
}
|
345 |
});
|
346 |
|
wm_interactive/web/app.py
CHANGED
@@ -2,11 +2,14 @@
|
|
2 |
Main Flask application for the watermark detection web interface.
|
3 |
"""
|
4 |
|
5 |
-
from flask import Flask, render_template, request, jsonify
|
6 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
7 |
|
8 |
from ..core.detector import MarylandDetector, MarylandDetectorZ, OpenaiDetector, OpenaiDetectorZ
|
9 |
-
from .
|
|
|
10 |
|
11 |
CACHE_DIR = "wm_interactive/static/hf_cache"
|
12 |
|
@@ -21,6 +24,12 @@ def convert_nan_to_null(obj):
|
|
21 |
return [convert_nan_to_null(item) for item in obj]
|
22 |
return obj
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
def create_detector(detector_type, tokenizer, **kwargs):
|
25 |
"""Create a detector instance based on the specified type."""
|
26 |
detector_map = {
|
@@ -32,16 +41,9 @@ def create_detector(detector_type, tokenizer, **kwargs):
|
|
32 |
|
33 |
# Validate and set default values for parameters
|
34 |
if 'seed' in kwargs:
|
35 |
-
|
36 |
-
kwargs['seed'] = int(kwargs['seed'])
|
37 |
-
except (ValueError, TypeError):
|
38 |
-
kwargs['seed'] = 0
|
39 |
-
|
40 |
if 'ngram' in kwargs:
|
41 |
-
|
42 |
-
kwargs['ngram'] = int(kwargs['ngram'])
|
43 |
-
except (ValueError, TypeError):
|
44 |
-
kwargs['ngram'] = 1
|
45 |
|
46 |
detector_class = detector_map.get(detector_type, MarylandDetector)
|
47 |
return detector_class(tokenizer=tokenizer, **kwargs)
|
@@ -58,7 +60,10 @@ def create_app():
|
|
58 |
# model_id = "meta-llama/Llama-3.2-1B-Instruct"
|
59 |
model_id = "HuggingFaceTB/SmolLM2-135M-Instruct"
|
60 |
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=CACHE_DIR)
|
61 |
-
model = AutoModelForCausalLM.from_pretrained(model_id)
|
|
|
|
|
|
|
62 |
|
63 |
@app.route("/", methods=["GET"])
|
64 |
def index():
|
@@ -132,6 +137,99 @@ def create_app():
|
|
132 |
app.logger.error(f'Server error: {str(e)}')
|
133 |
return jsonify({'error': f'Server error: {str(e)}'}), 500
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
return app
|
136 |
|
137 |
app = create_app()
|
|
|
2 |
Main Flask application for the watermark detection web interface.
|
3 |
"""
|
4 |
|
5 |
+
from flask import Flask, render_template, request, jsonify, Response, stream_with_context
|
6 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
7 |
+
import torch
|
8 |
+
import json
|
9 |
|
10 |
from ..core.detector import MarylandDetector, MarylandDetectorZ, OpenaiDetector, OpenaiDetectorZ
|
11 |
+
from ..core.generator import WmGenerator, OpenaiGenerator, MarylandGenerator
|
12 |
+
from .utils import get_token_details, template_prompt
|
13 |
|
14 |
CACHE_DIR = "wm_interactive/static/hf_cache"
|
15 |
|
|
|
24 |
return [convert_nan_to_null(item) for item in obj]
|
25 |
return obj
|
26 |
|
27 |
+
def set_to_int(value, default_value = None):
|
28 |
+
try:
|
29 |
+
return int(value)
|
30 |
+
except (ValueError, TypeError):
|
31 |
+
return default_value
|
32 |
+
|
33 |
def create_detector(detector_type, tokenizer, **kwargs):
|
34 |
"""Create a detector instance based on the specified type."""
|
35 |
detector_map = {
|
|
|
41 |
|
42 |
# Validate and set default values for parameters
|
43 |
if 'seed' in kwargs:
|
44 |
+
kwargs['seed'] = set_to_int(kwargs['seed'], default_value = 0)
|
|
|
|
|
|
|
|
|
45 |
if 'ngram' in kwargs:
|
46 |
+
kwargs['ngram'] = set_to_int(kwargs['ngram'], default_value = 1)
|
|
|
|
|
|
|
47 |
|
48 |
detector_class = detector_map.get(detector_type, MarylandDetector)
|
49 |
return detector_class(tokenizer=tokenizer, **kwargs)
|
|
|
60 |
# model_id = "meta-llama/Llama-3.2-1B-Instruct"
|
61 |
model_id = "HuggingFaceTB/SmolLM2-135M-Instruct"
|
62 |
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=CACHE_DIR)
|
63 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=CACHE_DIR).to("cuda" if torch.cuda.is_available() else "cpu")
|
64 |
+
|
65 |
+
# Create default generator
|
66 |
+
generator = MarylandGenerator(model, tokenizer, ngram=1, seed=0)
|
67 |
|
68 |
@app.route("/", methods=["GET"])
|
69 |
def index():
|
|
|
137 |
app.logger.error(f'Server error: {str(e)}')
|
138 |
return jsonify({'error': f'Server error: {str(e)}'}), 500
|
139 |
|
140 |
+
@app.route("/generate", methods=["POST"])
|
141 |
+
def generate():
|
142 |
+
try:
|
143 |
+
data = request.get_json()
|
144 |
+
if not data:
|
145 |
+
return jsonify({'error': 'No JSON data received'}), 400
|
146 |
+
|
147 |
+
prompt = template_prompt(data.get('prompt', ''))
|
148 |
+
params = data.get('params', {})
|
149 |
+
|
150 |
+
def generate_stream():
|
151 |
+
try:
|
152 |
+
# Create generator with correct parameters
|
153 |
+
generator_class = OpenaiGenerator if params.get('detector_type') == 'openai' else MarylandGenerator
|
154 |
+
generator = generator_class(
|
155 |
+
model=model,
|
156 |
+
tokenizer=tokenizer,
|
157 |
+
ngram=set_to_int(params.get('ngram', 1)),
|
158 |
+
seed=set_to_int(params.get('seed', 0))
|
159 |
+
)
|
160 |
+
|
161 |
+
# Get special tokens to filter out
|
162 |
+
special_tokens = {
|
163 |
+
'<|im_start|>', '<|im_end|>',
|
164 |
+
tokenizer.pad_token, tokenizer.eos_token,
|
165 |
+
tokenizer.bos_token if hasattr(tokenizer, 'bos_token') else None,
|
166 |
+
tokenizer.sep_token if hasattr(tokenizer, 'sep_token') else None
|
167 |
+
}
|
168 |
+
special_tokens = {t for t in special_tokens if t is not None}
|
169 |
+
|
170 |
+
# Encode prompt
|
171 |
+
prompt_tokens = tokenizer.encode(prompt)
|
172 |
+
prompt_size = len(prompt_tokens)
|
173 |
+
max_gen_len = 100
|
174 |
+
total_len = min(getattr(model.config, 'max_position_embeddings', 2048), max_gen_len + prompt_size)
|
175 |
+
|
176 |
+
# Initialize generation
|
177 |
+
tokens = torch.full((1, total_len), model.config.pad_token_id).to(model.device).long()
|
178 |
+
tokens[0, :prompt_size] = torch.tensor(prompt_tokens).long()
|
179 |
+
input_text_mask = tokens != model.config.pad_token_id
|
180 |
+
|
181 |
+
# Generate token by token
|
182 |
+
prev_pos = 0
|
183 |
+
outputs = None # Initialize outputs to None
|
184 |
+
for cur_pos in range(prompt_size, total_len):
|
185 |
+
# Get model outputs
|
186 |
+
outputs = model.forward(
|
187 |
+
tokens[:, prev_pos:cur_pos],
|
188 |
+
use_cache=True,
|
189 |
+
past_key_values=outputs.past_key_values if prev_pos > 0 else None
|
190 |
+
)
|
191 |
+
|
192 |
+
# Sample next token using the generator's sampling method
|
193 |
+
aux = {
|
194 |
+
'ngram_tokens': tokens[:, cur_pos-generator.ngram:cur_pos],
|
195 |
+
'cur_pos': cur_pos,
|
196 |
+
}
|
197 |
+
next_token = generator.sample_next(
|
198 |
+
outputs.logits[:, -1, :],
|
199 |
+
aux,
|
200 |
+
temperature=0.8,
|
201 |
+
top_p=0.95
|
202 |
+
)
|
203 |
+
# Check for EOS token
|
204 |
+
if next_token == model.config.eos_token_id:
|
205 |
+
break
|
206 |
+
|
207 |
+
# Decode and check if it's a special token
|
208 |
+
new_text = tokenizer.decode([next_token])
|
209 |
+
if new_text not in special_tokens and not any(st in new_text for st in special_tokens):
|
210 |
+
yield f"data: {json.dumps({'token': new_text, 'done': False})}\n\n"
|
211 |
+
|
212 |
+
# Update token and position
|
213 |
+
tokens[0, cur_pos] = next_token
|
214 |
+
prev_pos = cur_pos
|
215 |
+
|
216 |
+
# Send final complete text, filtering out special tokens
|
217 |
+
final_tokens = tokens[0, prompt_size:cur_pos+1].tolist()
|
218 |
+
final_text = tokenizer.decode(final_tokens)
|
219 |
+
for st in special_tokens:
|
220 |
+
final_text = final_text.replace(st, '')
|
221 |
+
yield f"data: {json.dumps({'text': final_text, 'done': True})}\n\n"
|
222 |
+
|
223 |
+
except Exception as e:
|
224 |
+
app.logger.error(f'Error generating text: {str(e)}')
|
225 |
+
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
226 |
+
|
227 |
+
return Response(stream_with_context(generate_stream()), mimetype='text/event-stream')
|
228 |
+
|
229 |
+
except Exception as e:
|
230 |
+
app.logger.error(f'Server error: {str(e)}')
|
231 |
+
return jsonify({'error': f'Server error: {str(e)}'}), 500
|
232 |
+
|
233 |
return app
|
234 |
|
235 |
app = create_app()
|
wm_interactive/web/utils.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
import random
|
3 |
import numpy as np
|
4 |
|
@@ -63,3 +62,22 @@ def get_token_details(
|
|
63 |
})
|
64 |
|
65 |
return display_info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import random
|
2 |
import numpy as np
|
3 |
|
|
|
62 |
})
|
63 |
|
64 |
return display_info
|
65 |
+
|
66 |
+
def template_prompt(instruction: str, prompt_type: str = "smollm") -> str:
|
67 |
+
"""Template a prompt according to the model's format.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
instruction: The raw prompt/instruction to template
|
71 |
+
prompt_type: Type of prompt format (smollm, alpaca)
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
The formatted prompt ready for the model
|
75 |
+
"""
|
76 |
+
if prompt_type == "alpaca":
|
77 |
+
return instruction
|
78 |
+
elif prompt_type == "smollm":
|
79 |
+
prompt = "<|im_start|>system\nYou are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>\n"
|
80 |
+
prompt += f"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
|
81 |
+
return prompt
|
82 |
+
else:
|
83 |
+
raise ValueError(f"Prompt type {prompt_type} not supported")
|