multimodalart HF Staff commited on
Commit
2491cbe
·
verified ·
1 Parent(s): a375a1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -127
app.py CHANGED
@@ -9,9 +9,11 @@ import time
9
  import re
10
  from typing import List, Dict, Tuple, Optional
11
  import torch.distributions as dists # Added import
 
12
 
13
  # --- START: Copied Helper functions from generation_utils.py ---
14
- # [Keep the copied functions: top_p_logits, top_k_logits, sample_tokens]
 
15
  def top_p_logits(logits, top_p=None):
16
  """ Applies top-p filtering to logits. """
17
  if top_p is None or top_p >= 1.0:
@@ -33,6 +35,8 @@ def top_k_logits(logits, top_k=None):
33
  if top_k is None or top_k <= 0:
34
  return logits
35
  top_k = min(top_k, logits.size(-1)) # Safety check
 
 
36
  # Remove all tokens with a probability less than the last token of the top-k
37
  indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
38
  logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
@@ -44,29 +48,36 @@ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confid
44
  # Prevent division by zero or negative temperatures
45
  safe_temp = max(temperature, 1e-6)
46
  logits = logits / safe_temp
47
- if top_p is not None and top_p < 1.0: # Apply top_p if valid
48
  logits = top_p_logits(logits, top_p)
49
  if top_k is not None and top_k > 0: # Apply top_k if valid
50
  logits = top_k_logits(logits, top_k)
51
 
52
- # Ensure logits are not all -inf after filtering, if so, sample uniformly? Or handle error.
53
- # Add a check here: if all logits are -inf, assign uniform probability.
54
- is_all_neg_inf = torch.all(logits == torch.finfo(logits.dtype).min, dim=-1, keepdim=True)
55
  if torch.any(is_all_neg_inf):
56
  # print("Warning: All logits became -inf after filtering. Assigning uniform probabilities.")
57
- uniform_logits = torch.zeros_like(logits)
58
  logits = torch.where(is_all_neg_inf, uniform_logits, logits)
59
 
60
  probs = torch.softmax(logits, dim=-1)
61
 
62
  # Clamp probabilities to avoid NaNs in sampling, ensure they sum to 1
63
  probs = torch.clamp(probs, min=0.0) # Ensure non-negative
64
- probs = probs / probs.sum(dim=-1, keepdim=True) # Re-normalize
 
 
 
65
  probs = torch.nan_to_num(probs, nan=0.0) # Handle any remaining NaNs
66
 
67
-
68
  if temperature > 0:
69
  try:
 
 
 
 
 
 
70
  x0 = dists.Categorical(probs=probs).sample()
71
  confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
72
  except Exception as e: # Catch broader exceptions during sampling
@@ -79,14 +90,14 @@ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confid
79
  sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
80
  # Ensure there are at least 2 probabilities to compare
81
  top1_probs = sorted_probs[..., 0]
82
- top2_probs = sorted_probs[..., 1] if sorted_probs.shape[-1] > 1 else top1_probs # Handle case with only 1 possible token
83
  confidence = top1_probs - top2_probs
84
 
85
  if neg_entropy:
86
- epsilon = 1e-10
87
  # Ensure probs are > 0 for log
88
- log_probs = torch.log(probs + epsilon)
89
- confidence = torch.sum(probs * log_probs, dim=-1) # Should be negative entropy
90
 
91
  # Ensure confidence is not NaN
92
  confidence = torch.nan_to_num(confidence, nan=0.0)
@@ -95,7 +106,7 @@ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confid
95
  # --- END: Copied Helper functions ---
96
 
97
 
98
- # [Keep model loading, constants, helper functions: parse_constraints, format_chat_history, apply_constraints_to_state]
99
  # Load model configuration to get special token IDs
100
  config = AutoConfig.from_pretrained("Dream-org/Dream-v0-Instruct-7B", trust_remote_code=True)
101
  # Use AutoModel for the base model loading, relying on trust_remote_code=True
@@ -139,34 +150,32 @@ SPECIAL_TOKEN_IDS = {PAD_ID, EOS_ID, MASK_ID}
139
  try:
140
  IM_START_ID = tokenizer.convert_tokens_to_ids("<|im_start|>")
141
  IM_END_ID = tokenizer.convert_tokens_to_ids("<|im_end|>")
142
- SPECIAL_TOKEN_IDS.add(IM_START_ID)
143
- SPECIAL_TOKEN_IDS.add(IM_END_ID)
144
  except KeyError:
145
  print("Warning: <|im_start|> or <|im_end|> not found in tokenizer vocab.")
146
  IM_START_ID = None
147
  IM_END_ID = None
148
 
149
 
150
- # --- Helper Functions ---
151
  def parse_constraints(constraints_text: str) -> Dict[int, List[int]]:
152
- """
153
- Parse constraints in format: 'position:word, position:word, ...'
154
- Returns a dictionary mapping the starting position (0-indexed from the start
155
- of the *generated* sequence) to a list of token IDs for the constraint word.
156
- """
157
  constraints = {}
158
  if not constraints_text:
159
  return constraints
160
 
 
161
  parts = constraints_text.split(',')
 
162
  for part in parts:
163
- part = part.strip() # Remove leading/trailing whitespace from the part itself
164
  if ':' not in part:
165
  continue
166
  pos_str, word = part.split(':', 1)
167
  try:
168
  pos = int(pos_str.strip())
169
- word = word.strip() # Strip whitespace from word
170
  token_ids = []
171
  if word: # Only encode if word is not empty
172
  # Add space prefix automatically if pos > 0 and word doesn't start with space
@@ -192,9 +201,10 @@ def format_chat_history(history: List[List[Optional[str]]]) -> List[Dict[str, st
192
  """ Formats chat history for the template. """
193
  messages = []
194
  for user_msg, assistant_msg in history:
195
- if user_msg:
196
  messages.append({"role": "user", "content": user_msg})
197
- if assistant_msg:
 
198
  messages.append({"role": "assistant", "content": assistant_msg})
199
  return messages
200
 
@@ -206,15 +216,16 @@ def apply_constraints_to_state(
206
  current_step: Optional[int] = None # For logging/debugging
207
  ) -> torch.Tensor:
208
  """ Applies constraints directly to the state tensor `x`. """
209
- modified_x = x # Modify in place maybe okay? Let's stick with clone for safety.
210
- modified_x = x.clone()
211
  for rel_pos, word_token_ids in parsed_constraints.items():
212
  abs_start_pos = prompt_length + rel_pos
213
  abs_end_pos = abs_start_pos + len(word_token_ids)
214
 
 
215
  if abs_start_pos < total_length and abs_end_pos <= total_length:
216
  try:
217
  constraint_tensor = torch.tensor(word_token_ids, dtype=torch.long, device=modified_x.device)
 
218
  modified_x[0, abs_start_pos:abs_end_pos] = constraint_tensor
219
  except IndexError:
220
  print(f"Warning (Step {current_step}): Constraint at {rel_pos} ('{tokenizer.decode(word_token_ids)}') goes out of bounds.")
@@ -228,7 +239,7 @@ def apply_constraints_to_state(
228
  @spaces.GPU # Decorator for Hugging Face Spaces GPU usage
229
  @torch.no_grad() # Ensure no gradients are computed during generation
230
  def generate_dream_response(
231
- history: List[List[Optional[str]]],
232
  gen_length: int,
233
  steps: int,
234
  constraints_text: str,
@@ -241,13 +252,13 @@ def generate_dream_response(
241
  ) -> List[Tuple[str, str]]:
242
  """ Generates text step-by-step and yields visualization states live. """
243
 
244
- if not history or not history[-1][0]:
245
- yield history, [("No input message found.", "red")], ""
246
  return
247
 
248
  # --- 1. Preparation ---
249
- last_user_message = history[-1][0]
250
- messages_for_template = format_chat_history(history) # Includes the latest user message
251
  parsed_constraints = parse_constraints(constraints_text)
252
 
253
  try:
@@ -255,46 +266,38 @@ def generate_dream_response(
255
  messages_for_template,
256
  return_tensors="pt",
257
  return_dict=True,
258
- add_generation_prompt=True
259
  )
260
  input_ids = inputs.input_ids.to(device)
261
- # Ensure prompt_attention_mask is also on the correct device
262
  prompt_attention_mask = inputs.attention_mask.to(device) if 'attention_mask' in inputs else torch.ones_like(input_ids)
263
  prompt_length = input_ids.shape[1]
264
  except Exception as e:
265
  print(f"Error applying chat template: {e}")
 
266
  yield history, [("Error preparing input.", "red")], ""
267
  return
268
 
269
  eps = 1e-3
270
- top_p_val = top_p if top_p is not None and 0.0 < top_p < 1.0 else None # Make sure top_p is > 0
271
  top_k_val = top_k if top_k is not None and top_k > 0 else None
272
- alg_temp_val = alg_temp if alg in ['maskgit_plus', 'topk_margin', 'entropy'] and alg_temp is not None and alg_temp > 0 else None # Ensure > 0
273
 
274
  # --- 2. Initialize Generation State ---
275
  total_length = prompt_length + gen_length
276
  initial_generation_part = torch.full((1, gen_length), MASK_ID, dtype=torch.long, device=device)
277
  x = torch.cat((input_ids, initial_generation_part), dim=1)
278
 
279
- # --- Prepare Attention Mask for SDPA ---
280
- generation_attention_mask = torch.ones((1, gen_length), dtype=torch.long, device=device)
281
- full_attention_mask_long = torch.cat((prompt_attention_mask, generation_attention_mask), dim=1) # Shape [B, N], dtype torch.long
282
 
283
- # Convert attention mask for SDPA: Needs float matching query dtype.
284
- # Where mask is 1 (attend), value should be 0.0. Where mask is 0 (don't attend), value should be -inf.
285
- attention_mask_for_model = full_attention_mask_long.to(model.dtype) # Convert to model's dtype (e.g., bfloat16)
286
- # Invert the mask logic: (1.0 - mask) gives 0s for attend, 1s for mask
287
- # Multiply by large negative number (min value for dtype) for masked positions
288
  large_neg_val = torch.finfo(model.dtype).min
289
  attention_mask_for_model = (1.0 - attention_mask_for_model) * large_neg_val
290
- # Ensure the shape is broadcastable, SDPA usually handles [B, N] -> [B, H, N, N] if needed.
291
- # However, explicitly making it [B, 1, 1, N] or [B, 1, N, N] can be safer.
292
- # Let's try passing [B, N] first, if it fails, reshape.
293
- # Reshape to [B, 1, 1, N] which is commonly expected for additive masks by HF models
294
- attention_mask_for_model = attention_mask_for_model.unsqueeze(1).unsqueeze(2)
295
- # Now shape is [B, 1, 1, N]
296
-
297
- # --- Timesteps ---
298
  timesteps = torch.linspace(1, eps, steps + 1, device=device)
299
 
300
  # Apply initial constraints
@@ -303,7 +306,8 @@ def generate_dream_response(
303
  # --- 3. Visualization Setup ---
304
  previous_tokens_vis = None
305
  final_response_text = ""
306
- history_copy = [list(item) for item in history] # Mutable copy
 
307
 
308
  # --- 4. Initial Yield (Masked State) ---
309
  initial_generated_tokens = x[0, prompt_length:].cpu()
@@ -314,6 +318,7 @@ def generate_dream_response(
314
  vis_data_initial.append((display_token, color))
315
 
316
  previous_tokens_vis = initial_generated_tokens
 
317
  yield history_copy, vis_data_initial, ""
318
  time.sleep(visualization_delay)
319
 
@@ -327,18 +332,21 @@ def generate_dream_response(
327
  break
328
 
329
  # --- Model Forward Pass ---
330
- # Pass the correctly formatted float mask
331
  outputs = model(
332
  input_ids=x,
333
  attention_mask=attention_mask_for_model, # Pass the [B, 1, 1, N] float mask
334
- position_ids=None,
335
  use_cache=False,
336
  return_dict=True
337
  )
338
  logits = outputs.logits
339
- logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1) # Align logits
340
 
341
- mask_logits = logits[mask_index]
 
 
 
 
 
342
  if mask_logits.numel() == 0:
343
  print(f"No masked tokens found for logit selection at step {i}. Stopping.")
344
  break
@@ -346,6 +354,7 @@ def generate_dream_response(
346
  # --- Sampling / Remasking Logic ---
347
  t = timesteps[i]
348
  s = timesteps[i + 1]
 
349
  x_new_masked_part = torch.full_like(x[mask_index], MASK_ID, device=device, dtype=torch.long)
350
 
351
  if alg == 'origin':
@@ -356,11 +365,13 @@ def generate_dream_response(
356
 
357
  if logits_to_sample.numel() > 0:
358
  _, sampled_tokens = sample_tokens(logits_to_sample, temperature=temperature, top_p=top_p_val, top_k=top_k_val)
 
359
  x_new_masked_part[transfer_indices_relative] = sampled_tokens
360
 
361
- else: # Confidence-based algorithms
362
  use_margin = (alg == 'topk_margin')
363
  use_entropy = (alg == 'entropy')
 
364
  confidence, x0_candidates = sample_tokens(
365
  mask_logits,
366
  temperature=temperature,
@@ -371,102 +382,95 @@ def generate_dream_response(
371
  )
372
 
373
  num_mask_token = mask_logits.shape[0]
 
374
  target_num_revealed_float = num_mask_token * (1.0 - s / t)
375
  number_transfer_tokens = int(target_num_revealed_float) if i < steps - 1 else num_mask_token
376
 
377
  if number_transfer_tokens > 0:
 
378
  num_samples = min(number_transfer_tokens, num_mask_token) # Ensure k <= num_mask_token
379
- if num_samples > 0: # Proceed only if we need to sample > 0 tokens
380
- if alg_temp_val is None or alg_temp_val <= 0: # Top-k confidence
381
- sort_metric = confidence if alg != 'entropy' else -confidence # Lower entropy = higher confidence
 
 
382
  # Ensure k is not greater than the number of elements
383
  k_topk = min(num_samples, sort_metric.numel())
384
  if k_topk > 0:
385
  _, transfer_indices_relative = torch.topk(sort_metric, k=k_topk)
386
- else:
387
- transfer_indices_relative = torch.tensor([], dtype=torch.long, device=device)
388
 
389
  else: # Sample based on confidence temperature
390
  # Ensure confidence has elements before processing
391
  if confidence.numel() > 0:
392
  conf_probs = confidence / alg_temp_val
393
  # Handle potential inf/-inf before softmax, ensure non-negative probabilities
394
- conf_probs = torch.nan_to_num(conf_probs, nan=0.0, posinf=1e9, neginf=-1e9) # Use large numbers instead of inf
395
- conf_probs = torch.clamp(conf_probs - conf_probs.max(), min=-30) # Prevent large positive values leading to inf in exp
 
396
  conf_probs = F.softmax(conf_probs, dim=-1)
397
  conf_probs = torch.clamp(conf_probs, min=0.0) # Ensure non-negative
398
  conf_probs = torch.nan_to_num(conf_probs, nan=0.0) # Handle NaNs
399
 
400
- # Normalize probabilities if they don't sum to 1
401
  prob_sum = conf_probs.sum()
402
- # --- START FIX ---
403
- # Ensure the comparison tensor has the same dtype as prob_sum
404
  target_sum_tensor = torch.tensor(1.0, device=device, dtype=prob_sum.dtype)
405
  if not torch.isclose(prob_sum, target_sum_tensor, atol=1e-4) and prob_sum > 0:
406
- # --- END FIX ---
407
- # print(f"Warning step {i}: Confidence probabilities sum {prob_sum:.4f} != 1. Re-normalizing.")
408
- # Avoid division by zero if prob_sum is extremely small or zero
409
  safe_prob_sum = torch.max(prob_sum, torch.tensor(1e-12, device=device, dtype=prob_sum.dtype))
410
- conf_probs = conf_probs / safe_prob_sum # Use safe_prob_sum
411
 
412
- # Ensure num_samples is valid and probabilities are okay for multinomial
413
- # --- START FIX ---
414
- # Check sum again after potential normalization
415
  final_prob_sum_check = conf_probs.sum()
416
  if conf_probs.numel() > 0 and num_samples > 0 and torch.all(conf_probs >= 0) and torch.isclose(final_prob_sum_check, target_sum_tensor, atol=1e-4):
417
- # --- END FIX ---
418
  try:
419
  transfer_indices_relative = torch.multinomial(conf_probs, num_samples=num_samples, replacement=False)
420
  except RuntimeError as e:
421
- # [Fallback logic remains the same]
422
  print(f"Warning step {i}: Multinomial sampling failed ('{e}'). Falling back to top-k.")
 
423
  sort_metric = confidence if alg != 'entropy' else -confidence
424
  k_multinomial_fallback = min(num_samples, sort_metric.numel())
425
  if k_multinomial_fallback > 0:
426
  _, transfer_indices_relative = torch.topk(sort_metric, k=k_multinomial_fallback)
427
- else:
428
- transfer_indices_relative = torch.tensor([], dtype=torch.long, device=device)
429
  else: # Handle cases where multinomial is not possible (e.g., bad probabilities)
430
- # [Fallback logic remains the same]
431
  # print(f"Warning step {i}: Invalid probabilities for multinomial sampling (sum={final_prob_sum_check:.4f}). Falling back to top-k.")
432
  sort_metric = confidence if alg != 'entropy' else -confidence
433
  k_multinomial_fallback = min(num_samples, sort_metric.numel())
434
  if k_multinomial_fallback > 0:
435
  _, transfer_indices_relative = torch.topk(sort_metric, k=k_multinomial_fallback)
436
- else:
437
- transfer_indices_relative = torch.tensor([], dtype=torch.long, device=device)
438
- else: # No confidence values to sample from
439
- transfer_indices_relative = torch.tensor([], dtype=torch.long, device=device)
440
 
441
- # Apply the transfer
442
  if transfer_indices_relative.numel() > 0:
443
- # Ensure indices are within bounds of x0_candidates
444
- valid_indices = transfer_indices_relative < x0_candidates.shape[0]
445
- valid_transfer_indices = transfer_indices_relative[valid_indices]
 
 
 
 
446
 
447
  if valid_transfer_indices.numel() > 0:
448
- # Ensure indices are also within bounds of x_new_masked_part
449
- if valid_transfer_indices.max() < x_new_masked_part.shape[0]:
450
- x_new_masked_part[valid_transfer_indices] = x0_candidates[valid_transfer_indices].clone()
451
- else:
452
- print(f"Warning step {i}: transfer_indices out of bounds for x_new_masked_part.")
453
 
454
  # Update the global state `x` only at the masked positions
455
  x[mask_index] = x_new_masked_part
456
 
457
  # --- Apply Constraints ---
 
458
  x = apply_constraints_to_state(x, prompt_length, total_length, parsed_constraints, current_step=i)
459
 
460
  # --- Yield Visualization ---
461
- current_generated_tokens = x[0, prompt_length:].cpu()
462
  vis_data = []
463
- # [Keep visualization formatting logic the same]
464
  for j in range(gen_length):
465
  current_tok_id = current_generated_tokens[j].item()
 
466
  previous_tok_id = previous_tokens_vis[j].item() if previous_tokens_vis is not None and j < len(previous_tokens_vis) else MASK_ID
467
 
468
  try:
469
- # Use replace to handle potential bytes rendering issues
470
  decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False, clean_up_tokenization_spaces=False)
471
  display_token = MASK_TOKEN if current_tok_id == MASK_ID else decoded_token
472
  except Exception:
@@ -482,17 +486,25 @@ def generate_dream_response(
482
  else: # Token was already revealed
483
  color = "#6699CC" # Light Blue
484
 
485
- should_hide = (PAD_ID is not None and current_tok_id == PAD_ID) or \
486
- (EOS_ID is not None and current_tok_id == EOS_ID)
 
 
 
 
 
 
487
  if should_hide and previous_tok_id == current_tok_id:
488
  token_to_display = "" # Hide by making empty
489
  color = None # No color for hidden
490
 
491
- if token_to_display:
492
  vis_data.append((token_to_display, color))
493
 
494
- previous_tokens_vis = current_generated_tokens # Update for next step
 
495
 
 
496
  intermediate_response_tokens = x[0, prompt_length:]
497
  intermediate_response_text = tokenizer.decode(
498
  intermediate_response_tokens,
@@ -500,6 +512,11 @@ def generate_dream_response(
500
  clean_up_tokenization_spaces=True
501
  ).strip()
502
 
 
 
 
 
 
503
  yield history_copy, vis_data, intermediate_response_text
504
  time.sleep(visualization_delay)
505
 
@@ -514,11 +531,14 @@ def generate_dream_response(
514
  skip_special_tokens=True,
515
  clean_up_tokenization_spaces=True
516
  ).strip()
517
- history_copy[-1][1] = final_response_text
518
 
 
 
 
 
 
519
  final_generated_tokens = x[0, prompt_length:].cpu()
520
  vis_data_final = []
521
- # [Keep final visualization formatting logic the same]
522
  for j in range(gen_length):
523
  current_tok_id = final_generated_tokens[j].item()
524
  previous_tok_id = previous_tokens_vis[j].item() if previous_tokens_vis is not None and j < len(previous_tokens_vis) else MASK_ID
@@ -532,24 +552,29 @@ def generate_dream_response(
532
  if current_tok_id == MASK_ID: color = "#444444"
533
  elif previous_tok_id == MASK_ID: color = "#66CC66"
534
  else: color = "#6699CC"
535
- should_hide = (PAD_ID is not None and current_tok_id == PAD_ID) or \
536
- (EOS_ID is not None and current_tok_id == EOS_ID)
 
 
 
 
537
  if should_hide and previous_tok_id == current_tok_id:
538
  token_to_display = ""; color = None
539
  if token_to_display: vis_data_final.append((token_to_display, color))
540
 
 
541
  yield history_copy, vis_data_final, final_response_text
542
  print("Visualization streaming complete.")
543
 
544
  except Exception as e:
545
- print(f"Error during generation or processing: {e}")
546
- import traceback
547
  traceback.print_exc()
 
548
  yield history_copy, [("Error during generation.", "red")], ""
549
  return
550
 
551
 
552
- # --- Gradio UI (No changes needed here) ---
553
  css = '''
554
  .category-legend{display:none}
555
  button{min-height: 60px}
@@ -562,8 +587,10 @@ def create_chatbot_demo():
562
  "[[Blog](https://hkunlp.github.io/blog/2025/dream/)]" # Note: Link might be hypothetical
563
  )
564
 
 
565
  _chat_history_store = gr.State([]) # Hidden state to store actual history list
566
 
 
567
  with gr.Row():
568
  with gr.Column(scale=3):
569
  chatbot_ui = gr.Chatbot(
@@ -594,15 +621,15 @@ def create_chatbot_demo():
594
  label="Denoising Process Visualization",
595
  combine_adjacent=False,
596
  show_legend=True,
597
- interactive=False
598
  )
599
  response_text_display = gr.Textbox(
600
  label="Generated Response",
601
  interactive=False,
602
- lines=5,
603
- visible=False
604
  )
605
 
 
606
  with gr.Accordion("Generation Settings", open=False):
607
  with gr.Row():
608
  gen_length = gr.Slider(minimum=16, maximum=512, value=128, step=8, label="Max New Tokens")
@@ -611,58 +638,92 @@ def create_chatbot_demo():
611
  temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.05, label="Temperature (0 = greedy)")
612
  alg_temp = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="Remasking Temp (Confidence Algs)")
613
  with gr.Row():
614
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="Top-P (0 disables)")
615
- top_k = gr.Slider(minimum=0, maximum=200, value=0, step=5, label="Top-K (0 disables)")
 
616
  with gr.Row():
617
  remasking_strategy = gr.Radio(choices=['origin', 'maskgit_plus', 'topk_margin', 'entropy'], value='entropy', label="Remasking Strategy (Algorithm)")
618
  with gr.Row():
619
- visualization_delay = gr.Slider(minimum=0.0, maximum=0.5, value=0.0, step=0.01, label="Visualization Delay (seconds)")
620
 
 
621
  clear_btn = gr.Button("Clear Conversation")
622
 
 
 
623
  def add_user_message_to_history(message: str, history_store: List[List[Optional[str]]]):
 
624
  if not message.strip():
625
  gr.Warning("Please enter a message.")
626
- return history_store, history_store, "", [], ""
627
- history_store.append([message, None])
628
- return history_store, history_store, "", [], ""
 
 
 
 
 
 
629
 
630
  def clear_conversation():
631
- return [], [], "", [], ""
 
 
 
 
 
632
 
 
633
  generation_inputs = [
634
  _chat_history_store, gen_length, steps, constraints_input,
635
  temperature, top_p, top_k, remasking_strategy, alg_temp,
636
  visualization_delay
637
  ]
 
638
  generation_outputs = [chatbot_ui, output_vis, response_text_display]
639
 
 
 
 
 
 
 
 
 
 
640
  submit_listener = user_input.submit(
641
  fn=add_user_message_to_history,
642
  inputs=[user_input, _chat_history_store],
643
- outputs=[_chat_history_store, chatbot_ui, user_input, output_vis, response_text_display]
 
644
  ).then(
645
  fn=generate_dream_response,
646
- inputs=generation_inputs,
647
- outputs=generation_outputs,
648
- show_progress="hidden"
 
649
  )
650
 
 
651
  click_listener = send_btn.click(
652
  fn=add_user_message_to_history,
653
  inputs=[user_input, _chat_history_store],
654
- outputs=[_chat_history_store, chatbot_ui, user_input, output_vis, response_text_display]
 
655
  ).then(
656
  fn=generate_dream_response,
657
- inputs=generation_inputs,
658
- outputs=generation_outputs,
659
- show_progress="hidden"
 
660
  )
661
 
 
662
  clear_btn.click(
663
  clear_conversation,
664
  inputs=[],
665
- outputs=[_chat_history_store, chatbot_ui, user_input, output_vis, response_text_display]
 
666
  )
667
 
668
  return demo
@@ -670,4 +731,5 @@ def create_chatbot_demo():
670
  # --- Launch ---
671
  if __name__ == "__main__":
672
  demo = create_chatbot_demo()
673
- demo.queue().launch(debug=True, share=False)
 
 
9
  import re
10
  from typing import List, Dict, Tuple, Optional
11
  import torch.distributions as dists # Added import
12
+ import traceback # For printing exceptions
13
 
14
  # --- START: Copied Helper functions from generation_utils.py ---
15
+ # These are needed because we are reimplementing the sampling loop locally.
16
+
17
  def top_p_logits(logits, top_p=None):
18
  """ Applies top-p filtering to logits. """
19
  if top_p is None or top_p >= 1.0:
 
35
  if top_k is None or top_k <= 0:
36
  return logits
37
  top_k = min(top_k, logits.size(-1)) # Safety check
38
+ if top_k == logits.size(-1): # Avoid unnecessary computation if k is full size
39
+ return logits
40
  # Remove all tokens with a probability less than the last token of the top-k
41
  indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
42
  logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
 
48
  # Prevent division by zero or negative temperatures
49
  safe_temp = max(temperature, 1e-6)
50
  logits = logits / safe_temp
51
+ if top_p is not None and 0.0 < top_p < 1.0: # Apply top_p if valid (and not disabled)
52
  logits = top_p_logits(logits, top_p)
53
  if top_k is not None and top_k > 0: # Apply top_k if valid
54
  logits = top_k_logits(logits, top_k)
55
 
56
+ # Ensure logits are not all -inf after filtering, if so, assign uniform probability.
57
+ is_all_neg_inf = torch.all(logits <= torch.finfo(logits.dtype).min, dim=-1, keepdim=True)
 
58
  if torch.any(is_all_neg_inf):
59
  # print("Warning: All logits became -inf after filtering. Assigning uniform probabilities.")
60
+ uniform_logits = torch.zeros_like(logits) # Uniform logits (zeros before softmax)
61
  logits = torch.where(is_all_neg_inf, uniform_logits, logits)
62
 
63
  probs = torch.softmax(logits, dim=-1)
64
 
65
  # Clamp probabilities to avoid NaNs in sampling, ensure they sum to 1
66
  probs = torch.clamp(probs, min=0.0) # Ensure non-negative
67
+ prob_sum_for_norm = probs.sum(dim=-1, keepdim=True)
68
+ # Use a tolerance check for division
69
+ safe_prob_sum_for_norm = torch.where(prob_sum_for_norm > 1e-12, prob_sum_for_norm, torch.ones_like(prob_sum_for_norm))
70
+ probs = probs / safe_prob_sum_for_norm # Re-normalize with safe denominator
71
  probs = torch.nan_to_num(probs, nan=0.0) # Handle any remaining NaNs
72
 
 
73
  if temperature > 0:
74
  try:
75
+ # Ensure probs sum to 1 before sampling
76
+ probs_sum_check = probs.sum(dim=-1)
77
+ if not torch.all(torch.isclose(probs_sum_check, torch.ones_like(probs_sum_check))):
78
+ # print(f"Warning: Probs do not sum to 1 before sampling ({probs_sum_check}). Re-normalizing.")
79
+ probs = probs / probs.sum(dim=-1, keepdim=True) # Final normalization attempt
80
+
81
  x0 = dists.Categorical(probs=probs).sample()
82
  confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
83
  except Exception as e: # Catch broader exceptions during sampling
 
90
  sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
91
  # Ensure there are at least 2 probabilities to compare
92
  top1_probs = sorted_probs[..., 0]
93
+ top2_probs = sorted_probs[..., 1] if sorted_probs.shape[-1] > 1 else torch.zeros_like(top1_probs) # Use 0 if only one prob
94
  confidence = top1_probs - top2_probs
95
 
96
  if neg_entropy:
97
+ epsilon = torch.finfo(probs.dtype).eps # Use dtype's epsilon
98
  # Ensure probs are > 0 for log
99
+ log_probs = torch.log(torch.clamp(probs, min=epsilon)) # Clamp before log
100
+ confidence = torch.sum(probs * log_probs, dim=-1) # This is negative entropy
101
 
102
  # Ensure confidence is not NaN
103
  confidence = torch.nan_to_num(confidence, nan=0.0)
 
106
  # --- END: Copied Helper functions ---
107
 
108
 
109
+ # --- Model Loading and Constants ---
110
  # Load model configuration to get special token IDs
111
  config = AutoConfig.from_pretrained("Dream-org/Dream-v0-Instruct-7B", trust_remote_code=True)
112
  # Use AutoModel for the base model loading, relying on trust_remote_code=True
 
150
  try:
151
  IM_START_ID = tokenizer.convert_tokens_to_ids("<|im_start|>")
152
  IM_END_ID = tokenizer.convert_tokens_to_ids("<|im_end|>")
153
+ if IM_START_ID is not None: SPECIAL_TOKEN_IDS.add(IM_START_ID)
154
+ if IM_END_ID is not None: SPECIAL_TOKEN_IDS.add(IM_END_ID)
155
  except KeyError:
156
  print("Warning: <|im_start|> or <|im_end|> not found in tokenizer vocab.")
157
  IM_START_ID = None
158
  IM_END_ID = None
159
 
160
 
161
+ # --- App Helper Functions ---
162
  def parse_constraints(constraints_text: str) -> Dict[int, List[int]]:
163
+ """ Parses constraints. """
 
 
 
 
164
  constraints = {}
165
  if not constraints_text:
166
  return constraints
167
 
168
+ # Simple split on comma, assumes format 'pos:word, pos:word'
169
  parts = constraints_text.split(',')
170
+
171
  for part in parts:
172
+ part = part.strip()
173
  if ':' not in part:
174
  continue
175
  pos_str, word = part.split(':', 1)
176
  try:
177
  pos = int(pos_str.strip())
178
+ word = word.strip()
179
  token_ids = []
180
  if word: # Only encode if word is not empty
181
  # Add space prefix automatically if pos > 0 and word doesn't start with space
 
201
  """ Formats chat history for the template. """
202
  messages = []
203
  for user_msg, assistant_msg in history:
204
+ if user_msg is not None: # Check for None explicitly
205
  messages.append({"role": "user", "content": user_msg})
206
+ # Add assistant message only if it exists (it won't for the last turn before generation)
207
+ if assistant_msg is not None:
208
  messages.append({"role": "assistant", "content": assistant_msg})
209
  return messages
210
 
 
216
  current_step: Optional[int] = None # For logging/debugging
217
  ) -> torch.Tensor:
218
  """ Applies constraints directly to the state tensor `x`. """
219
+ modified_x = x.clone() # Work on a copy
 
220
  for rel_pos, word_token_ids in parsed_constraints.items():
221
  abs_start_pos = prompt_length + rel_pos
222
  abs_end_pos = abs_start_pos + len(word_token_ids)
223
 
224
+ # Ensure the constraint fits within the generation length
225
  if abs_start_pos < total_length and abs_end_pos <= total_length:
226
  try:
227
  constraint_tensor = torch.tensor(word_token_ids, dtype=torch.long, device=modified_x.device)
228
+ # Force the constraint tokens onto the sequence
229
  modified_x[0, abs_start_pos:abs_end_pos] = constraint_tensor
230
  except IndexError:
231
  print(f"Warning (Step {current_step}): Constraint at {rel_pos} ('{tokenizer.decode(word_token_ids)}') goes out of bounds.")
 
239
  @spaces.GPU # Decorator for Hugging Face Spaces GPU usage
240
  @torch.no_grad() # Ensure no gradients are computed during generation
241
  def generate_dream_response(
242
+ history: List[List[Optional[str]]], # Receives the latest state from _chat_history_store
243
  gen_length: int,
244
  steps: int,
245
  constraints_text: str,
 
252
  ) -> List[Tuple[str, str]]:
253
  """ Generates text step-by-step and yields visualization states live. """
254
 
255
+ if not history or history[-1][0] is None: # Check if last user message is None or missing
256
+ yield history, [("Internal Error: History state invalid.", "red")], ""
257
  return
258
 
259
  # --- 1. Preparation ---
260
+ # History already contains the latest user message and None for the bot response
261
+ messages_for_template = format_chat_history(history)
262
  parsed_constraints = parse_constraints(constraints_text)
263
 
264
  try:
 
266
  messages_for_template,
267
  return_tensors="pt",
268
  return_dict=True,
269
+ add_generation_prompt=True # Creates the '<|im_start|>assistant\n' prompt
270
  )
271
  input_ids = inputs.input_ids.to(device)
272
+ # Ensure prompt_attention_mask is also on the correct device and handle missing mask
273
  prompt_attention_mask = inputs.attention_mask.to(device) if 'attention_mask' in inputs else torch.ones_like(input_ids)
274
  prompt_length = input_ids.shape[1]
275
  except Exception as e:
276
  print(f"Error applying chat template: {e}")
277
+ # Yield current history (with None), error message, empty text
278
  yield history, [("Error preparing input.", "red")], ""
279
  return
280
 
281
  eps = 1e-3
282
+ top_p_val = top_p if top_p is not None and 0.0 < top_p < 1.0 else None
283
  top_k_val = top_k if top_k is not None and top_k > 0 else None
284
+ alg_temp_val = alg_temp if alg in ['maskgit_plus', 'topk_margin', 'entropy'] and alg_temp is not None and alg_temp > 0 else None
285
 
286
  # --- 2. Initialize Generation State ---
287
  total_length = prompt_length + gen_length
288
  initial_generation_part = torch.full((1, gen_length), MASK_ID, dtype=torch.long, device=device)
289
  x = torch.cat((input_ids, initial_generation_part), dim=1)
290
 
291
+ # Prepare attention mask for SDPA (float format)
292
+ generation_attention_mask = torch.ones((1, gen_length), dtype=prompt_attention_mask.dtype, device=device) # Match dtype
293
+ full_attention_mask_long = torch.cat((prompt_attention_mask, generation_attention_mask), dim=1) # Shape [B, N]
294
 
295
+ attention_mask_for_model = full_attention_mask_long.to(model.dtype) # Convert to model's float dtype
 
 
 
 
296
  large_neg_val = torch.finfo(model.dtype).min
297
  attention_mask_for_model = (1.0 - attention_mask_for_model) * large_neg_val
298
+ attention_mask_for_model = attention_mask_for_model.unsqueeze(1).unsqueeze(2) # Shape [B, 1, 1, N]
299
+
300
+ # Timesteps
 
 
 
 
 
301
  timesteps = torch.linspace(1, eps, steps + 1, device=device)
302
 
303
  # Apply initial constraints
 
306
  # --- 3. Visualization Setup ---
307
  previous_tokens_vis = None
308
  final_response_text = ""
309
+ # Work on a copy of the history list received as input
310
+ history_copy = [list(item) for item in history]
311
 
312
  # --- 4. Initial Yield (Masked State) ---
313
  initial_generated_tokens = x[0, prompt_length:].cpu()
 
318
  vis_data_initial.append((display_token, color))
319
 
320
  previous_tokens_vis = initial_generated_tokens
321
+ # Yield the initial history copy (with None placeholder), initial vis, empty text
322
  yield history_copy, vis_data_initial, ""
323
  time.sleep(visualization_delay)
324
 
 
332
  break
333
 
334
  # --- Model Forward Pass ---
 
335
  outputs = model(
336
  input_ids=x,
337
  attention_mask=attention_mask_for_model, # Pass the [B, 1, 1, N] float mask
338
+ position_ids=None, # Let model compute default positions
339
  use_cache=False,
340
  return_dict=True
341
  )
342
  logits = outputs.logits
 
343
 
344
+ # Align logits with the token positions they predict (logits[t] predicts token[t+1])
345
+ # Shift left, effectively aligning logits[t] with inputs[t]
346
+ logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
347
+
348
+ # Select logits for masked positions
349
+ mask_logits = logits[mask_index] # Shape [num_masked_tokens, V]
350
  if mask_logits.numel() == 0:
351
  print(f"No masked tokens found for logit selection at step {i}. Stopping.")
352
  break
 
354
  # --- Sampling / Remasking Logic ---
355
  t = timesteps[i]
356
  s = timesteps[i + 1]
357
+ # Initialize the update tensor for masked positions with MASK_ID
358
  x_new_masked_part = torch.full_like(x[mask_index], MASK_ID, device=device, dtype=torch.long)
359
 
360
  if alg == 'origin':
 
365
 
366
  if logits_to_sample.numel() > 0:
367
  _, sampled_tokens = sample_tokens(logits_to_sample, temperature=temperature, top_p=top_p_val, top_k=top_k_val)
368
+ # Place sampled tokens into the correct positions within the masked part update
369
  x_new_masked_part[transfer_indices_relative] = sampled_tokens
370
 
371
+ else: # Confidence-based algorithms ('maskgit_plus', 'topk_margin', 'entropy')
372
  use_margin = (alg == 'topk_margin')
373
  use_entropy = (alg == 'entropy')
374
+ # Sample candidates and get confidence for all masked positions
375
  confidence, x0_candidates = sample_tokens(
376
  mask_logits,
377
  temperature=temperature,
 
382
  )
383
 
384
  num_mask_token = mask_logits.shape[0]
385
+ # Calculate target number of tokens to reveal in this step
386
  target_num_revealed_float = num_mask_token * (1.0 - s / t)
387
  number_transfer_tokens = int(target_num_revealed_float) if i < steps - 1 else num_mask_token
388
 
389
  if number_transfer_tokens > 0:
390
+ # Determine which tokens to reveal based on confidence
391
  num_samples = min(number_transfer_tokens, num_mask_token) # Ensure k <= num_mask_token
392
+ if num_samples > 0:
393
+ transfer_indices_relative = torch.tensor([], dtype=torch.long, device=device) # Initialize empty
394
+ if alg_temp_val is None or alg_temp_val <= 0: # Use top-k confidence sorting
395
+ # Sort by confidence (higher is better, except for entropy where lower is better)
396
+ sort_metric = confidence if alg != 'entropy' else -confidence
397
  # Ensure k is not greater than the number of elements
398
  k_topk = min(num_samples, sort_metric.numel())
399
  if k_topk > 0:
400
  _, transfer_indices_relative = torch.topk(sort_metric, k=k_topk)
 
 
401
 
402
  else: # Sample based on confidence temperature
403
  # Ensure confidence has elements before processing
404
  if confidence.numel() > 0:
405
  conf_probs = confidence / alg_temp_val
406
  # Handle potential inf/-inf before softmax, ensure non-negative probabilities
407
+ conf_probs = torch.nan_to_num(conf_probs, nan=0.0, posinf=1e9, neginf=-1e9)
408
+ # Clamp to prevent large positive values causing overflow in exp
409
+ conf_probs = torch.clamp(conf_probs - conf_probs.max(), min=-30) # Softmax is invariant to shift
410
  conf_probs = F.softmax(conf_probs, dim=-1)
411
  conf_probs = torch.clamp(conf_probs, min=0.0) # Ensure non-negative
412
  conf_probs = torch.nan_to_num(conf_probs, nan=0.0) # Handle NaNs
413
 
414
+ # Normalize probabilities if they don't sum to 1 (within tolerance)
415
  prob_sum = conf_probs.sum()
 
 
416
  target_sum_tensor = torch.tensor(1.0, device=device, dtype=prob_sum.dtype)
417
  if not torch.isclose(prob_sum, target_sum_tensor, atol=1e-4) and prob_sum > 0:
 
 
 
418
  safe_prob_sum = torch.max(prob_sum, torch.tensor(1e-12, device=device, dtype=prob_sum.dtype))
419
+ conf_probs = conf_probs / safe_prob_sum
420
 
421
+ # Check if probabilities are valid for multinomial sampling
 
 
422
  final_prob_sum_check = conf_probs.sum()
423
  if conf_probs.numel() > 0 and num_samples > 0 and torch.all(conf_probs >= 0) and torch.isclose(final_prob_sum_check, target_sum_tensor, atol=1e-4):
 
424
  try:
425
  transfer_indices_relative = torch.multinomial(conf_probs, num_samples=num_samples, replacement=False)
426
  except RuntimeError as e:
 
427
  print(f"Warning step {i}: Multinomial sampling failed ('{e}'). Falling back to top-k.")
428
+ # Fallback to top-k if multinomial fails
429
  sort_metric = confidence if alg != 'entropy' else -confidence
430
  k_multinomial_fallback = min(num_samples, sort_metric.numel())
431
  if k_multinomial_fallback > 0:
432
  _, transfer_indices_relative = torch.topk(sort_metric, k=k_multinomial_fallback)
 
 
433
  else: # Handle cases where multinomial is not possible (e.g., bad probabilities)
 
434
  # print(f"Warning step {i}: Invalid probabilities for multinomial sampling (sum={final_prob_sum_check:.4f}). Falling back to top-k.")
435
  sort_metric = confidence if alg != 'entropy' else -confidence
436
  k_multinomial_fallback = min(num_samples, sort_metric.numel())
437
  if k_multinomial_fallback > 0:
438
  _, transfer_indices_relative = torch.topk(sort_metric, k=k_multinomial_fallback)
 
 
 
 
439
 
440
+ # Apply the transfer using the selected indices, with safety checks
441
  if transfer_indices_relative.numel() > 0:
442
+ # Bounds check before indexing
443
+ max_cand_idx = x0_candidates.shape[0] - 1
444
+ max_mask_idx = x_new_masked_part.shape[0] - 1
445
+ valid_indices_mask = (transfer_indices_relative >= 0) & \
446
+ (transfer_indices_relative <= max_cand_idx) & \
447
+ (transfer_indices_relative <= max_mask_idx)
448
+ valid_transfer_indices = transfer_indices_relative[valid_indices_mask]
449
 
450
  if valid_transfer_indices.numel() > 0:
451
+ x_new_masked_part[valid_transfer_indices] = x0_candidates[valid_transfer_indices].clone()
452
+ # else:
453
+ # if transfer_indices_relative.numel() > 0: # Only warn if there were indices initially
454
+ # print(f"Warning step {i}: No valid transfer indices after bounds check.")
455
+
456
 
457
  # Update the global state `x` only at the masked positions
458
  x[mask_index] = x_new_masked_part
459
 
460
  # --- Apply Constraints ---
461
+ # Constraints should be applied *after* sampling/revealing tokens for the step
462
  x = apply_constraints_to_state(x, prompt_length, total_length, parsed_constraints, current_step=i)
463
 
464
  # --- Yield Visualization ---
465
+ current_generated_tokens = x[0, prompt_length:].cpu() # Get generated part, move to CPU
466
  vis_data = []
 
467
  for j in range(gen_length):
468
  current_tok_id = current_generated_tokens[j].item()
469
+ # Ensure previous_tokens_vis exists and index is valid
470
  previous_tok_id = previous_tokens_vis[j].item() if previous_tokens_vis is not None and j < len(previous_tokens_vis) else MASK_ID
471
 
472
  try:
473
+ # Use replace='�' to handle potential bytes rendering issues in Gradio HighlightedText
474
  decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False, clean_up_tokenization_spaces=False)
475
  display_token = MASK_TOKEN if current_tok_id == MASK_ID else decoded_token
476
  except Exception:
 
486
  else: # Token was already revealed
487
  color = "#6699CC" # Light Blue
488
 
489
+ # Hide special tokens (PAD/EOS) if they were already revealed (LLaDA effect)
490
+ # Ensure PAD_ID and EOS_ID are not None before checking
491
+ should_hide = False
492
+ if PAD_ID is not None and current_tok_id == PAD_ID: should_hide = True
493
+ if EOS_ID is not None and current_tok_id == EOS_ID: should_hide = True
494
+ # Special check: If PAD and EOS are the same, only hide if it's that ID
495
+ if PAD_ID == EOS_ID and PAD_ID is not None and current_tok_id == PAD_ID: should_hide = True
496
+
497
  if should_hide and previous_tok_id == current_tok_id:
498
  token_to_display = "" # Hide by making empty
499
  color = None # No color for hidden
500
 
501
+ if token_to_display: # Avoid adding empty strings if hiding
502
  vis_data.append((token_to_display, color))
503
 
504
+ # Update previous state for the next iteration's color logic
505
+ previous_tokens_vis = current_generated_tokens
506
 
507
+ # Decode intermediate response text using the *current* state x
508
  intermediate_response_tokens = x[0, prompt_length:]
509
  intermediate_response_text = tokenizer.decode(
510
  intermediate_response_tokens,
 
512
  clean_up_tokenization_spaces=True
513
  ).strip()
514
 
515
+ # Update the *copy* of the history with the intermediate text for display purposes
516
+ if history_copy: # Ensure history_copy is not empty
517
+ history_copy[-1][1] = intermediate_response_text # Update the None placeholder
518
+
519
+ # Yield the updated history copy, current vis, and intermediate text
520
  yield history_copy, vis_data, intermediate_response_text
521
  time.sleep(visualization_delay)
522
 
 
531
  skip_special_tokens=True,
532
  clean_up_tokenization_spaces=True
533
  ).strip()
 
534
 
535
+ # Update the final history copy *definitively*
536
+ if history_copy:
537
+ history_copy[-1][1] = final_response_text
538
+
539
+ # Format the final visualization state
540
  final_generated_tokens = x[0, prompt_length:].cpu()
541
  vis_data_final = []
 
542
  for j in range(gen_length):
543
  current_tok_id = final_generated_tokens[j].item()
544
  previous_tok_id = previous_tokens_vis[j].item() if previous_tokens_vis is not None and j < len(previous_tokens_vis) else MASK_ID
 
552
  if current_tok_id == MASK_ID: color = "#444444"
553
  elif previous_tok_id == MASK_ID: color = "#66CC66"
554
  else: color = "#6699CC"
555
+
556
+ should_hide = False
557
+ if PAD_ID is not None and current_tok_id == PAD_ID: should_hide = True
558
+ if EOS_ID is not None and current_tok_id == EOS_ID: should_hide = True
559
+ if PAD_ID == EOS_ID and PAD_ID is not None and current_tok_id == PAD_ID: should_hide = True
560
+
561
  if should_hide and previous_tok_id == current_tok_id:
562
  token_to_display = ""; color = None
563
  if token_to_display: vis_data_final.append((token_to_display, color))
564
 
565
+ # Yield the final history, final visualization, and final text
566
  yield history_copy, vis_data_final, final_response_text
567
  print("Visualization streaming complete.")
568
 
569
  except Exception as e:
570
+ print(f"Error during generation or processing loop: {e}")
 
571
  traceback.print_exc()
572
+ # Yield the history as it was before the error, error vis, empty text
573
  yield history_copy, [("Error during generation.", "red")], ""
574
  return
575
 
576
 
577
+ # --- Gradio UI ---
578
  css = '''
579
  .category-legend{display:none}
580
  button{min-height: 60px}
 
587
  "[[Blog](https://hkunlp.github.io/blog/2025/dream/)]" # Note: Link might be hypothetical
588
  )
589
 
590
+ # STATE MANAGEMENT
591
  _chat_history_store = gr.State([]) # Hidden state to store actual history list
592
 
593
+ # UI COMPONENTS
594
  with gr.Row():
595
  with gr.Column(scale=3):
596
  chatbot_ui = gr.Chatbot(
 
621
  label="Denoising Process Visualization",
622
  combine_adjacent=False,
623
  show_legend=True,
624
+ interactive=False,
625
  )
626
  response_text_display = gr.Textbox(
627
  label="Generated Response",
628
  interactive=False,
629
+ lines=5
 
630
  )
631
 
632
+ # Advanced generation settings
633
  with gr.Accordion("Generation Settings", open=False):
634
  with gr.Row():
635
  gen_length = gr.Slider(minimum=16, maximum=512, value=128, step=8, label="Max New Tokens")
 
638
  temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.05, label="Temperature (0 = greedy)")
639
  alg_temp = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="Remasking Temp (Confidence Algs)")
640
  with gr.Row():
641
+ # Adjusted label for clarity on disabling top_p
642
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="Top-P (>0 & <1 to enable)")
643
+ top_k = gr.Slider(minimum=0, maximum=200, value=0, step=5, label="Top-K (>0 to enable)")
644
  with gr.Row():
645
  remasking_strategy = gr.Radio(choices=['origin', 'maskgit_plus', 'topk_margin', 'entropy'], value='entropy', label="Remasking Strategy (Algorithm)")
646
  with gr.Row():
647
+ visualization_delay = gr.Slider(minimum=0.0, maximum=0.5, value=0.03, step=0.01, label="Visualization Delay (seconds)")
648
 
649
+ # Clear button
650
  clear_btn = gr.Button("Clear Conversation")
651
 
652
+ # --- Event Handlers ---
653
+
654
  def add_user_message_to_history(message: str, history_store: List[List[Optional[str]]]):
655
+ """Adds user message TO STATE, clears input, prepares for bot response."""
656
  if not message.strip():
657
  gr.Warning("Please enter a message.")
658
+ # Return unchanged state, but clear inputs/outputs for next step
659
+ # Outputs: _chat_history_store, user_input, output_vis, response_text_display
660
+ return history_store, message, [], "" # Return original message to keep it in input if invalid
661
+
662
+ # Add user message with placeholder for bot response TO THE STATE
663
+ history_store.append([message.strip(), None]) # Ensure message is stripped
664
+ # Return updated history store, clear input box, clear vis, clear response text
665
+ # Outputs: _chat_history_store, user_input, output_vis, response_text_display
666
+ return history_store, "", [], "" # Clear user_input only on success
667
 
668
  def clear_conversation():
669
+ """Clears the chat history state and UI elements."""
670
+ # Outputs: _chat_history_store, chatbot_ui, user_input, output_vis, response_text_display
671
+ return [], [], "", [], "" # Clear everything
672
+
673
+
674
+ # --- Connect UI elements ---
675
 
676
+ # Inputs for the generation function
677
  generation_inputs = [
678
  _chat_history_store, gen_length, steps, constraints_input,
679
  temperature, top_p, top_k, remasking_strategy, alg_temp,
680
  visualization_delay
681
  ]
682
+ # Outputs for the generation function (yields history, vis_data, text)
683
  generation_outputs = [chatbot_ui, output_vis, response_text_display]
684
 
685
+ # Outputs for add_user_message_to_history
686
+ add_message_outputs = [
687
+ _chat_history_store, # Update state
688
+ user_input, # Clear input (or return original if invalid)
689
+ output_vis, # Clear visualization
690
+ response_text_display # Clear response text
691
+ ]
692
+
693
+ # Handle Textbox Submission (Enter key)
694
  submit_listener = user_input.submit(
695
  fn=add_user_message_to_history,
696
  inputs=[user_input, _chat_history_store],
697
+ outputs=add_message_outputs, # Step 1: Update state, clear inputs/vis/response
698
+ queue=True # Ensure intermediate steps are processed
699
  ).then(
700
  fn=generate_dream_response,
701
+ inputs=generation_inputs, # Takes the updated state
702
+ outputs=generation_outputs, # Step 2: Generate response and stream history/vis/text to UI
703
+ show_progress="hidden", # Hide default progress as we have live vis
704
+ queue=True # Ensure generation runs in the queue
705
  )
706
 
707
+ # Handle Send Button Click
708
  click_listener = send_btn.click(
709
  fn=add_user_message_to_history,
710
  inputs=[user_input, _chat_history_store],
711
+ outputs=add_message_outputs, # Step 1: Update state, clear inputs/vis/response
712
+ queue=True # Ensure intermediate steps are processed
713
  ).then(
714
  fn=generate_dream_response,
715
+ inputs=generation_inputs, # Takes the updated state
716
+ outputs=generation_outputs, # Step 2: Generate response and stream history/vis/text to UI
717
+ show_progress="hidden", # Hide default progress as we have live vis
718
+ queue=True # Ensure generation runs in the queue
719
  )
720
 
721
+ # Clear Button Action
722
  clear_btn.click(
723
  clear_conversation,
724
  inputs=[],
725
+ outputs=[_chat_history_store, chatbot_ui, user_input, output_vis, response_text_display],
726
+ queue=False # Clearing can be immediate
727
  )
728
 
729
  return demo
 
731
  # --- Launch ---
732
  if __name__ == "__main__":
733
  demo = create_chatbot_demo()
734
+ # Use queue for handling multiple users and streaming
735
+ demo.queue().launch(debug=True, share=False) # Set share=True for public link