TheFrenchDemos commited on
Commit
7f97da4
·
1 Parent(s): f747801

implemented detection

Browse files
wm_interactive/core/detector.py CHANGED
@@ -93,7 +93,8 @@ class WmDetector():
93
 
94
  score = float('nan')
95
  if is_scored:
96
- score = self.score_tok(ngram_tokens, tokens_id[cur_pos]).numpy()[0]
 
97
 
98
  token_details.append({
99
  'token_id': tokens_id[cur_pos],
@@ -168,12 +169,7 @@ class MarylandDetector(WmDetector):
168
 
169
  def score_tok(self, ngram_tokens, token_id):
170
  """
171
- score_t = 1 if token_id in greenlist else 0
172
- The last line shifts the scores by token_id.
173
- ex: scores[0] = 1 if token_id in greenlist else 0
174
- scores[1] = 1 if token_id in (greenlist shifted of 1) else 0
175
- ...
176
- The score for each payload will be given by scores[payload]
177
  """
178
  seed = get_seed_rng(self.seed, ngram_tokens)
179
  self.rng.manual_seed(seed)
@@ -181,7 +177,7 @@ class MarylandDetector(WmDetector):
181
  vocab_permutation = torch.randperm(self.vocab_size, generator=self.rng)
182
  greenlist = vocab_permutation[:int(self.gamma * self.vocab_size)] # gamma * n toks in the greenlist
183
  scores[greenlist] = 1
184
- return scores.roll(-token_id)
185
 
186
  def get_pvalue(self, score: int, ntoks: int, eps: float):
187
  """ from cdf of a binomial distribution """
@@ -209,7 +205,7 @@ class MarylandDetectorZ(WmDetector):
209
  vocab_permutation = torch.randperm(self.vocab_size, generator=self.rng)
210
  greenlist = vocab_permutation[:int(self.gamma * self.vocab_size)] # gamma * n
211
  scores[greenlist] = 1
212
- return scores.roll(-token_id)
213
 
214
  def get_pvalue(self, score: int, ntoks: int, eps: float):
215
  """ from cdf of a normal distribution """
@@ -229,17 +225,12 @@ class OpenaiDetector(WmDetector):
229
  def score_tok(self, ngram_tokens, token_id):
230
  """
231
  score_t = -log(1 - rt[token_id]])
232
- The last line shifts the scores by token_id.
233
- ex: scores[0] = r_t[token_id]
234
- scores[1] = (r_t shifted of 1)[token_id]
235
- ...
236
- The score for each payload will be given by scores[payload]
237
  """
238
  seed = get_seed_rng(self.seed, ngram_tokens)
239
  self.rng.manual_seed(seed)
240
  rs = torch.rand(self.vocab_size, generator=self.rng) # n
241
- scores = -(1 - rs).log().roll(-token_id)
242
- return scores
243
 
244
  def get_pvalue(self, score: float, ntoks: int, eps: float):
245
  """ from cdf of a gamma distribution """
@@ -260,8 +251,8 @@ class OpenaiDetectorZ(WmDetector):
260
  seed = get_seed_rng(self.seed, ngram_tokens)
261
  self.rng.manual_seed(seed)
262
  rs = torch.rand(self.vocab_size, generator=self.rng) # n
263
- scores = -(1 - rs).log().roll(-token_id)
264
- return scores
265
 
266
  def get_pvalue(self, score: float, ntoks: int, eps: float):
267
  """ from cdf of a normal distribution """
 
93
 
94
  score = float('nan')
95
  if is_scored:
96
+ score = self.score_tok(ngram_tokens, tokens_id[cur_pos])
97
+ score = float(score)
98
 
99
  token_details.append({
100
  'token_id': tokens_id[cur_pos],
 
169
 
170
  def score_tok(self, ngram_tokens, token_id):
171
  """
172
+ score_t = 1 if token_id in greenlist else 0
 
 
 
 
 
173
  """
174
  seed = get_seed_rng(self.seed, ngram_tokens)
175
  self.rng.manual_seed(seed)
 
177
  vocab_permutation = torch.randperm(self.vocab_size, generator=self.rng)
178
  greenlist = vocab_permutation[:int(self.gamma * self.vocab_size)] # gamma * n toks in the greenlist
179
  scores[greenlist] = 1
180
+ return scores[token_id]
181
 
182
  def get_pvalue(self, score: int, ntoks: int, eps: float):
183
  """ from cdf of a binomial distribution """
 
205
  vocab_permutation = torch.randperm(self.vocab_size, generator=self.rng)
206
  greenlist = vocab_permutation[:int(self.gamma * self.vocab_size)] # gamma * n
207
  scores[greenlist] = 1
208
+ return scores[token_id]
209
 
210
  def get_pvalue(self, score: int, ntoks: int, eps: float):
211
  """ from cdf of a normal distribution """
 
225
  def score_tok(self, ngram_tokens, token_id):
226
  """
227
  score_t = -log(1 - rt[token_id]])
 
 
 
 
 
228
  """
229
  seed = get_seed_rng(self.seed, ngram_tokens)
230
  self.rng.manual_seed(seed)
231
  rs = torch.rand(self.vocab_size, generator=self.rng) # n
232
+ scores = -(1 - rs).log()
233
+ return scores[token_id]
234
 
235
  def get_pvalue(self, score: float, ntoks: int, eps: float):
236
  """ from cdf of a gamma distribution """
 
251
  seed = get_seed_rng(self.seed, ngram_tokens)
252
  self.rng.manual_seed(seed)
253
  rs = torch.rand(self.vocab_size, generator=self.rng) # n
254
+ scores = -(1 - rs).log()
255
+ return scores[token_id]
256
 
257
  def get_pvalue(self, score: float, ntoks: int, eps: float):
258
  """ from cdf of a normal distribution """
wm_interactive/core/generator.py CHANGED
@@ -11,7 +11,8 @@ class WmGenerator():
11
  model: AutoModelForCausalLM,
12
  tokenizer: AutoTokenizer,
13
  ngram: int = 1,
14
- seed: int = 0
 
15
  ):
16
  # model config
17
  self.tokenizer = tokenizer
@@ -49,11 +50,15 @@ class WmGenerator():
49
  start_pos = prompt_size
50
  prev_pos = 0
51
  for cur_pos in range(start_pos, total_len):
 
52
  outputs = self.model.forward(
53
- tokens[:, prev_pos:cur_pos], use_cache=True, past_key_values=outputs.past_key_values if prev_pos > 0 else None
 
 
54
  )
 
55
  aux = {
56
- 'ngram_tokens': tokens[:, cur_pos-self.ngram:cur_pos],
57
  'cur_pos': cur_pos,
58
  }
59
  next_tok = self.sample_next(outputs.logits[:, -1, :], aux, temperature, top_p)
@@ -135,7 +140,7 @@ class OpenaiGenerator(WmGenerator):
135
  probs_sort[mask] = 0.0
136
  probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
137
  # seed with hash of ngram tokens
138
- seed = get_seed_rng(self.seed, ngram_tokens[0])
139
  self.rng.manual_seed(seed)
140
  # generate rs randomly between [0,1]
141
  rs = torch.rand(self.vocab_size, generator=self.rng) # n
@@ -164,13 +169,11 @@ class MarylandGenerator(WmGenerator):
164
  *args,
165
  gamma: float = 0.5,
166
  delta: float = 1.0,
167
- test_mul: float = 0,
168
  **kwargs
169
  ):
170
  super().__init__(*args, **kwargs)
171
  self.gamma = gamma
172
  self.delta = delta
173
- self.test_mul = test_mul
174
 
175
  def sample_next(
176
  self,
@@ -198,7 +201,7 @@ class MarylandGenerator(WmGenerator):
198
  def logits_processor(self, logits, ngram_tokens):
199
  """Process logits to mask out words in greenlist."""
200
  logits = logits.clone()
201
- seed = get_seed_rng(self.seed, ngram_tokens[0])
202
  self.rng.manual_seed(seed)
203
  vocab_permutation = torch.randperm(self.vocab_size, generator=self.rng)
204
  greenlist = vocab_permutation[:int(self.gamma * self.vocab_size)] # gamma * n
 
11
  model: AutoModelForCausalLM,
12
  tokenizer: AutoTokenizer,
13
  ngram: int = 1,
14
+ seed: int = 0,
15
+ **kwargs
16
  ):
17
  # model config
18
  self.tokenizer = tokenizer
 
50
  start_pos = prompt_size
51
  prev_pos = 0
52
  for cur_pos in range(start_pos, total_len):
53
+ past_key_values = outputs.past_key_values if prev_pos > 0 else None
54
  outputs = self.model.forward(
55
+ tokens[:, prev_pos:cur_pos],
56
+ use_cache=True,
57
+ past_key_values=past_key_values
58
  )
59
+ ngram_tokens = tokens[0, cur_pos-self.ngram:cur_pos].tolist()
60
  aux = {
61
+ 'ngram_tokens': ngram_tokens,
62
  'cur_pos': cur_pos,
63
  }
64
  next_tok = self.sample_next(outputs.logits[:, -1, :], aux, temperature, top_p)
 
140
  probs_sort[mask] = 0.0
141
  probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
142
  # seed with hash of ngram tokens
143
+ seed = get_seed_rng(self.seed, ngram_tokens)
144
  self.rng.manual_seed(seed)
145
  # generate rs randomly between [0,1]
146
  rs = torch.rand(self.vocab_size, generator=self.rng) # n
 
169
  *args,
170
  gamma: float = 0.5,
171
  delta: float = 1.0,
 
172
  **kwargs
173
  ):
174
  super().__init__(*args, **kwargs)
175
  self.gamma = gamma
176
  self.delta = delta
 
177
 
178
  def sample_next(
179
  self,
 
201
  def logits_processor(self, logits, ngram_tokens):
202
  """Process logits to mask out words in greenlist."""
203
  logits = logits.clone()
204
+ seed = get_seed_rng(self.seed, ngram_tokens)
205
  self.rng.manual_seed(seed)
206
  vocab_permutation = torch.randperm(self.vocab_size, generator=self.rng)
207
  greenlist = vocab_permutation[:int(self.gamma * self.vocab_size)] # gamma * n
wm_interactive/core/main.py CHANGED
@@ -1,7 +1,7 @@
1
  """
2
  Main script for watermark detection.
3
  Test with:
4
- python -m wm_interactive.core.main --model_name smollm2-135m --prompt_path data/prompts.json
5
  """
6
 
7
  import os
@@ -116,9 +116,9 @@ def get_args_parser():
116
  help='Statistical test to detect watermark. Choose from: same (same as method), openai, openaiz, maryland, marylandz')
117
  parser.add_argument('--seed', type=int, default=0,
118
  help='Random seed for reproducibility')
119
- parser.add_argument('--ngram', type=int, default=4,
120
  help='n-gram size for rng key generation')
121
- parser.add_argument('--gamma', type=float, default=0.25,
122
  help='For maryland method: proportion of greenlist tokens')
123
  parser.add_argument('--delta', type=float, default=2.0,
124
  help='For maryland method: bias to add to greenlist tokens')
 
1
  """
2
  Main script for watermark detection.
3
  Test with:
4
+ python -m wm_interactive.core.main --model_name smollm2-135m --prompt_path data/prompts.json --method maryland --delta 4.0 --ngram 1
5
  """
6
 
7
  import os
 
116
  help='Statistical test to detect watermark. Choose from: same (same as method), openai, openaiz, maryland, marylandz')
117
  parser.add_argument('--seed', type=int, default=0,
118
  help='Random seed for reproducibility')
119
+ parser.add_argument('--ngram', type=int, default=1,
120
  help='n-gram size for rng key generation')
121
+ parser.add_argument('--gamma', type=float, default=0.5,
122
  help='For maryland method: proportion of greenlist tokens')
123
  parser.add_argument('--delta', type=float, default=2.0,
124
  help='For maryland method: bias to add to greenlist tokens')
wm_interactive/static/styles.css CHANGED
@@ -117,6 +117,39 @@ h1 {
117
  .stat-label {
118
  color: #666;
119
  font-size: 20px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  }
121
 
122
  /* Mobile-specific styles */
@@ -153,4 +186,4 @@ h1 {
153
  .stat-label {
154
  font-size: 16px;
155
  }
156
- }
 
117
  .stat-label {
118
  color: #666;
119
  font-size: 20px;
120
+ position: relative;
121
+ display: inline-flex;
122
+ align-items: center;
123
+ gap: 0.5rem;
124
+ }
125
+
126
+ .help-icon {
127
+ cursor: help;
128
+ color: #6c757d;
129
+ font-size: 0.875rem;
130
+ }
131
+
132
+ .help-tooltip {
133
+ visibility: hidden;
134
+ position: absolute;
135
+ bottom: 100%;
136
+ left: 50%;
137
+ transform: translateX(-50%);
138
+ background-color: #333;
139
+ color: white;
140
+ padding: 0.5rem;
141
+ border-radius: 4px;
142
+ font-size: 0.75rem;
143
+ width: max-content;
144
+ max-width: 200px;
145
+ z-index: 1000;
146
+ opacity: 0;
147
+ transition: opacity 0.2s;
148
+ }
149
+
150
+ .help-icon:hover + .help-tooltip {
151
+ visibility: visible;
152
+ opacity: 1;
153
  }
154
 
155
  /* Mobile-specific styles */
 
186
  .stat-label {
187
  font-size: 16px;
188
  }
189
+ }
wm_interactive/templates/index.html CHANGED
@@ -45,6 +45,16 @@
45
  <input type="number" class="form-control" id="ngram" value="1">
46
  <div class="form-text">Size of the n-gram window used for detection</div>
47
  </div>
 
 
 
 
 
 
 
 
 
 
48
  </div>
49
  <div class="modal-footer">
50
  <button type="button" class="btn btn-secondary" data-bs-dismiss="modal">Close</button>
@@ -73,19 +83,35 @@
73
  <div class="stats-container">
74
  <div>
75
  <div class="stat-value" id="tokenCount">0</div>
76
- <div class="stat-label">Tokens</div>
 
 
 
 
77
  </div>
78
  <div>
79
  <div class="stat-value" id="scoredTokens">0</div>
80
- <div class="stat-label">Scored Tokens</div>
 
 
 
 
81
  </div>
82
  <div>
83
  <div class="stat-value" id="finalScore">0.00</div>
84
- <div class="stat-label">Final Score</div>
 
 
 
 
85
  </div>
86
  <div>
87
  <div class="stat-value" id="pValue">0.500</div>
88
- <div class="stat-label">P-value</div>
 
 
 
 
89
  </div>
90
  </div>
91
  </div>
@@ -93,7 +119,7 @@
93
  <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"></script>
94
  <script>
95
  let debounceTimeout = null;
96
- let eventSource = null;
97
  const textarea = document.getElementById('user_text');
98
  const promptArea = document.getElementById('prompt_text');
99
  const generateBtn = document.getElementById('generateBtn');
@@ -107,6 +133,8 @@
107
  const seedInput = document.getElementById('seed');
108
  const ngramInput = document.getElementById('ngram');
109
  const detectorTypeSelect = document.getElementById('detectorType');
 
 
110
 
111
  function startGeneration() {
112
  const prompt = promptArea.value.trim();
@@ -119,11 +147,16 @@
119
  stopBtn.disabled = false;
120
  textarea.value = '';
121
 
 
 
 
122
  // Get current parameters
123
  const params = {
124
  detector_type: detectorTypeSelect.value,
125
  seed: parseInt(seedInput.value) || 0,
126
- ngram: parseInt(ngramInput.value) || 1
 
 
127
  };
128
 
129
  // Create headers for SSE
@@ -132,14 +165,15 @@
132
  'Accept': 'text/event-stream',
133
  });
134
 
135
- // Start fetch request
136
  fetch('/generate', {
137
  method: 'POST',
138
  headers: headers,
139
  body: JSON.stringify({
140
  prompt: prompt,
141
  params: params
142
- })
 
143
  }).then(response => {
144
  const reader = response.body.getReader();
145
  const decoder = new TextDecoder();
@@ -205,16 +239,25 @@
205
  return pump();
206
  })
207
  .catch(error => {
208
- console.error('Error:', error);
209
- alert('Error: Failed to generate text');
 
 
 
 
210
  })
211
  .finally(() => {
212
  generateBtn.disabled = false;
213
  stopBtn.disabled = true;
 
214
  });
215
  }
216
 
217
  function stopGeneration() {
 
 
 
 
218
  generateBtn.disabled = false;
219
  stopBtn.disabled = true;
220
  }
@@ -230,6 +273,8 @@
230
  // Validate parameters before sending
231
  const seed = parseInt(seedInput.value);
232
  const ngram = parseInt(ngramInput.value);
 
 
233
 
234
  const response = await fetch('/tokenize', {
235
  method: 'POST',
@@ -241,7 +286,9 @@
241
  params: {
242
  detector_type: detectorTypeSelect.value,
243
  seed: isNaN(seed) ? 0 : seed,
244
- ngram: isNaN(ngram) ? 1 : ngram
 
 
245
  }
246
  })
247
  });
@@ -262,7 +309,7 @@
262
  const score = data.scores[i];
263
  const pvalue = data.pvalues[i];
264
  const scoreDisplay = (score !== null && !isNaN(score)) ? score.toFixed(3) : 'N/A';
265
- const pvalueDisplay = (pvalue !== null && !isNaN(pvalue)) ? pvalue.toFixed(3) : 'N/A';
266
 
267
  return `<span class="token" style="background-color: ${data.colors[i]}">
268
  ${token}
@@ -279,7 +326,7 @@
279
  finalScore.textContent = (data.final_score !== null && !isNaN(data.final_score)) ?
280
  data.final_score.toFixed(2) : '0.00';
281
  pValue.textContent = (data.final_pvalue !== null && !isNaN(data.final_pvalue)) ?
282
- data.final_pvalue.toFixed(3) : '0.500';
283
 
284
  // Clear any previous error
285
  const existingError = tokenDisplay.querySelector('.alert-danger');
@@ -332,6 +379,28 @@
332
  debounceTimeout = setTimeout(updateTokenization, 500);
333
  });
334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  // Add keyboard shortcut for applying changes
336
  document.addEventListener('keydown', function(e) {
337
  if ((e.metaKey || e.ctrlKey) && e.key === 'Enter') {
@@ -369,6 +438,15 @@
369
  console.error('Error during initial tokenization:', error);
370
  });
371
  });
 
 
 
 
 
 
 
 
 
372
  </script>
373
  </body>
374
  </html>
 
45
  <input type="number" class="form-control" id="ngram" value="1">
46
  <div class="form-text">Size of the n-gram window used for detection</div>
47
  </div>
48
+ <div class="mb-3">
49
+ <label for="delta" class="form-label">Delta</label>
50
+ <input type="number" step="0.1" class="form-control" id="delta" value="2.0">
51
+ <div class="form-text">Bias added to greenlist tokens (for Maryland method)</div>
52
+ </div>
53
+ <div class="mb-3">
54
+ <label for="temperature" class="form-label">Temperature</label>
55
+ <input type="number" step="0.1" class="form-control" id="temperature" value="0.8">
56
+ <div class="form-text">Temperature for sampling (higher = more random)</div>
57
+ </div>
58
  </div>
59
  <div class="modal-footer">
60
  <button type="button" class="btn btn-secondary" data-bs-dismiss="modal">Close</button>
 
83
  <div class="stats-container">
84
  <div>
85
  <div class="stat-value" id="tokenCount">0</div>
86
+ <div class="stat-label">
87
+ Tokens
88
+ <i class="bi bi-question-circle help-icon"></i>
89
+ <span class="help-tooltip">Total number of tokens in the text</span>
90
+ </div>
91
  </div>
92
  <div>
93
  <div class="stat-value" id="scoredTokens">0</div>
94
+ <div class="stat-label">
95
+ Scored Tokens
96
+ <i class="bi bi-question-circle help-icon"></i>
97
+ <span class="help-tooltip">Number of tokens that were actually scored by the detector (excludes first n-gram tokens and duplicates)</span>
98
+ </div>
99
  </div>
100
  <div>
101
  <div class="stat-value" id="finalScore">0.00</div>
102
+ <div class="stat-label">
103
+ Final Score
104
+ <i class="bi bi-question-circle help-icon"></i>
105
+ <span class="help-tooltip">Cumulative score from all scored tokens. Higher values indicate more likely watermarked text</span>
106
+ </div>
107
  </div>
108
  <div>
109
  <div class="stat-value" id="pValue">0.500</div>
110
+ <div class="stat-label">
111
+ P-value
112
+ <i class="bi bi-question-circle help-icon"></i>
113
+ <span class="help-tooltip">Statistical significance of the score. Lower values indicate stronger evidence of watermarking (p < 0.05 is typically considered significant)</span>
114
+ </div>
115
  </div>
116
  </div>
117
  </div>
 
119
  <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"></script>
120
  <script>
121
  let debounceTimeout = null;
122
+ let abortController = null; // Add this line at the top with other variables
123
  const textarea = document.getElementById('user_text');
124
  const promptArea = document.getElementById('prompt_text');
125
  const generateBtn = document.getElementById('generateBtn');
 
133
  const seedInput = document.getElementById('seed');
134
  const ngramInput = document.getElementById('ngram');
135
  const detectorTypeSelect = document.getElementById('detectorType');
136
+ const deltaInput = document.getElementById('delta');
137
+ const temperatureInput = document.getElementById('temperature');
138
 
139
  function startGeneration() {
140
  const prompt = promptArea.value.trim();
 
147
  stopBtn.disabled = false;
148
  textarea.value = '';
149
 
150
+ // Create new AbortController for this request
151
+ abortController = new AbortController();
152
+
153
  // Get current parameters
154
  const params = {
155
  detector_type: detectorTypeSelect.value,
156
  seed: parseInt(seedInput.value) || 0,
157
+ ngram: parseInt(ngramInput.value) || 1,
158
+ delta: parseFloat(deltaInput.value) || 2.0,
159
+ temperature: parseFloat(temperatureInput.value) || 0.8
160
  };
161
 
162
  // Create headers for SSE
 
165
  'Accept': 'text/event-stream',
166
  });
167
 
168
+ // Start fetch request with signal
169
  fetch('/generate', {
170
  method: 'POST',
171
  headers: headers,
172
  body: JSON.stringify({
173
  prompt: prompt,
174
  params: params
175
+ }),
176
+ signal: abortController.signal // Add the abort signal
177
  }).then(response => {
178
  const reader = response.body.getReader();
179
  const decoder = new TextDecoder();
 
239
  return pump();
240
  })
241
  .catch(error => {
242
+ if (error.name === 'AbortError') {
243
+ console.log('Generation stopped by user');
244
+ } else {
245
+ console.error('Error:', error);
246
+ alert('Error: Failed to generate text');
247
+ }
248
  })
249
  .finally(() => {
250
  generateBtn.disabled = false;
251
  stopBtn.disabled = true;
252
+ abortController = null;
253
  });
254
  }
255
 
256
  function stopGeneration() {
257
+ if (abortController) {
258
+ abortController.abort();
259
+ abortController = null;
260
+ }
261
  generateBtn.disabled = false;
262
  stopBtn.disabled = true;
263
  }
 
273
  // Validate parameters before sending
274
  const seed = parseInt(seedInput.value);
275
  const ngram = parseInt(ngramInput.value);
276
+ const delta = parseFloat(deltaInput.value);
277
+ const temperature = parseFloat(temperatureInput.value);
278
 
279
  const response = await fetch('/tokenize', {
280
  method: 'POST',
 
286
  params: {
287
  detector_type: detectorTypeSelect.value,
288
  seed: isNaN(seed) ? 0 : seed,
289
+ ngram: isNaN(ngram) ? 1 : ngram,
290
+ delta: isNaN(delta) ? 2.0 : delta,
291
+ temperature: isNaN(temperature) ? 0.8 : temperature
292
  }
293
  })
294
  });
 
309
  const score = data.scores[i];
310
  const pvalue = data.pvalues[i];
311
  const scoreDisplay = (score !== null && !isNaN(score)) ? score.toFixed(3) : 'N/A';
312
+ const pvalueDisplay = (pvalue !== null && !isNaN(pvalue)) ? formatPValue(pvalue) : 'N/A';
313
 
314
  return `<span class="token" style="background-color: ${data.colors[i]}">
315
  ${token}
 
326
  finalScore.textContent = (data.final_score !== null && !isNaN(data.final_score)) ?
327
  data.final_score.toFixed(2) : '0.00';
328
  pValue.textContent = (data.final_pvalue !== null && !isNaN(data.final_pvalue)) ?
329
+ formatPValue(data.final_pvalue) : '0.500';
330
 
331
  // Clear any previous error
332
  const existingError = tokenDisplay.querySelector('.alert-danger');
 
379
  debounceTimeout = setTimeout(updateTokenization, 500);
380
  });
381
 
382
+ deltaInput.addEventListener('input', function() {
383
+ const value = this.value === '' ? '' : parseFloat(this.value);
384
+ if (isNaN(value) && this.value !== '') {
385
+ this.value = "2.0";
386
+ }
387
+ if (debounceTimeout) {
388
+ clearTimeout(debounceTimeout);
389
+ }
390
+ debounceTimeout = setTimeout(updateTokenization, 500);
391
+ });
392
+
393
+ temperatureInput.addEventListener('input', function() {
394
+ const value = this.value === '' ? '' : parseFloat(this.value);
395
+ if (isNaN(value) && this.value !== '') {
396
+ this.value = "0.8";
397
+ }
398
+ if (debounceTimeout) {
399
+ clearTimeout(debounceTimeout);
400
+ }
401
+ debounceTimeout = setTimeout(updateTokenization, 500);
402
+ });
403
+
404
  // Add keyboard shortcut for applying changes
405
  document.addEventListener('keydown', function(e) {
406
  if ((e.metaKey || e.ctrlKey) && e.key === 'Enter') {
 
438
  console.error('Error during initial tokenization:', error);
439
  });
440
  });
441
+
442
+ // Add this helper function for formatting p-values
443
+ function formatPValue(value) {
444
+ if (value >= 0.001) {
445
+ return value.toFixed(3);
446
+ } else {
447
+ return value.toExponential(2);
448
+ }
449
+ }
450
  </script>
451
  </body>
452
  </html>
wm_interactive/web/app.py CHANGED
@@ -146,6 +146,7 @@ def create_app():
146
 
147
  prompt = template_prompt(data.get('prompt', ''))
148
  params = data.get('params', {})
 
149
 
150
  def generate_stream():
151
  try:
@@ -155,7 +156,8 @@ def create_app():
155
  model=model,
156
  tokenizer=tokenizer,
157
  ngram=set_to_int(params.get('ngram', 1)),
158
- seed=set_to_int(params.get('seed', 0))
 
159
  )
160
 
161
  # Get special tokens to filter out
@@ -190,15 +192,16 @@ def create_app():
190
  )
191
 
192
  # Sample next token using the generator's sampling method
 
193
  aux = {
194
- 'ngram_tokens': tokens[:, cur_pos-generator.ngram:cur_pos],
195
  'cur_pos': cur_pos,
196
  }
197
  next_token = generator.sample_next(
198
  outputs.logits[:, -1, :],
199
  aux,
200
- temperature=0.8,
201
- top_p=0.95
202
  )
203
  # Check for EOS token
204
  if next_token == model.config.eos_token_id:
 
146
 
147
  prompt = template_prompt(data.get('prompt', ''))
148
  params = data.get('params', {})
149
+ temperature = float(params.get('temperature', 0.8))
150
 
151
  def generate_stream():
152
  try:
 
156
  model=model,
157
  tokenizer=tokenizer,
158
  ngram=set_to_int(params.get('ngram', 1)),
159
+ seed=set_to_int(params.get('seed', 0)),
160
+ delta=float(params.get('delta', 2.0)),
161
  )
162
 
163
  # Get special tokens to filter out
 
192
  )
193
 
194
  # Sample next token using the generator's sampling method
195
+ ngram_tokens = tokens[0, cur_pos-generator.ngram:cur_pos].tolist()
196
  aux = {
197
+ 'ngram_tokens': ngram_tokens,
198
  'cur_pos': cur_pos,
199
  }
200
  next_token = generator.sample_next(
201
  outputs.logits[:, -1, :],
202
  aux,
203
+ temperature=temperature,
204
+ top_p=0.9
205
  )
206
  # Check for EOS token
207
  if next_token == model.config.eos_token_id: