Spaces:
Running
Running
Commit
·
7f97da4
1
Parent(s):
f747801
implemented detection
Browse files- wm_interactive/core/detector.py +9 -18
- wm_interactive/core/generator.py +10 -7
- wm_interactive/core/main.py +3 -3
- wm_interactive/static/styles.css +34 -1
- wm_interactive/templates/index.html +91 -13
- wm_interactive/web/app.py +7 -4
wm_interactive/core/detector.py
CHANGED
@@ -93,7 +93,8 @@ class WmDetector():
|
|
93 |
|
94 |
score = float('nan')
|
95 |
if is_scored:
|
96 |
-
score = self.score_tok(ngram_tokens, tokens_id[cur_pos])
|
|
|
97 |
|
98 |
token_details.append({
|
99 |
'token_id': tokens_id[cur_pos],
|
@@ -168,12 +169,7 @@ class MarylandDetector(WmDetector):
|
|
168 |
|
169 |
def score_tok(self, ngram_tokens, token_id):
|
170 |
"""
|
171 |
-
score_t = 1 if token_id in greenlist else 0
|
172 |
-
The last line shifts the scores by token_id.
|
173 |
-
ex: scores[0] = 1 if token_id in greenlist else 0
|
174 |
-
scores[1] = 1 if token_id in (greenlist shifted of 1) else 0
|
175 |
-
...
|
176 |
-
The score for each payload will be given by scores[payload]
|
177 |
"""
|
178 |
seed = get_seed_rng(self.seed, ngram_tokens)
|
179 |
self.rng.manual_seed(seed)
|
@@ -181,7 +177,7 @@ class MarylandDetector(WmDetector):
|
|
181 |
vocab_permutation = torch.randperm(self.vocab_size, generator=self.rng)
|
182 |
greenlist = vocab_permutation[:int(self.gamma * self.vocab_size)] # gamma * n toks in the greenlist
|
183 |
scores[greenlist] = 1
|
184 |
-
return scores
|
185 |
|
186 |
def get_pvalue(self, score: int, ntoks: int, eps: float):
|
187 |
""" from cdf of a binomial distribution """
|
@@ -209,7 +205,7 @@ class MarylandDetectorZ(WmDetector):
|
|
209 |
vocab_permutation = torch.randperm(self.vocab_size, generator=self.rng)
|
210 |
greenlist = vocab_permutation[:int(self.gamma * self.vocab_size)] # gamma * n
|
211 |
scores[greenlist] = 1
|
212 |
-
return scores
|
213 |
|
214 |
def get_pvalue(self, score: int, ntoks: int, eps: float):
|
215 |
""" from cdf of a normal distribution """
|
@@ -229,17 +225,12 @@ class OpenaiDetector(WmDetector):
|
|
229 |
def score_tok(self, ngram_tokens, token_id):
|
230 |
"""
|
231 |
score_t = -log(1 - rt[token_id]])
|
232 |
-
The last line shifts the scores by token_id.
|
233 |
-
ex: scores[0] = r_t[token_id]
|
234 |
-
scores[1] = (r_t shifted of 1)[token_id]
|
235 |
-
...
|
236 |
-
The score for each payload will be given by scores[payload]
|
237 |
"""
|
238 |
seed = get_seed_rng(self.seed, ngram_tokens)
|
239 |
self.rng.manual_seed(seed)
|
240 |
rs = torch.rand(self.vocab_size, generator=self.rng) # n
|
241 |
-
scores = -(1 - rs).log()
|
242 |
-
return scores
|
243 |
|
244 |
def get_pvalue(self, score: float, ntoks: int, eps: float):
|
245 |
""" from cdf of a gamma distribution """
|
@@ -260,8 +251,8 @@ class OpenaiDetectorZ(WmDetector):
|
|
260 |
seed = get_seed_rng(self.seed, ngram_tokens)
|
261 |
self.rng.manual_seed(seed)
|
262 |
rs = torch.rand(self.vocab_size, generator=self.rng) # n
|
263 |
-
scores = -(1 - rs).log()
|
264 |
-
return scores
|
265 |
|
266 |
def get_pvalue(self, score: float, ntoks: int, eps: float):
|
267 |
""" from cdf of a normal distribution """
|
|
|
93 |
|
94 |
score = float('nan')
|
95 |
if is_scored:
|
96 |
+
score = self.score_tok(ngram_tokens, tokens_id[cur_pos])
|
97 |
+
score = float(score)
|
98 |
|
99 |
token_details.append({
|
100 |
'token_id': tokens_id[cur_pos],
|
|
|
169 |
|
170 |
def score_tok(self, ngram_tokens, token_id):
|
171 |
"""
|
172 |
+
score_t = 1 if token_id in greenlist else 0
|
|
|
|
|
|
|
|
|
|
|
173 |
"""
|
174 |
seed = get_seed_rng(self.seed, ngram_tokens)
|
175 |
self.rng.manual_seed(seed)
|
|
|
177 |
vocab_permutation = torch.randperm(self.vocab_size, generator=self.rng)
|
178 |
greenlist = vocab_permutation[:int(self.gamma * self.vocab_size)] # gamma * n toks in the greenlist
|
179 |
scores[greenlist] = 1
|
180 |
+
return scores[token_id]
|
181 |
|
182 |
def get_pvalue(self, score: int, ntoks: int, eps: float):
|
183 |
""" from cdf of a binomial distribution """
|
|
|
205 |
vocab_permutation = torch.randperm(self.vocab_size, generator=self.rng)
|
206 |
greenlist = vocab_permutation[:int(self.gamma * self.vocab_size)] # gamma * n
|
207 |
scores[greenlist] = 1
|
208 |
+
return scores[token_id]
|
209 |
|
210 |
def get_pvalue(self, score: int, ntoks: int, eps: float):
|
211 |
""" from cdf of a normal distribution """
|
|
|
225 |
def score_tok(self, ngram_tokens, token_id):
|
226 |
"""
|
227 |
score_t = -log(1 - rt[token_id]])
|
|
|
|
|
|
|
|
|
|
|
228 |
"""
|
229 |
seed = get_seed_rng(self.seed, ngram_tokens)
|
230 |
self.rng.manual_seed(seed)
|
231 |
rs = torch.rand(self.vocab_size, generator=self.rng) # n
|
232 |
+
scores = -(1 - rs).log()
|
233 |
+
return scores[token_id]
|
234 |
|
235 |
def get_pvalue(self, score: float, ntoks: int, eps: float):
|
236 |
""" from cdf of a gamma distribution """
|
|
|
251 |
seed = get_seed_rng(self.seed, ngram_tokens)
|
252 |
self.rng.manual_seed(seed)
|
253 |
rs = torch.rand(self.vocab_size, generator=self.rng) # n
|
254 |
+
scores = -(1 - rs).log()
|
255 |
+
return scores[token_id]
|
256 |
|
257 |
def get_pvalue(self, score: float, ntoks: int, eps: float):
|
258 |
""" from cdf of a normal distribution """
|
wm_interactive/core/generator.py
CHANGED
@@ -11,7 +11,8 @@ class WmGenerator():
|
|
11 |
model: AutoModelForCausalLM,
|
12 |
tokenizer: AutoTokenizer,
|
13 |
ngram: int = 1,
|
14 |
-
seed: int = 0
|
|
|
15 |
):
|
16 |
# model config
|
17 |
self.tokenizer = tokenizer
|
@@ -49,11 +50,15 @@ class WmGenerator():
|
|
49 |
start_pos = prompt_size
|
50 |
prev_pos = 0
|
51 |
for cur_pos in range(start_pos, total_len):
|
|
|
52 |
outputs = self.model.forward(
|
53 |
-
tokens[:, prev_pos:cur_pos],
|
|
|
|
|
54 |
)
|
|
|
55 |
aux = {
|
56 |
-
'ngram_tokens':
|
57 |
'cur_pos': cur_pos,
|
58 |
}
|
59 |
next_tok = self.sample_next(outputs.logits[:, -1, :], aux, temperature, top_p)
|
@@ -135,7 +140,7 @@ class OpenaiGenerator(WmGenerator):
|
|
135 |
probs_sort[mask] = 0.0
|
136 |
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
137 |
# seed with hash of ngram tokens
|
138 |
-
seed = get_seed_rng(self.seed, ngram_tokens
|
139 |
self.rng.manual_seed(seed)
|
140 |
# generate rs randomly between [0,1]
|
141 |
rs = torch.rand(self.vocab_size, generator=self.rng) # n
|
@@ -164,13 +169,11 @@ class MarylandGenerator(WmGenerator):
|
|
164 |
*args,
|
165 |
gamma: float = 0.5,
|
166 |
delta: float = 1.0,
|
167 |
-
test_mul: float = 0,
|
168 |
**kwargs
|
169 |
):
|
170 |
super().__init__(*args, **kwargs)
|
171 |
self.gamma = gamma
|
172 |
self.delta = delta
|
173 |
-
self.test_mul = test_mul
|
174 |
|
175 |
def sample_next(
|
176 |
self,
|
@@ -198,7 +201,7 @@ class MarylandGenerator(WmGenerator):
|
|
198 |
def logits_processor(self, logits, ngram_tokens):
|
199 |
"""Process logits to mask out words in greenlist."""
|
200 |
logits = logits.clone()
|
201 |
-
seed = get_seed_rng(self.seed, ngram_tokens
|
202 |
self.rng.manual_seed(seed)
|
203 |
vocab_permutation = torch.randperm(self.vocab_size, generator=self.rng)
|
204 |
greenlist = vocab_permutation[:int(self.gamma * self.vocab_size)] # gamma * n
|
|
|
11 |
model: AutoModelForCausalLM,
|
12 |
tokenizer: AutoTokenizer,
|
13 |
ngram: int = 1,
|
14 |
+
seed: int = 0,
|
15 |
+
**kwargs
|
16 |
):
|
17 |
# model config
|
18 |
self.tokenizer = tokenizer
|
|
|
50 |
start_pos = prompt_size
|
51 |
prev_pos = 0
|
52 |
for cur_pos in range(start_pos, total_len):
|
53 |
+
past_key_values = outputs.past_key_values if prev_pos > 0 else None
|
54 |
outputs = self.model.forward(
|
55 |
+
tokens[:, prev_pos:cur_pos],
|
56 |
+
use_cache=True,
|
57 |
+
past_key_values=past_key_values
|
58 |
)
|
59 |
+
ngram_tokens = tokens[0, cur_pos-self.ngram:cur_pos].tolist()
|
60 |
aux = {
|
61 |
+
'ngram_tokens': ngram_tokens,
|
62 |
'cur_pos': cur_pos,
|
63 |
}
|
64 |
next_tok = self.sample_next(outputs.logits[:, -1, :], aux, temperature, top_p)
|
|
|
140 |
probs_sort[mask] = 0.0
|
141 |
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
142 |
# seed with hash of ngram tokens
|
143 |
+
seed = get_seed_rng(self.seed, ngram_tokens)
|
144 |
self.rng.manual_seed(seed)
|
145 |
# generate rs randomly between [0,1]
|
146 |
rs = torch.rand(self.vocab_size, generator=self.rng) # n
|
|
|
169 |
*args,
|
170 |
gamma: float = 0.5,
|
171 |
delta: float = 1.0,
|
|
|
172 |
**kwargs
|
173 |
):
|
174 |
super().__init__(*args, **kwargs)
|
175 |
self.gamma = gamma
|
176 |
self.delta = delta
|
|
|
177 |
|
178 |
def sample_next(
|
179 |
self,
|
|
|
201 |
def logits_processor(self, logits, ngram_tokens):
|
202 |
"""Process logits to mask out words in greenlist."""
|
203 |
logits = logits.clone()
|
204 |
+
seed = get_seed_rng(self.seed, ngram_tokens)
|
205 |
self.rng.manual_seed(seed)
|
206 |
vocab_permutation = torch.randperm(self.vocab_size, generator=self.rng)
|
207 |
greenlist = vocab_permutation[:int(self.gamma * self.vocab_size)] # gamma * n
|
wm_interactive/core/main.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
"""
|
2 |
Main script for watermark detection.
|
3 |
Test with:
|
4 |
-
python -m wm_interactive.core.main --model_name smollm2-135m --prompt_path data/prompts.json
|
5 |
"""
|
6 |
|
7 |
import os
|
@@ -116,9 +116,9 @@ def get_args_parser():
|
|
116 |
help='Statistical test to detect watermark. Choose from: same (same as method), openai, openaiz, maryland, marylandz')
|
117 |
parser.add_argument('--seed', type=int, default=0,
|
118 |
help='Random seed for reproducibility')
|
119 |
-
parser.add_argument('--ngram', type=int, default=
|
120 |
help='n-gram size for rng key generation')
|
121 |
-
parser.add_argument('--gamma', type=float, default=0.
|
122 |
help='For maryland method: proportion of greenlist tokens')
|
123 |
parser.add_argument('--delta', type=float, default=2.0,
|
124 |
help='For maryland method: bias to add to greenlist tokens')
|
|
|
1 |
"""
|
2 |
Main script for watermark detection.
|
3 |
Test with:
|
4 |
+
python -m wm_interactive.core.main --model_name smollm2-135m --prompt_path data/prompts.json --method maryland --delta 4.0 --ngram 1
|
5 |
"""
|
6 |
|
7 |
import os
|
|
|
116 |
help='Statistical test to detect watermark. Choose from: same (same as method), openai, openaiz, maryland, marylandz')
|
117 |
parser.add_argument('--seed', type=int, default=0,
|
118 |
help='Random seed for reproducibility')
|
119 |
+
parser.add_argument('--ngram', type=int, default=1,
|
120 |
help='n-gram size for rng key generation')
|
121 |
+
parser.add_argument('--gamma', type=float, default=0.5,
|
122 |
help='For maryland method: proportion of greenlist tokens')
|
123 |
parser.add_argument('--delta', type=float, default=2.0,
|
124 |
help='For maryland method: bias to add to greenlist tokens')
|
wm_interactive/static/styles.css
CHANGED
@@ -117,6 +117,39 @@ h1 {
|
|
117 |
.stat-label {
|
118 |
color: #666;
|
119 |
font-size: 20px;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
}
|
121 |
|
122 |
/* Mobile-specific styles */
|
@@ -153,4 +186,4 @@ h1 {
|
|
153 |
.stat-label {
|
154 |
font-size: 16px;
|
155 |
}
|
156 |
-
}
|
|
|
117 |
.stat-label {
|
118 |
color: #666;
|
119 |
font-size: 20px;
|
120 |
+
position: relative;
|
121 |
+
display: inline-flex;
|
122 |
+
align-items: center;
|
123 |
+
gap: 0.5rem;
|
124 |
+
}
|
125 |
+
|
126 |
+
.help-icon {
|
127 |
+
cursor: help;
|
128 |
+
color: #6c757d;
|
129 |
+
font-size: 0.875rem;
|
130 |
+
}
|
131 |
+
|
132 |
+
.help-tooltip {
|
133 |
+
visibility: hidden;
|
134 |
+
position: absolute;
|
135 |
+
bottom: 100%;
|
136 |
+
left: 50%;
|
137 |
+
transform: translateX(-50%);
|
138 |
+
background-color: #333;
|
139 |
+
color: white;
|
140 |
+
padding: 0.5rem;
|
141 |
+
border-radius: 4px;
|
142 |
+
font-size: 0.75rem;
|
143 |
+
width: max-content;
|
144 |
+
max-width: 200px;
|
145 |
+
z-index: 1000;
|
146 |
+
opacity: 0;
|
147 |
+
transition: opacity 0.2s;
|
148 |
+
}
|
149 |
+
|
150 |
+
.help-icon:hover + .help-tooltip {
|
151 |
+
visibility: visible;
|
152 |
+
opacity: 1;
|
153 |
}
|
154 |
|
155 |
/* Mobile-specific styles */
|
|
|
186 |
.stat-label {
|
187 |
font-size: 16px;
|
188 |
}
|
189 |
+
}
|
wm_interactive/templates/index.html
CHANGED
@@ -45,6 +45,16 @@
|
|
45 |
<input type="number" class="form-control" id="ngram" value="1">
|
46 |
<div class="form-text">Size of the n-gram window used for detection</div>
|
47 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
</div>
|
49 |
<div class="modal-footer">
|
50 |
<button type="button" class="btn btn-secondary" data-bs-dismiss="modal">Close</button>
|
@@ -73,19 +83,35 @@
|
|
73 |
<div class="stats-container">
|
74 |
<div>
|
75 |
<div class="stat-value" id="tokenCount">0</div>
|
76 |
-
<div class="stat-label">
|
|
|
|
|
|
|
|
|
77 |
</div>
|
78 |
<div>
|
79 |
<div class="stat-value" id="scoredTokens">0</div>
|
80 |
-
<div class="stat-label">
|
|
|
|
|
|
|
|
|
81 |
</div>
|
82 |
<div>
|
83 |
<div class="stat-value" id="finalScore">0.00</div>
|
84 |
-
<div class="stat-label">
|
|
|
|
|
|
|
|
|
85 |
</div>
|
86 |
<div>
|
87 |
<div class="stat-value" id="pValue">0.500</div>
|
88 |
-
<div class="stat-label">
|
|
|
|
|
|
|
|
|
89 |
</div>
|
90 |
</div>
|
91 |
</div>
|
@@ -93,7 +119,7 @@
|
|
93 |
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"></script>
|
94 |
<script>
|
95 |
let debounceTimeout = null;
|
96 |
-
let
|
97 |
const textarea = document.getElementById('user_text');
|
98 |
const promptArea = document.getElementById('prompt_text');
|
99 |
const generateBtn = document.getElementById('generateBtn');
|
@@ -107,6 +133,8 @@
|
|
107 |
const seedInput = document.getElementById('seed');
|
108 |
const ngramInput = document.getElementById('ngram');
|
109 |
const detectorTypeSelect = document.getElementById('detectorType');
|
|
|
|
|
110 |
|
111 |
function startGeneration() {
|
112 |
const prompt = promptArea.value.trim();
|
@@ -119,11 +147,16 @@
|
|
119 |
stopBtn.disabled = false;
|
120 |
textarea.value = '';
|
121 |
|
|
|
|
|
|
|
122 |
// Get current parameters
|
123 |
const params = {
|
124 |
detector_type: detectorTypeSelect.value,
|
125 |
seed: parseInt(seedInput.value) || 0,
|
126 |
-
ngram: parseInt(ngramInput.value) || 1
|
|
|
|
|
127 |
};
|
128 |
|
129 |
// Create headers for SSE
|
@@ -132,14 +165,15 @@
|
|
132 |
'Accept': 'text/event-stream',
|
133 |
});
|
134 |
|
135 |
-
// Start fetch request
|
136 |
fetch('/generate', {
|
137 |
method: 'POST',
|
138 |
headers: headers,
|
139 |
body: JSON.stringify({
|
140 |
prompt: prompt,
|
141 |
params: params
|
142 |
-
})
|
|
|
143 |
}).then(response => {
|
144 |
const reader = response.body.getReader();
|
145 |
const decoder = new TextDecoder();
|
@@ -205,16 +239,25 @@
|
|
205 |
return pump();
|
206 |
})
|
207 |
.catch(error => {
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
|
|
210 |
})
|
211 |
.finally(() => {
|
212 |
generateBtn.disabled = false;
|
213 |
stopBtn.disabled = true;
|
|
|
214 |
});
|
215 |
}
|
216 |
|
217 |
function stopGeneration() {
|
|
|
|
|
|
|
|
|
218 |
generateBtn.disabled = false;
|
219 |
stopBtn.disabled = true;
|
220 |
}
|
@@ -230,6 +273,8 @@
|
|
230 |
// Validate parameters before sending
|
231 |
const seed = parseInt(seedInput.value);
|
232 |
const ngram = parseInt(ngramInput.value);
|
|
|
|
|
233 |
|
234 |
const response = await fetch('/tokenize', {
|
235 |
method: 'POST',
|
@@ -241,7 +286,9 @@
|
|
241 |
params: {
|
242 |
detector_type: detectorTypeSelect.value,
|
243 |
seed: isNaN(seed) ? 0 : seed,
|
244 |
-
ngram: isNaN(ngram) ? 1 : ngram
|
|
|
|
|
245 |
}
|
246 |
})
|
247 |
});
|
@@ -262,7 +309,7 @@
|
|
262 |
const score = data.scores[i];
|
263 |
const pvalue = data.pvalues[i];
|
264 |
const scoreDisplay = (score !== null && !isNaN(score)) ? score.toFixed(3) : 'N/A';
|
265 |
-
const pvalueDisplay = (pvalue !== null && !isNaN(pvalue)) ? pvalue
|
266 |
|
267 |
return `<span class="token" style="background-color: ${data.colors[i]}">
|
268 |
${token}
|
@@ -279,7 +326,7 @@
|
|
279 |
finalScore.textContent = (data.final_score !== null && !isNaN(data.final_score)) ?
|
280 |
data.final_score.toFixed(2) : '0.00';
|
281 |
pValue.textContent = (data.final_pvalue !== null && !isNaN(data.final_pvalue)) ?
|
282 |
-
data.final_pvalue
|
283 |
|
284 |
// Clear any previous error
|
285 |
const existingError = tokenDisplay.querySelector('.alert-danger');
|
@@ -332,6 +379,28 @@
|
|
332 |
debounceTimeout = setTimeout(updateTokenization, 500);
|
333 |
});
|
334 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
335 |
// Add keyboard shortcut for applying changes
|
336 |
document.addEventListener('keydown', function(e) {
|
337 |
if ((e.metaKey || e.ctrlKey) && e.key === 'Enter') {
|
@@ -369,6 +438,15 @@
|
|
369 |
console.error('Error during initial tokenization:', error);
|
370 |
});
|
371 |
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
372 |
</script>
|
373 |
</body>
|
374 |
</html>
|
|
|
45 |
<input type="number" class="form-control" id="ngram" value="1">
|
46 |
<div class="form-text">Size of the n-gram window used for detection</div>
|
47 |
</div>
|
48 |
+
<div class="mb-3">
|
49 |
+
<label for="delta" class="form-label">Delta</label>
|
50 |
+
<input type="number" step="0.1" class="form-control" id="delta" value="2.0">
|
51 |
+
<div class="form-text">Bias added to greenlist tokens (for Maryland method)</div>
|
52 |
+
</div>
|
53 |
+
<div class="mb-3">
|
54 |
+
<label for="temperature" class="form-label">Temperature</label>
|
55 |
+
<input type="number" step="0.1" class="form-control" id="temperature" value="0.8">
|
56 |
+
<div class="form-text">Temperature for sampling (higher = more random)</div>
|
57 |
+
</div>
|
58 |
</div>
|
59 |
<div class="modal-footer">
|
60 |
<button type="button" class="btn btn-secondary" data-bs-dismiss="modal">Close</button>
|
|
|
83 |
<div class="stats-container">
|
84 |
<div>
|
85 |
<div class="stat-value" id="tokenCount">0</div>
|
86 |
+
<div class="stat-label">
|
87 |
+
Tokens
|
88 |
+
<i class="bi bi-question-circle help-icon"></i>
|
89 |
+
<span class="help-tooltip">Total number of tokens in the text</span>
|
90 |
+
</div>
|
91 |
</div>
|
92 |
<div>
|
93 |
<div class="stat-value" id="scoredTokens">0</div>
|
94 |
+
<div class="stat-label">
|
95 |
+
Scored Tokens
|
96 |
+
<i class="bi bi-question-circle help-icon"></i>
|
97 |
+
<span class="help-tooltip">Number of tokens that were actually scored by the detector (excludes first n-gram tokens and duplicates)</span>
|
98 |
+
</div>
|
99 |
</div>
|
100 |
<div>
|
101 |
<div class="stat-value" id="finalScore">0.00</div>
|
102 |
+
<div class="stat-label">
|
103 |
+
Final Score
|
104 |
+
<i class="bi bi-question-circle help-icon"></i>
|
105 |
+
<span class="help-tooltip">Cumulative score from all scored tokens. Higher values indicate more likely watermarked text</span>
|
106 |
+
</div>
|
107 |
</div>
|
108 |
<div>
|
109 |
<div class="stat-value" id="pValue">0.500</div>
|
110 |
+
<div class="stat-label">
|
111 |
+
P-value
|
112 |
+
<i class="bi bi-question-circle help-icon"></i>
|
113 |
+
<span class="help-tooltip">Statistical significance of the score. Lower values indicate stronger evidence of watermarking (p < 0.05 is typically considered significant)</span>
|
114 |
+
</div>
|
115 |
</div>
|
116 |
</div>
|
117 |
</div>
|
|
|
119 |
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"></script>
|
120 |
<script>
|
121 |
let debounceTimeout = null;
|
122 |
+
let abortController = null; // Add this line at the top with other variables
|
123 |
const textarea = document.getElementById('user_text');
|
124 |
const promptArea = document.getElementById('prompt_text');
|
125 |
const generateBtn = document.getElementById('generateBtn');
|
|
|
133 |
const seedInput = document.getElementById('seed');
|
134 |
const ngramInput = document.getElementById('ngram');
|
135 |
const detectorTypeSelect = document.getElementById('detectorType');
|
136 |
+
const deltaInput = document.getElementById('delta');
|
137 |
+
const temperatureInput = document.getElementById('temperature');
|
138 |
|
139 |
function startGeneration() {
|
140 |
const prompt = promptArea.value.trim();
|
|
|
147 |
stopBtn.disabled = false;
|
148 |
textarea.value = '';
|
149 |
|
150 |
+
// Create new AbortController for this request
|
151 |
+
abortController = new AbortController();
|
152 |
+
|
153 |
// Get current parameters
|
154 |
const params = {
|
155 |
detector_type: detectorTypeSelect.value,
|
156 |
seed: parseInt(seedInput.value) || 0,
|
157 |
+
ngram: parseInt(ngramInput.value) || 1,
|
158 |
+
delta: parseFloat(deltaInput.value) || 2.0,
|
159 |
+
temperature: parseFloat(temperatureInput.value) || 0.8
|
160 |
};
|
161 |
|
162 |
// Create headers for SSE
|
|
|
165 |
'Accept': 'text/event-stream',
|
166 |
});
|
167 |
|
168 |
+
// Start fetch request with signal
|
169 |
fetch('/generate', {
|
170 |
method: 'POST',
|
171 |
headers: headers,
|
172 |
body: JSON.stringify({
|
173 |
prompt: prompt,
|
174 |
params: params
|
175 |
+
}),
|
176 |
+
signal: abortController.signal // Add the abort signal
|
177 |
}).then(response => {
|
178 |
const reader = response.body.getReader();
|
179 |
const decoder = new TextDecoder();
|
|
|
239 |
return pump();
|
240 |
})
|
241 |
.catch(error => {
|
242 |
+
if (error.name === 'AbortError') {
|
243 |
+
console.log('Generation stopped by user');
|
244 |
+
} else {
|
245 |
+
console.error('Error:', error);
|
246 |
+
alert('Error: Failed to generate text');
|
247 |
+
}
|
248 |
})
|
249 |
.finally(() => {
|
250 |
generateBtn.disabled = false;
|
251 |
stopBtn.disabled = true;
|
252 |
+
abortController = null;
|
253 |
});
|
254 |
}
|
255 |
|
256 |
function stopGeneration() {
|
257 |
+
if (abortController) {
|
258 |
+
abortController.abort();
|
259 |
+
abortController = null;
|
260 |
+
}
|
261 |
generateBtn.disabled = false;
|
262 |
stopBtn.disabled = true;
|
263 |
}
|
|
|
273 |
// Validate parameters before sending
|
274 |
const seed = parseInt(seedInput.value);
|
275 |
const ngram = parseInt(ngramInput.value);
|
276 |
+
const delta = parseFloat(deltaInput.value);
|
277 |
+
const temperature = parseFloat(temperatureInput.value);
|
278 |
|
279 |
const response = await fetch('/tokenize', {
|
280 |
method: 'POST',
|
|
|
286 |
params: {
|
287 |
detector_type: detectorTypeSelect.value,
|
288 |
seed: isNaN(seed) ? 0 : seed,
|
289 |
+
ngram: isNaN(ngram) ? 1 : ngram,
|
290 |
+
delta: isNaN(delta) ? 2.0 : delta,
|
291 |
+
temperature: isNaN(temperature) ? 0.8 : temperature
|
292 |
}
|
293 |
})
|
294 |
});
|
|
|
309 |
const score = data.scores[i];
|
310 |
const pvalue = data.pvalues[i];
|
311 |
const scoreDisplay = (score !== null && !isNaN(score)) ? score.toFixed(3) : 'N/A';
|
312 |
+
const pvalueDisplay = (pvalue !== null && !isNaN(pvalue)) ? formatPValue(pvalue) : 'N/A';
|
313 |
|
314 |
return `<span class="token" style="background-color: ${data.colors[i]}">
|
315 |
${token}
|
|
|
326 |
finalScore.textContent = (data.final_score !== null && !isNaN(data.final_score)) ?
|
327 |
data.final_score.toFixed(2) : '0.00';
|
328 |
pValue.textContent = (data.final_pvalue !== null && !isNaN(data.final_pvalue)) ?
|
329 |
+
formatPValue(data.final_pvalue) : '0.500';
|
330 |
|
331 |
// Clear any previous error
|
332 |
const existingError = tokenDisplay.querySelector('.alert-danger');
|
|
|
379 |
debounceTimeout = setTimeout(updateTokenization, 500);
|
380 |
});
|
381 |
|
382 |
+
deltaInput.addEventListener('input', function() {
|
383 |
+
const value = this.value === '' ? '' : parseFloat(this.value);
|
384 |
+
if (isNaN(value) && this.value !== '') {
|
385 |
+
this.value = "2.0";
|
386 |
+
}
|
387 |
+
if (debounceTimeout) {
|
388 |
+
clearTimeout(debounceTimeout);
|
389 |
+
}
|
390 |
+
debounceTimeout = setTimeout(updateTokenization, 500);
|
391 |
+
});
|
392 |
+
|
393 |
+
temperatureInput.addEventListener('input', function() {
|
394 |
+
const value = this.value === '' ? '' : parseFloat(this.value);
|
395 |
+
if (isNaN(value) && this.value !== '') {
|
396 |
+
this.value = "0.8";
|
397 |
+
}
|
398 |
+
if (debounceTimeout) {
|
399 |
+
clearTimeout(debounceTimeout);
|
400 |
+
}
|
401 |
+
debounceTimeout = setTimeout(updateTokenization, 500);
|
402 |
+
});
|
403 |
+
|
404 |
// Add keyboard shortcut for applying changes
|
405 |
document.addEventListener('keydown', function(e) {
|
406 |
if ((e.metaKey || e.ctrlKey) && e.key === 'Enter') {
|
|
|
438 |
console.error('Error during initial tokenization:', error);
|
439 |
});
|
440 |
});
|
441 |
+
|
442 |
+
// Add this helper function for formatting p-values
|
443 |
+
function formatPValue(value) {
|
444 |
+
if (value >= 0.001) {
|
445 |
+
return value.toFixed(3);
|
446 |
+
} else {
|
447 |
+
return value.toExponential(2);
|
448 |
+
}
|
449 |
+
}
|
450 |
</script>
|
451 |
</body>
|
452 |
</html>
|
wm_interactive/web/app.py
CHANGED
@@ -146,6 +146,7 @@ def create_app():
|
|
146 |
|
147 |
prompt = template_prompt(data.get('prompt', ''))
|
148 |
params = data.get('params', {})
|
|
|
149 |
|
150 |
def generate_stream():
|
151 |
try:
|
@@ -155,7 +156,8 @@ def create_app():
|
|
155 |
model=model,
|
156 |
tokenizer=tokenizer,
|
157 |
ngram=set_to_int(params.get('ngram', 1)),
|
158 |
-
seed=set_to_int(params.get('seed', 0))
|
|
|
159 |
)
|
160 |
|
161 |
# Get special tokens to filter out
|
@@ -190,15 +192,16 @@ def create_app():
|
|
190 |
)
|
191 |
|
192 |
# Sample next token using the generator's sampling method
|
|
|
193 |
aux = {
|
194 |
-
'ngram_tokens':
|
195 |
'cur_pos': cur_pos,
|
196 |
}
|
197 |
next_token = generator.sample_next(
|
198 |
outputs.logits[:, -1, :],
|
199 |
aux,
|
200 |
-
temperature=
|
201 |
-
top_p=0.
|
202 |
)
|
203 |
# Check for EOS token
|
204 |
if next_token == model.config.eos_token_id:
|
|
|
146 |
|
147 |
prompt = template_prompt(data.get('prompt', ''))
|
148 |
params = data.get('params', {})
|
149 |
+
temperature = float(params.get('temperature', 0.8))
|
150 |
|
151 |
def generate_stream():
|
152 |
try:
|
|
|
156 |
model=model,
|
157 |
tokenizer=tokenizer,
|
158 |
ngram=set_to_int(params.get('ngram', 1)),
|
159 |
+
seed=set_to_int(params.get('seed', 0)),
|
160 |
+
delta=float(params.get('delta', 2.0)),
|
161 |
)
|
162 |
|
163 |
# Get special tokens to filter out
|
|
|
192 |
)
|
193 |
|
194 |
# Sample next token using the generator's sampling method
|
195 |
+
ngram_tokens = tokens[0, cur_pos-generator.ngram:cur_pos].tolist()
|
196 |
aux = {
|
197 |
+
'ngram_tokens': ngram_tokens,
|
198 |
'cur_pos': cur_pos,
|
199 |
}
|
200 |
next_token = generator.sample_next(
|
201 |
outputs.logits[:, -1, :],
|
202 |
aux,
|
203 |
+
temperature=temperature,
|
204 |
+
top_p=0.9
|
205 |
)
|
206 |
# Check for EOS token
|
207 |
if next_token == model.config.eos_token_id:
|