multimodalart HF Staff commited on
Commit
badae07
·
verified ·
1 Parent(s): 4474e7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -369
app.py CHANGED
@@ -11,8 +11,7 @@ from typing import List, Dict, Tuple, Optional
11
  import torch.distributions as dists # Added import
12
 
13
  # --- START: Copied Helper functions from generation_utils.py ---
14
- # These are needed because we are reimplementing the sampling loop locally.
15
-
16
  def top_p_logits(logits, top_p=None):
17
  """ Applies top-p filtering to logits. """
18
  if top_p is None or top_p >= 1.0:
@@ -42,39 +41,38 @@ def top_k_logits(logits, top_k=None):
42
  def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
43
  """ Samples tokens based on logits and calculates confidence. """
44
  if temperature > 0:
45
- logits = logits / temperature
 
 
46
  if top_p is not None and top_p < 1.0: # Apply top_p if valid
47
  logits = top_p_logits(logits, top_p)
48
  if top_k is not None and top_k > 0: # Apply top_k if valid
49
  logits = top_k_logits(logits, top_k)
50
 
51
  # Ensure logits are not all -inf after filtering, if so, sample uniformly? Or handle error.
52
- # For simplicity, assume valid logits after filtering. If not, sampling might fail.
53
- # Add a small epsilon to prevent log(0) or issues with all -inf logits
54
- logits = torch.where(logits == torch.finfo(logits.dtype).min, torch.full_like(logits, -1e9), logits)
55
-
 
 
56
 
57
  probs = torch.softmax(logits, dim=-1)
58
 
 
 
 
 
 
 
59
  if temperature > 0:
60
  try:
61
- # Check for NaNs or Infs in probs before sampling
62
- if torch.isnan(probs).any() or torch.isinf(probs).any():
63
- print("Warning: NaN or Inf detected in probabilities before sampling. Attempting to recover.")
64
- # Simple recovery: Sample from uniform distribution or highest prob token
65
- probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
66
- if probs.sum() == 0: # If all probabilities became zero
67
- print("Warning: All probabilities became zero. Sampling uniformly.")
68
- probs = torch.ones_like(probs) / probs.shape[-1]
69
- else:
70
- probs = probs / probs.sum(dim=-1, keepdim=True) # Re-normalize
71
-
72
  x0 = dists.Categorical(probs=probs).sample()
73
  confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
74
  except Exception as e: # Catch broader exceptions during sampling
75
  print(f"Warning: Error during Categorical sampling: {e}. Falling back to argmax.")
76
  confidence, x0 = probs.max(dim=-1)
77
- else:
78
  confidence, x0 = probs.max(dim=-1)
79
 
80
  if margin_confidence:
@@ -86,14 +84,18 @@ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confid
86
 
87
  if neg_entropy:
88
  epsilon = 1e-10
 
89
  log_probs = torch.log(probs + epsilon)
90
  confidence = torch.sum(probs * log_probs, dim=-1) # Should be negative entropy
91
 
92
- return confidence, x0
 
93
 
 
94
  # --- END: Copied Helper functions ---
95
 
96
 
 
97
  # Load model configuration to get special token IDs
98
  config = AutoConfig.from_pretrained("Dream-org/Dream-v0-Instruct-7B", trust_remote_code=True)
99
  # Use AutoModel for the base model loading, relying on trust_remote_code=True
@@ -113,7 +115,7 @@ model = AutoModel.from_pretrained(
113
  model_path,
114
  torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32, # Use bfloat16 only on CUDA
115
  trust_remote_code=True,
116
- # attn_implementation="flash_attention_2" # Optional: Speed up if FA2 is available
117
  )
118
  model = model.to(device).eval()
119
  print("Model loaded.")
@@ -123,27 +125,17 @@ MASK_TOKEN = tokenizer.mask_token
123
  MASK_ID = tokenizer.mask_token_id # Use tokenizer's mask_token_id directly
124
  PAD_ID = tokenizer.pad_token_id # Use tokenizer's pad_token_id
125
  EOS_ID = tokenizer.eos_token_id # Use tokenizer's eos_token_id
126
- # Use attributes from loaded config/tokenizer objects
127
- # MASK_ID = config.mask_token_id # Can use this too, should be consistent
128
- # PAD_ID = config.pad_token_id
129
- # EOS_ID = config.eos_token_id
130
 
131
- # Ensure mask_token_id is correctly identified
132
  if MASK_ID is None:
133
  print("Warning: Mask token ID not found in config/tokenizer. Trying to fetch from tokenizer...")
134
- # Try getting from tokenizer directly if config doesn't have it or it's None
135
  mask_token_special = tokenizer.mask_token
136
  if mask_token_special:
137
  MASK_ID = tokenizer.convert_tokens_to_ids(mask_token_special)
138
  print(f"Found MASK_ID from tokenizer: {MASK_ID}")
139
  else:
140
- # Fallback or raise error if still not found
141
  raise ValueError("Cannot determine MASK_ID. Check model's tokenizer configuration.")
142
 
143
- # Make sure EOS_ID and PAD_ID are handled correctly; Dream uses the same ID for both
144
  SPECIAL_TOKEN_IDS = {PAD_ID, EOS_ID, MASK_ID}
145
- # Add other special tokens defined in tokenizer_config.json if needed for hiding
146
- # Get IDs for im_start, im_end etc. if they should also be hidden/handled specially
147
  try:
148
  IM_START_ID = tokenizer.convert_tokens_to_ids("<|im_start|>")
149
  IM_END_ID = tokenizer.convert_tokens_to_ids("<|im_end|>")
@@ -156,7 +148,6 @@ except KeyError:
156
 
157
 
158
  # --- Helper Functions ---
159
-
160
  def parse_constraints(constraints_text: str) -> Dict[int, List[int]]:
161
  """
162
  Parse constraints in format: 'position:word, position:word, ...'
@@ -174,43 +165,17 @@ def parse_constraints(constraints_text: str) -> Dict[int, List[int]]:
174
  continue
175
  pos_str, word = part.split(':', 1)
176
  try:
177
- # Position relative to the start of the *generation*
178
  pos = int(pos_str.strip())
179
  word = word.strip() # Strip whitespace from word
180
- # Tokenize the word - Dream tokenizer handles spaces well typically.
181
- # Let's check if the word starts with a space implicitly or needs one.
182
- # Standard tokenizers often need a space prefix if not at the start.
183
- # Test: tokenizer.encode(" world") vs tokenizer.encode("world")
184
- # Dream often encodes ' world' differently from 'world'.
185
- # Assume we want the word as it would appear mid-sentence unless pos is 0.
186
- token_ids = tokenizer.encode(word, add_special_tokens=False)
187
- # Add space prefix if needed based on position? This is tricky.
188
- # Let's assume the user provides the word how they want it tokenized,
189
- # potentially including a leading space if necessary.
190
- # Example: " 5: word" might be tokenized differently than "5:word".
191
- # Simplest approach: Tokenize exactly what the user provided.
192
- # Let's refine: add space prefix automatically if pos > 0, unless word already starts with space?
193
- # This seems more robust for typical usage.
194
- if pos > 0 and not word.startswith(" "):
195
- token_ids_with_space = tokenizer.encode(" " + word, add_special_tokens=False)
196
- # Check if adding space actually changes tokenization significantly
197
- # Heuristic: if the first token ID changes, use the space-prefixed version.
198
- first_token_no_space = tokenizer.encode(word, add_special_tokens=False)[0] if token_ids else None
199
- first_token_with_space = tokenizer.encode(" " + word, add_special_tokens=False)[0] if token_ids_with_space else None
200
-
201
- if first_token_no_space != first_token_with_space:
202
- token_ids = token_ids_with_space
203
- # If tokenization doesn't change much, maybe stick to original? Less surprising.
204
- # Let's stick to adding the space if pos > 0 for consistency, like original code.
205
- token_ids = tokenizer.encode(" " + word, add_special_tokens=False)
206
-
207
- elif pos == 0:
208
- token_ids = tokenizer.encode(word, add_special_tokens=False)
209
-
210
 
211
  if token_ids and pos >= 0:
212
  constraints[pos] = token_ids
213
- elif not token_ids:
214
  print(f"Warning: Could not tokenize constraint word '{word}'")
215
  except ValueError:
216
  print(f"Warning: Invalid position '{pos_str}' in constraint part '{part}'")
@@ -219,26 +184,16 @@ def parse_constraints(constraints_text: str) -> Dict[int, List[int]]:
219
  print(f"Warning: Error processing constraint '{part}': {e}")
220
  continue
221
 
222
- print(f"Parsed constraints: {constraints}") # Debugging
223
  return constraints
224
 
225
 
226
  def format_chat_history(history: List[List[Optional[str]]]) -> List[Dict[str, str]]:
227
- """
228
- Format chat history for the Dream model's chat template.
229
-
230
- Args:
231
- history: List of [user_message, assistant_message] pairs.
232
- The last assistant_message might be None.
233
-
234
- Returns:
235
- Formatted list of message dictionaries for tokenizer.apply_chat_template.
236
- """
237
  messages = []
238
  for user_msg, assistant_msg in history:
239
- if user_msg: # Defensive check
240
  messages.append({"role": "user", "content": user_msg})
241
- # Add assistant message only if it exists (it won't for the last turn before generation)
242
  if assistant_msg:
243
  messages.append({"role": "assistant", "content": assistant_msg})
244
  return messages
@@ -250,19 +205,17 @@ def apply_constraints_to_state(
250
  parsed_constraints: Dict[int, List[int]],
251
  current_step: Optional[int] = None # For logging/debugging
252
  ) -> torch.Tensor:
253
- """Applies constraints directly to the state tensor `x`."""
254
- modified_x = x.clone() # Work on a copy to avoid modifying original if needed elsewhere
 
255
  for rel_pos, word_token_ids in parsed_constraints.items():
256
  abs_start_pos = prompt_length + rel_pos
257
  abs_end_pos = abs_start_pos + len(word_token_ids)
258
 
259
- # Ensure the constraint fits within the generation length
260
  if abs_start_pos < total_length and abs_end_pos <= total_length:
261
  try:
262
  constraint_tensor = torch.tensor(word_token_ids, dtype=torch.long, device=modified_x.device)
263
- # Force the constraint tokens onto the sequence
264
  modified_x[0, abs_start_pos:abs_end_pos] = constraint_tensor
265
- # print(f"Debug (Step {current_step}): Applied constraint {tokenizer.decode(word_token_ids)} at pos {rel_pos}") # Debug
266
  except IndexError:
267
  print(f"Warning (Step {current_step}): Constraint at {rel_pos} ('{tokenizer.decode(word_token_ids)}') goes out of bounds.")
268
  except Exception as e:
@@ -286,27 +239,7 @@ def generate_dream_response(
286
  alg_temp: Optional[float],
287
  visualization_delay: float
288
  ) -> List[Tuple[str, str]]:
289
- """
290
- Generates text using the Dream model step-by-step and yields visualization states live.
291
-
292
- Args:
293
- history: Chat history.
294
- gen_length: Max new tokens to generate.
295
- steps: Number of diffusion steps.
296
- constraints_text: User-provided constraints string.
297
- temperature: Sampling temperature.
298
- top_p: Top-p sampling nucleus. Clamp to < 1.0 or None.
299
- top_k: Top-k sampling. Clamp to > 0 or None.
300
- alg: Remasking algorithm ('origin', 'maskgit_plus', 'topk_margin', 'entropy').
301
- alg_temp: Temperature for confidence-based algorithms.
302
- visualization_delay: Delay between visualization steps.
303
-
304
- Yields:
305
- Tuple[List[List[Optional[str]]], List[Tuple[str, Optional[str]]], str]:
306
- - Updated history (may be intermediate until final response)
307
- - Visualization data for HighlightedText for the current step
308
- - Intermediate or Final response text (yielded repeatedly)
309
- """
310
 
311
  if not history or not history[-1][0]:
312
  yield history, [("No input message found.", "red")], ""
@@ -315,74 +248,62 @@ def generate_dream_response(
315
  # --- 1. Preparation ---
316
  last_user_message = history[-1][0]
317
  messages_for_template = format_chat_history(history) # Includes the latest user message
 
318
 
319
- # Parse constraints relative to the *generated* sequence
320
- parsed_constraints = parse_constraints(constraints_text) # Dict[rel_pos, List[token_id]]
321
-
322
- # Prepare inputs using the chat template
323
  try:
324
  inputs = tokenizer.apply_chat_template(
325
  messages_for_template,
326
  return_tensors="pt",
327
  return_dict=True,
328
- add_generation_prompt=True # Important for instruct models
329
  )
330
  input_ids = inputs.input_ids.to(device)
331
- prompt_attention_mask = inputs.attention_mask.to(device) # Mask for the prompt part
 
332
  prompt_length = input_ids.shape[1]
333
  except Exception as e:
334
  print(f"Error applying chat template: {e}")
335
  yield history, [("Error preparing input.", "red")], ""
336
  return
337
 
338
- # --- Config parameters for the loop ---
339
- eps = 1e-3 # Default from DreamGenerationConfig, make configurable if needed
340
- # Ensure top_p and top_k have valid values for filtering functions
341
- top_p_val = top_p if top_p is not None and top_p < 1.0 else None
342
  top_k_val = top_k if top_k is not None and top_k > 0 else None
343
- alg_temp_val = alg_temp if alg in ['maskgit_plus', 'topk_margin', 'entropy'] else None
344
 
345
  # --- 2. Initialize Generation State ---
346
  total_length = prompt_length + gen_length
347
-
348
- # Initial state: prompt + MASK tokens
349
  initial_generation_part = torch.full((1, gen_length), MASK_ID, dtype=torch.long, device=device)
350
  x = torch.cat((input_ids, initial_generation_part), dim=1)
351
 
352
- # Prepare full attention mask (assuming full attention over generated part initially)
353
  generation_attention_mask = torch.ones((1, gen_length), dtype=torch.long, device=device)
354
- full_attention_mask = torch.cat((prompt_attention_mask, generation_attention_mask), dim=1)
355
-
356
- # Check if model needs specific attention mask format (e.g., causal for prompt?)
357
- # The original `diffusion_generate` handles this internally. Replicating requires care.
358
- # Based on `_sample`, it prepares a broadcastable mask if padding exists, else uses "full".
359
- # Let's assume "full" attention is okay for Dream's purpose here, as mask tokens don't depend on future masks.
360
- # If the base model *requires* causal masking internally even with diffusion, this might need adjustment.
361
- # For simplicity, using a full mask (ones) over the whole sequence.
362
- # The model's internal attention should handle causality if needed.
363
- # Let's stick to the simpler full mask preparation from the original code when no padding.
364
- if torch.any(full_attention_mask == 0): # Handle padding if present (shouldn't be with template?)
365
- tok_idx = full_attention_mask.long().cumsum(-1) - 1
366
- tok_idx.masked_fill_(full_attention_mask == 0, 0) # Use 0 for padding index? Or 1? Check original. Original used 1.
367
- tok_idx.masked_fill_(full_attention_mask == 0, 1)
368
- attention_mask_for_model = torch.logical_and(
369
- full_attention_mask.unsqueeze(1).unsqueeze(-2),
370
- full_attention_mask.unsqueeze(1).unsqueeze(-1),
371
- ) # Shape [B, 1, N, N]
372
- else:
373
- tok_idx = None
374
- attention_mask_for_model = None # Let the model handle full attention if mask is None
375
-
376
- # Timesteps for diffusion
377
  timesteps = torch.linspace(1, eps, steps + 1, device=device)
378
 
379
- # Apply initial constraints (before first step)
380
- x = apply_constraints_to_state(x, prompt_length, total_length, parsed_constraints, current_step=-1) # Step -1 for initial
381
 
382
  # --- 3. Visualization Setup ---
383
- previous_tokens_vis = None # Keep track of the previous step's state for coloring
384
- final_response_text = "" # Store the final decoded text
385
- history_copy = [list(item) for item in history] # Make a mutable copy
386
 
387
  # --- 4. Initial Yield (Masked State) ---
388
  initial_generated_tokens = x[0, prompt_length:].cpu()
@@ -393,84 +314,53 @@ def generate_dream_response(
393
  vis_data_initial.append((display_token, color))
394
 
395
  previous_tokens_vis = initial_generated_tokens
396
- yield history_copy, vis_data_initial, "" # Yield initial state
397
  time.sleep(visualization_delay)
398
 
399
  # --- 5. Step-by-Step Diffusion Loop ---
400
  try:
401
  start_time = time.time()
402
  for i in range(steps):
403
- # --- Model Forward Pass ---
404
- mask_index = (x == MASK_ID) # Find masks in the *current* state x
405
- if not mask_index.any(): # Stop if no masks left
406
  print(f"No mask tokens left at step {i}. Stopping early.")
407
  break
408
 
409
- # print(f"Step {i}: Input shape {x.shape}, Mask sum {mask_index.sum()}") # Debug
410
- # print(f"Step {i}: Input tokens (first/last 10): {x[0, :10].tolist()} ... {x[0, -10:].tolist()}") # Debug
411
-
412
- # Call the model - ensure attention mask format is correct
413
- # The model forward expects `attention_mask` usually of shape [B, N] or broadcastable.
414
- # If we use `attention_mask_for_model = None`, it implies full attention.
415
- # If we computed `attention_mask_for_model` as [B, 1, N, N], pass that.
416
- # Let's try passing the [B, N] mask and let the model handle broadcasting/causality internally.
417
  outputs = model(
418
  input_ids=x,
419
- attention_mask=full_attention_mask, # Pass the [B, N] mask
420
- position_ids=None, # Let model compute default positions
421
- use_cache=False, # No cache needed for diffusion steps
422
  return_dict=True
423
  )
424
  logits = outputs.logits
 
425
 
426
- # Shift logits like in original code? Check `generation_utils.py`.
427
- # Yes, `logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)`
428
- # This seems to align logits with the *previous* token's prediction. Is this correct for diffusion?
429
- # Let's assume the original code did this for a reason, perhaps related to how the model was trained or expects inputs.
430
- # Update: Looking at standard LM forward pass, logits[t] predicts token[t+1].
431
- # The shift aligns logits[t] with token[t]. Let's keep it.
432
- logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
433
-
434
-
435
- # Select logits for masked positions
436
- # Ensure mask_index has the same batch dimension size as logits
437
- # mask_index shape is [B, N], logits shape is [B, N, V]
438
- # We need to select elements from the last dim of logits where mask is True
439
- mask_logits = logits[mask_index] # This correctly selects [num_masked_tokens, V]
440
-
441
- if mask_logits.numel() == 0: # If no masks, logits selection is empty
442
  print(f"No masked tokens found for logit selection at step {i}. Stopping.")
443
  break
444
 
445
- # print(f"Step {i}: mask_logits shape: {mask_logits.shape}") # Debug
446
-
447
  # --- Sampling / Remasking Logic ---
448
  t = timesteps[i]
449
  s = timesteps[i + 1]
450
-
451
  x_new_masked_part = torch.full_like(x[mask_index], MASK_ID, device=device, dtype=torch.long)
452
 
453
  if alg == 'origin':
454
- # Original diffusion logic
455
- p_transfer = (1.0 - s / t) if i < steps - 1 else 1.0 # Ensure float division
456
- # Sample only for the tokens to be revealed in this step
457
  num_masked = mask_logits.shape[0]
458
  transfer_indices_relative = torch.rand(num_masked, device=device) < p_transfer
459
  logits_to_sample = mask_logits[transfer_indices_relative]
460
 
461
  if logits_to_sample.numel() > 0:
462
- # print(f"Step {i} (origin): Sampling {logits_to_sample.shape[0]} tokens.") # Debug
463
  _, sampled_tokens = sample_tokens(logits_to_sample, temperature=temperature, top_p=top_p_val, top_k=top_k_val)
464
- # Place sampled tokens into the correct positions within the masked part
465
  x_new_masked_part[transfer_indices_relative] = sampled_tokens
466
- # else:
467
- # print(f"Step {i} (origin): No tokens to sample (p_transfer={p_transfer}).") # Debug
468
 
469
- else:
470
- # Confidence-based algorithms (maskgit_plus, topk_margin, entropy)
471
  use_margin = (alg == 'topk_margin')
472
  use_entropy = (alg == 'entropy')
473
- # print(f"Step {i} ({alg}): Sampling all {mask_logits.shape[0]} masked tokens for confidence.") # Debug
474
  confidence, x0_candidates = sample_tokens(
475
  mask_logits,
476
  temperature=temperature,
@@ -479,67 +369,92 @@ def generate_dream_response(
479
  margin_confidence=use_margin,
480
  neg_entropy=use_entropy
481
  )
482
- # print(f"Step {i} ({alg}): Confidence range: [{confidence.min():.2f}, {confidence.max():.2f}]") # Debug
483
-
484
 
485
  num_mask_token = mask_logits.shape[0]
486
- # Calculate number to reveal based on time steps, ensure it's an int
487
  target_num_revealed_float = num_mask_token * (1.0 - s / t)
488
  number_transfer_tokens = int(target_num_revealed_float) if i < steps - 1 else num_mask_token
489
 
490
-
491
  if number_transfer_tokens > 0:
492
- # print(f"Step {i} ({alg}): Need to reveal {number_transfer_tokens} tokens.") # Debug
493
- if alg_temp_val is None or alg_temp_val <= 0: # Use top-k confidence
494
- # Sort by confidence (use negative entropy directly if alg='entropy')
495
- # For entropy, lower (more negative) is higher confidence (less uncertainty)
496
- sort_metric = confidence if alg != 'entropy' else -confidence
497
- _, transfer_indices_relative = torch.topk(sort_metric, k=min(number_transfer_tokens, num_mask_token)) # Ensure k is not > num_mask_token
498
- else: # Use sampling based on confidence temperature
499
- conf_probs = confidence / alg_temp_val
500
- # Check for inf/-inf before softmax
501
- conf_probs = torch.nan_to_num(conf_probs, nan=0.0, posinf=1e9, neginf=-1e9)
502
- conf_probs = F.softmax(conf_probs, dim=-1)
503
- # Check probs sum to 1
504
- if not torch.allclose(conf_probs.sum(), torch.tensor(1.0, device=device), atol=1e-4):
505
- print(f"Warning step {i}: Confidence probabilities do not sum to 1 after softmax ({conf_probs.sum()}). Re-normalizing.")
506
- conf_probs = conf_probs / conf_probs.sum(dim=-1, keepdim=True) # Normalize
507
-
508
- # Ensure num_samples is valid
509
- num_samples = min(number_transfer_tokens, num_mask_token)
510
- if conf_probs.numel() > 0 and num_samples > 0:
511
- try:
512
- transfer_indices_relative = torch.multinomial(conf_probs, num_samples=num_samples, replacement=False)
513
- except RuntimeError as e:
514
- print(f"Warning step {i}: Multinomial sampling failed ('{e}'). Falling back to top-k.")
515
- # Fallback to top-k if multinomial fails (e.g., due to prob issues)
516
- sort_metric = confidence if alg != 'entropy' else -confidence
517
- _, transfer_indices_relative = torch.topk(sort_metric, k=num_samples)
518
- else:
519
- transfer_indices_relative = torch.tensor([], dtype=torch.long, device=device) # No indices if no probs or num_samples=0
520
-
521
- # Place the selected candidate tokens into the masked part update
522
- if transfer_indices_relative.numel() > 0:
523
- x_new_masked_part[transfer_indices_relative] = x0_candidates[transfer_indices_relative].clone()
524
- # else:
525
- # print(f"Step {i} ({alg}): No tokens revealed via confidence ({number_transfer_tokens} target).") # Debug
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
 
527
  # Update the global state `x` only at the masked positions
528
  x[mask_index] = x_new_masked_part
529
 
530
  # --- Apply Constraints ---
531
- # Constraints should be applied *after* sampling/revealing tokens for the step
532
  x = apply_constraints_to_state(x, prompt_length, total_length, parsed_constraints, current_step=i)
533
 
534
  # --- Yield Visualization ---
535
- current_generated_tokens = x[0, prompt_length:].cpu() # Get generated part, move to CPU
536
  vis_data = []
 
537
  for j in range(gen_length):
538
  current_tok_id = current_generated_tokens[j].item()
539
- previous_tok_id = previous_tokens_vis[j].item() if previous_tokens_vis is not None else MASK_ID
540
 
541
  try:
542
- decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False)
 
543
  display_token = MASK_TOKEN if current_tok_id == MASK_ID else decoded_token
544
  except Exception:
545
  display_token = f"[ID:{current_tok_id}]" # Fallback
@@ -554,31 +469,17 @@ def generate_dream_response(
554
  else: # Token was already revealed
555
  color = "#6699CC" # Light Blue
556
 
557
- # Hide special tokens (PAD/EOS) if they were already revealed (LLaDA effect)
558
- # Ensure PAD_ID and EOS_ID are not None before checking
559
  should_hide = (PAD_ID is not None and current_tok_id == PAD_ID) or \
560
  (EOS_ID is not None and current_tok_id == EOS_ID)
561
  if should_hide and previous_tok_id == current_tok_id:
562
  token_to_display = "" # Hide by making empty
563
  color = None # No color for hidden
564
 
565
-
566
  if token_to_display:
567
  vis_data.append((token_to_display, color))
568
- elif len(vis_data) > 0 and isinstance(vis_data[-1], tuple) and vis_data[-1][0] == " ":
569
- # Avoid adding multiple spaces if tokens are hidden consecutively
570
- pass
571
- elif len(vis_data) > 0 and not isinstance(vis_data[-1], tuple) and vis_data[-1] == " ":
572
- pass # Already added a space
573
- elif len(vis_data) > 0 :
574
- # Add a single space if hiding follows a visible token, improves readability slightly
575
- # Let's simplify: just omit hidden tokens. Adding spaces might be complex.
576
- pass
577
-
578
- # Update previous state for the next iteration
579
- previous_tokens_vis = current_generated_tokens
580
-
581
- # Decode intermediate response (might be partial) - skip specials for readability
582
  intermediate_response_tokens = x[0, prompt_length:]
583
  intermediate_response_text = tokenizer.decode(
584
  intermediate_response_tokens,
@@ -586,9 +487,6 @@ def generate_dream_response(
586
  clean_up_tokenization_spaces=True
587
  ).strip()
588
 
589
- # Yield current state
590
- # We yield the *current* history, the vis data for this step, and intermediate text
591
- # The final text will overwrite the intermediate text in the UI eventually
592
  yield history_copy, vis_data, intermediate_response_text
593
  time.sleep(visualization_delay)
594
 
@@ -598,65 +496,47 @@ def generate_dream_response(
598
  # --- 6. Final Processing & Yield ---
599
  final_sequence = x[0]
600
  response_tokens = final_sequence[prompt_length:]
601
-
602
- # Decode the final response text
603
  final_response_text = tokenizer.decode(
604
  response_tokens,
605
- skip_special_tokens=True, # Skip EOS, PAD, MASK etc. in the final text
606
  clean_up_tokenization_spaces=True
607
  ).strip()
608
-
609
- # Update history with the final response *before* the last yield
610
  history_copy[-1][1] = final_response_text
611
 
612
- # Yield the final state (which might be the same as the last yielded state if loop finished)
613
- # Need to format vis_data one last time based on the final `x`
614
  final_generated_tokens = x[0, prompt_length:].cpu()
615
  vis_data_final = []
 
616
  for j in range(gen_length):
617
  current_tok_id = final_generated_tokens[j].item()
618
- previous_tok_id = previous_tokens_vis[j].item() if previous_tokens_vis is not None else MASK_ID
619
-
620
  try:
621
- decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False)
622
  display_token = MASK_TOKEN if current_tok_id == MASK_ID else decoded_token
623
  except Exception:
624
  display_token = f"[ID:{current_tok_id}]" # Fallback
625
-
626
  color = None
627
  token_to_display = display_token
628
-
629
- if current_tok_id == MASK_ID:
630
- color = "#444444"
631
- elif previous_tok_id == MASK_ID:
632
- color = "#66CC66"
633
- else:
634
- color = "#6699CC"
635
-
636
  should_hide = (PAD_ID is not None and current_tok_id == PAD_ID) or \
637
  (EOS_ID is not None and current_tok_id == EOS_ID)
638
  if should_hide and previous_tok_id == current_tok_id:
639
- token_to_display = ""
640
- color = None
641
-
642
- if token_to_display:
643
- vis_data_final.append((token_to_display, color))
644
 
645
- # Yield the final history, final visualization, and final text
646
  yield history_copy, vis_data_final, final_response_text
647
  print("Visualization streaming complete.")
648
 
649
-
650
  except Exception as e:
651
  print(f"Error during generation or processing: {e}")
652
  import traceback
653
  traceback.print_exc()
654
- # Update history with error message? Or leave as None? Let's leave as None.
655
  yield history_copy, [("Error during generation.", "red")], ""
656
  return
657
 
658
 
659
- # --- Gradio UI (Remains largely the same, ensures outputs match yield structure) ---
660
  css = '''
661
  .category-legend{display:none}
662
  button{min-height: 60px}
@@ -669,12 +549,8 @@ def create_chatbot_demo():
669
  "[[Blog](https://hkunlp.github.io/blog/2025/dream/)]" # Note: Link might be hypothetical
670
  )
671
 
672
- # STATE MANAGEMENT
673
- # chat_history = gr.State([]) # Use gr.Chatbot's internal state implicitly if possible, or manage manually
674
- # Let's manage manually with a list for clarity with yielding updates
675
  _chat_history_store = gr.State([]) # Hidden state to store actual history list
676
 
677
- # UI COMPONENTS
678
  with gr.Row():
679
  with gr.Column(scale=3):
680
  chatbot_ui = gr.Chatbot(
@@ -682,10 +558,7 @@ def create_chatbot_demo():
682
  height=500,
683
  show_copy_button=True,
684
  bubble_full_width=False,
685
- # value=[] # Initialize chatbot UI empty
686
  )
687
-
688
- # Message input
689
  with gr.Group():
690
  with gr.Row():
691
  user_input = gr.Textbox(
@@ -694,11 +567,9 @@ def create_chatbot_demo():
694
  scale=7,
695
  autofocus=True,
696
  show_label=False,
697
- container=False # Remove container for tighter packing
698
  )
699
  send_btn = gr.Button("Send", scale=1, variant="primary")
700
-
701
-
702
  constraints_input = gr.Textbox(
703
  label="Word Constraints (Optional)",
704
  info="Place words at specific positions (0-indexed from start of generation). Format: 'pos:word, pos:word,...'. Example: '0:Once, 5:upon, 10:time'",
@@ -708,125 +579,72 @@ def create_chatbot_demo():
708
  with gr.Column(scale=2):
709
  output_vis = gr.HighlightedText(
710
  label="Denoising Process Visualization",
711
- combine_adjacent=False,
712
- show_legend=True, # Legend isn't very informative here
713
- interactive=False # Not interactive
714
  )
715
- # Add a text box to display the final/intermediate response clearly
716
  response_text_display = gr.Textbox(
717
  label="Generated Response",
718
  interactive=False,
719
- lines=5 # Show a few lines
720
  )
721
 
722
-
723
- # Advanced generation settings
724
  with gr.Accordion("Generation Settings", open=False):
725
- with gr.Row():
726
- gen_length = gr.Slider(
727
- minimum=16, maximum=512, value=128, step=8, # Increased max length
728
- label="Max New Tokens"
729
- )
730
- steps = gr.Slider(
731
- minimum=8, maximum=512, value=128, step=8, # Increased max steps
732
- label="Diffusion Steps"
733
- )
734
- with gr.Row():
735
- temperature = gr.Slider(
736
- minimum=0.0, maximum=1.0, value=0.4, step=0.05,
737
- label="Temperature (0 = greedy)"
738
- )
739
- alg_temp = gr.Slider(
740
- minimum=0.0, maximum=1.0, value=0.1, step=0.05,
741
- label="Remasking Temp (Confidence Algs)"
742
- )
743
-
744
- with gr.Row():
745
- top_p = gr.Slider(
746
- minimum=0.0, maximum=1.0, value=0.95, step=0.05,
747
- label="Top-P (<=0 or >=1 disables)" # Clarify disabling condition
748
- )
749
- top_k = gr.Slider(
750
- minimum=0, maximum=200, value=0, step=5,
751
- label="Top-K (0 disables)"
752
- )
753
-
754
- with gr.Row():
755
- remasking_strategy = gr.Radio(
756
- choices=['origin', 'maskgit_plus', 'topk_margin', 'entropy'],
757
- value='entropy', # Default to entropy as in example
758
- label="Remasking Strategy (Algorithm)"
759
- )
760
-
761
- with gr.Row():
762
- visualization_delay = gr.Slider(
763
- minimum=0.0, maximum=0.5, value=0.03, step=0.01, # Slightly faster default
764
- label="Visualization Delay (seconds)"
765
- )
766
 
767
- # Clear button
768
  clear_btn = gr.Button("Clear Conversation")
769
 
770
- # --- Event Handlers ---
771
-
772
  def add_user_message_to_history(message: str, history_store: List[List[Optional[str]]]):
773
- """Adds user message, clears input, prepares for bot response."""
774
  if not message.strip():
775
  gr.Warning("Please enter a message.")
776
- # Return unchanged history, empty vis, empty response text
777
  return history_store, history_store, "", [], ""
778
-
779
- # Add user message with placeholder for bot response
780
  history_store.append([message, None])
781
- # Return updated history store, history for chatbot UI, empty input, empty vis, empty response
782
  return history_store, history_store, "", [], ""
783
 
784
  def clear_conversation():
785
- """Clears the chat history, visualization, and response text."""
786
- return [], [], "", [], "" # History store, chatbot UI, input, vis, response text
787
 
788
- # --- Connect UI elements ---
789
-
790
- # Define the inputs for the generation function once
791
  generation_inputs = [
792
  _chat_history_store, gen_length, steps, constraints_input,
793
  temperature, top_p, top_k, remasking_strategy, alg_temp,
794
  visualization_delay
795
  ]
796
- # Define the outputs for the generation function
797
- # Now yields: history_copy, vis_data, intermediate_response_text
798
- # Map these to: chatbot_ui, output_vis, response_text_display
799
  generation_outputs = [chatbot_ui, output_vis, response_text_display]
800
 
801
- # Handle Textbox Submission (Enter key)
802
  submit_listener = user_input.submit(
803
  fn=add_user_message_to_history,
804
  inputs=[user_input, _chat_history_store],
805
- outputs=[_chat_history_store, chatbot_ui, user_input, output_vis, response_text_display] # Step 1: Add user msg & clear outputs
806
- )
807
- # Chain the bot response generation after the user message is added
808
- submit_listener.then(
809
  fn=generate_dream_response,
810
  inputs=generation_inputs,
811
- outputs=generation_outputs, # Step 2: Generate response and stream vis/text
812
- show_progress="hidden" # Hide default progress bar as we have live vis
813
  )
814
 
815
- # Handle Send Button Click
816
  click_listener = send_btn.click(
817
  fn=add_user_message_to_history,
818
  inputs=[user_input, _chat_history_store],
819
- outputs=[_chat_history_store, chatbot_ui, user_input, output_vis, response_text_display] # Step 1: Add user msg & clear outputs
820
- )
821
- # Chain the bot response generation after the user message is added
822
- click_listener.then(
823
  fn=generate_dream_response,
824
  inputs=generation_inputs,
825
- outputs=generation_outputs, # Step 2: Generate response and stream vis/text
826
  show_progress="hidden"
827
  )
828
 
829
- # Clear Button Action
830
  clear_btn.click(
831
  clear_conversation,
832
  inputs=[],
@@ -838,5 +656,4 @@ def create_chatbot_demo():
838
  # --- Launch ---
839
  if __name__ == "__main__":
840
  demo = create_chatbot_demo()
841
- # Use queue for handling multiple users and streaming
842
- demo.queue().launch(debug=True, share=False) # Set share=True for public link
 
11
  import torch.distributions as dists # Added import
12
 
13
  # --- START: Copied Helper functions from generation_utils.py ---
14
+ # [Keep the copied functions: top_p_logits, top_k_logits, sample_tokens]
 
15
  def top_p_logits(logits, top_p=None):
16
  """ Applies top-p filtering to logits. """
17
  if top_p is None or top_p >= 1.0:
 
41
  def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
42
  """ Samples tokens based on logits and calculates confidence. """
43
  if temperature > 0:
44
+ # Prevent division by zero or negative temperatures
45
+ safe_temp = max(temperature, 1e-6)
46
+ logits = logits / safe_temp
47
  if top_p is not None and top_p < 1.0: # Apply top_p if valid
48
  logits = top_p_logits(logits, top_p)
49
  if top_k is not None and top_k > 0: # Apply top_k if valid
50
  logits = top_k_logits(logits, top_k)
51
 
52
  # Ensure logits are not all -inf after filtering, if so, sample uniformly? Or handle error.
53
+ # Add a check here: if all logits are -inf, assign uniform probability.
54
+ is_all_neg_inf = torch.all(logits == torch.finfo(logits.dtype).min, dim=-1, keepdim=True)
55
+ if torch.any(is_all_neg_inf):
56
+ # print("Warning: All logits became -inf after filtering. Assigning uniform probabilities.")
57
+ uniform_logits = torch.zeros_like(logits)
58
+ logits = torch.where(is_all_neg_inf, uniform_logits, logits)
59
 
60
  probs = torch.softmax(logits, dim=-1)
61
 
62
+ # Clamp probabilities to avoid NaNs in sampling, ensure they sum to 1
63
+ probs = torch.clamp(probs, min=0.0) # Ensure non-negative
64
+ probs = probs / probs.sum(dim=-1, keepdim=True) # Re-normalize
65
+ probs = torch.nan_to_num(probs, nan=0.0) # Handle any remaining NaNs
66
+
67
+
68
  if temperature > 0:
69
  try:
 
 
 
 
 
 
 
 
 
 
 
70
  x0 = dists.Categorical(probs=probs).sample()
71
  confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
72
  except Exception as e: # Catch broader exceptions during sampling
73
  print(f"Warning: Error during Categorical sampling: {e}. Falling back to argmax.")
74
  confidence, x0 = probs.max(dim=-1)
75
+ else: # Greedy decoding (temperature == 0)
76
  confidence, x0 = probs.max(dim=-1)
77
 
78
  if margin_confidence:
 
84
 
85
  if neg_entropy:
86
  epsilon = 1e-10
87
+ # Ensure probs are > 0 for log
88
  log_probs = torch.log(probs + epsilon)
89
  confidence = torch.sum(probs * log_probs, dim=-1) # Should be negative entropy
90
 
91
+ # Ensure confidence is not NaN
92
+ confidence = torch.nan_to_num(confidence, nan=0.0)
93
 
94
+ return confidence, x0
95
  # --- END: Copied Helper functions ---
96
 
97
 
98
+ # [Keep model loading, constants, helper functions: parse_constraints, format_chat_history, apply_constraints_to_state]
99
  # Load model configuration to get special token IDs
100
  config = AutoConfig.from_pretrained("Dream-org/Dream-v0-Instruct-7B", trust_remote_code=True)
101
  # Use AutoModel for the base model loading, relying on trust_remote_code=True
 
115
  model_path,
116
  torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32, # Use bfloat16 only on CUDA
117
  trust_remote_code=True,
118
+ attn_implementation="sdpa" # Explicitly request SDPA if available/desired
119
  )
120
  model = model.to(device).eval()
121
  print("Model loaded.")
 
125
  MASK_ID = tokenizer.mask_token_id # Use tokenizer's mask_token_id directly
126
  PAD_ID = tokenizer.pad_token_id # Use tokenizer's pad_token_id
127
  EOS_ID = tokenizer.eos_token_id # Use tokenizer's eos_token_id
 
 
 
 
128
 
 
129
  if MASK_ID is None:
130
  print("Warning: Mask token ID not found in config/tokenizer. Trying to fetch from tokenizer...")
 
131
  mask_token_special = tokenizer.mask_token
132
  if mask_token_special:
133
  MASK_ID = tokenizer.convert_tokens_to_ids(mask_token_special)
134
  print(f"Found MASK_ID from tokenizer: {MASK_ID}")
135
  else:
 
136
  raise ValueError("Cannot determine MASK_ID. Check model's tokenizer configuration.")
137
 
 
138
  SPECIAL_TOKEN_IDS = {PAD_ID, EOS_ID, MASK_ID}
 
 
139
  try:
140
  IM_START_ID = tokenizer.convert_tokens_to_ids("<|im_start|>")
141
  IM_END_ID = tokenizer.convert_tokens_to_ids("<|im_end|>")
 
148
 
149
 
150
  # --- Helper Functions ---
 
151
  def parse_constraints(constraints_text: str) -> Dict[int, List[int]]:
152
  """
153
  Parse constraints in format: 'position:word, position:word, ...'
 
165
  continue
166
  pos_str, word = part.split(':', 1)
167
  try:
 
168
  pos = int(pos_str.strip())
169
  word = word.strip() # Strip whitespace from word
170
+ token_ids = []
171
+ if word: # Only encode if word is not empty
172
+ # Add space prefix automatically if pos > 0 and word doesn't start with space
173
+ text_to_encode = (" " + word) if (pos > 0 and not word.startswith(" ")) else word
174
+ token_ids = tokenizer.encode(text_to_encode, add_special_tokens=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  if token_ids and pos >= 0:
177
  constraints[pos] = token_ids
178
+ elif not token_ids and word: # Don't warn for empty words after split
179
  print(f"Warning: Could not tokenize constraint word '{word}'")
180
  except ValueError:
181
  print(f"Warning: Invalid position '{pos_str}' in constraint part '{part}'")
 
184
  print(f"Warning: Error processing constraint '{part}': {e}")
185
  continue
186
 
187
+ # print(f"Parsed constraints: {constraints}") # Debugging
188
  return constraints
189
 
190
 
191
  def format_chat_history(history: List[List[Optional[str]]]) -> List[Dict[str, str]]:
192
+ """ Formats chat history for the template. """
 
 
 
 
 
 
 
 
 
193
  messages = []
194
  for user_msg, assistant_msg in history:
195
+ if user_msg:
196
  messages.append({"role": "user", "content": user_msg})
 
197
  if assistant_msg:
198
  messages.append({"role": "assistant", "content": assistant_msg})
199
  return messages
 
205
  parsed_constraints: Dict[int, List[int]],
206
  current_step: Optional[int] = None # For logging/debugging
207
  ) -> torch.Tensor:
208
+ """ Applies constraints directly to the state tensor `x`. """
209
+ modified_x = x # Modify in place maybe okay? Let's stick with clone for safety.
210
+ modified_x = x.clone()
211
  for rel_pos, word_token_ids in parsed_constraints.items():
212
  abs_start_pos = prompt_length + rel_pos
213
  abs_end_pos = abs_start_pos + len(word_token_ids)
214
 
 
215
  if abs_start_pos < total_length and abs_end_pos <= total_length:
216
  try:
217
  constraint_tensor = torch.tensor(word_token_ids, dtype=torch.long, device=modified_x.device)
 
218
  modified_x[0, abs_start_pos:abs_end_pos] = constraint_tensor
 
219
  except IndexError:
220
  print(f"Warning (Step {current_step}): Constraint at {rel_pos} ('{tokenizer.decode(word_token_ids)}') goes out of bounds.")
221
  except Exception as e:
 
239
  alg_temp: Optional[float],
240
  visualization_delay: float
241
  ) -> List[Tuple[str, str]]:
242
+ """ Generates text step-by-step and yields visualization states live. """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
  if not history or not history[-1][0]:
245
  yield history, [("No input message found.", "red")], ""
 
248
  # --- 1. Preparation ---
249
  last_user_message = history[-1][0]
250
  messages_for_template = format_chat_history(history) # Includes the latest user message
251
+ parsed_constraints = parse_constraints(constraints_text)
252
 
 
 
 
 
253
  try:
254
  inputs = tokenizer.apply_chat_template(
255
  messages_for_template,
256
  return_tensors="pt",
257
  return_dict=True,
258
+ add_generation_prompt=True
259
  )
260
  input_ids = inputs.input_ids.to(device)
261
+ # Ensure prompt_attention_mask is also on the correct device
262
+ prompt_attention_mask = inputs.attention_mask.to(device) if 'attention_mask' in inputs else torch.ones_like(input_ids)
263
  prompt_length = input_ids.shape[1]
264
  except Exception as e:
265
  print(f"Error applying chat template: {e}")
266
  yield history, [("Error preparing input.", "red")], ""
267
  return
268
 
269
+ eps = 1e-3
270
+ top_p_val = top_p if top_p is not None and 0.0 < top_p < 1.0 else None # Make sure top_p is > 0
 
 
271
  top_k_val = top_k if top_k is not None and top_k > 0 else None
272
+ alg_temp_val = alg_temp if alg in ['maskgit_plus', 'topk_margin', 'entropy'] and alg_temp is not None and alg_temp > 0 else None # Ensure > 0
273
 
274
  # --- 2. Initialize Generation State ---
275
  total_length = prompt_length + gen_length
 
 
276
  initial_generation_part = torch.full((1, gen_length), MASK_ID, dtype=torch.long, device=device)
277
  x = torch.cat((input_ids, initial_generation_part), dim=1)
278
 
279
+ # --- Prepare Attention Mask for SDPA ---
280
  generation_attention_mask = torch.ones((1, gen_length), dtype=torch.long, device=device)
281
+ full_attention_mask_long = torch.cat((prompt_attention_mask, generation_attention_mask), dim=1) # Shape [B, N], dtype torch.long
282
+
283
+ # Convert attention mask for SDPA: Needs float matching query dtype.
284
+ # Where mask is 1 (attend), value should be 0.0. Where mask is 0 (don't attend), value should be -inf.
285
+ attention_mask_for_model = full_attention_mask_long.to(model.dtype) # Convert to model's dtype (e.g., bfloat16)
286
+ # Invert the mask logic: (1.0 - mask) gives 0s for attend, 1s for mask
287
+ # Multiply by large negative number (min value for dtype) for masked positions
288
+ large_neg_val = torch.finfo(model.dtype).min
289
+ attention_mask_for_model = (1.0 - attention_mask_for_model) * large_neg_val
290
+ # Ensure the shape is broadcastable, SDPA usually handles [B, N] -> [B, H, N, N] if needed.
291
+ # However, explicitly making it [B, 1, 1, N] or [B, 1, N, N] can be safer.
292
+ # Let's try passing [B, N] first, if it fails, reshape.
293
+ # Reshape to [B, 1, 1, N] which is commonly expected for additive masks by HF models
294
+ attention_mask_for_model = attention_mask_for_model.unsqueeze(1).unsqueeze(2)
295
+ # Now shape is [B, 1, 1, N]
296
+
297
+ # --- Timesteps ---
 
 
 
 
 
 
298
  timesteps = torch.linspace(1, eps, steps + 1, device=device)
299
 
300
+ # Apply initial constraints
301
+ x = apply_constraints_to_state(x, prompt_length, total_length, parsed_constraints, current_step=-1)
302
 
303
  # --- 3. Visualization Setup ---
304
+ previous_tokens_vis = None
305
+ final_response_text = ""
306
+ history_copy = [list(item) for item in history] # Mutable copy
307
 
308
  # --- 4. Initial Yield (Masked State) ---
309
  initial_generated_tokens = x[0, prompt_length:].cpu()
 
314
  vis_data_initial.append((display_token, color))
315
 
316
  previous_tokens_vis = initial_generated_tokens
317
+ yield history_copy, vis_data_initial, ""
318
  time.sleep(visualization_delay)
319
 
320
  # --- 5. Step-by-Step Diffusion Loop ---
321
  try:
322
  start_time = time.time()
323
  for i in range(steps):
324
+ mask_index = (x == MASK_ID)
325
+ if not mask_index.any():
 
326
  print(f"No mask tokens left at step {i}. Stopping early.")
327
  break
328
 
329
+ # --- Model Forward Pass ---
330
+ # Pass the correctly formatted float mask
 
 
 
 
 
 
331
  outputs = model(
332
  input_ids=x,
333
+ attention_mask=attention_mask_for_model, # Pass the [B, 1, 1, N] float mask
334
+ position_ids=None,
335
+ use_cache=False,
336
  return_dict=True
337
  )
338
  logits = outputs.logits
339
+ logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1) # Align logits
340
 
341
+ mask_logits = logits[mask_index]
342
+ if mask_logits.numel() == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  print(f"No masked tokens found for logit selection at step {i}. Stopping.")
344
  break
345
 
 
 
346
  # --- Sampling / Remasking Logic ---
347
  t = timesteps[i]
348
  s = timesteps[i + 1]
 
349
  x_new_masked_part = torch.full_like(x[mask_index], MASK_ID, device=device, dtype=torch.long)
350
 
351
  if alg == 'origin':
352
+ p_transfer = (1.0 - s / t) if i < steps - 1 else 1.0
 
 
353
  num_masked = mask_logits.shape[0]
354
  transfer_indices_relative = torch.rand(num_masked, device=device) < p_transfer
355
  logits_to_sample = mask_logits[transfer_indices_relative]
356
 
357
  if logits_to_sample.numel() > 0:
 
358
  _, sampled_tokens = sample_tokens(logits_to_sample, temperature=temperature, top_p=top_p_val, top_k=top_k_val)
 
359
  x_new_masked_part[transfer_indices_relative] = sampled_tokens
 
 
360
 
361
+ else: # Confidence-based algorithms
 
362
  use_margin = (alg == 'topk_margin')
363
  use_entropy = (alg == 'entropy')
 
364
  confidence, x0_candidates = sample_tokens(
365
  mask_logits,
366
  temperature=temperature,
 
369
  margin_confidence=use_margin,
370
  neg_entropy=use_entropy
371
  )
 
 
372
 
373
  num_mask_token = mask_logits.shape[0]
 
374
  target_num_revealed_float = num_mask_token * (1.0 - s / t)
375
  number_transfer_tokens = int(target_num_revealed_float) if i < steps - 1 else num_mask_token
376
 
 
377
  if number_transfer_tokens > 0:
378
+ num_samples = min(number_transfer_tokens, num_mask_token) # Ensure k <= num_mask_token
379
+ if num_samples > 0: # Proceed only if we need to sample > 0 tokens
380
+ if alg_temp_val is None or alg_temp_val <= 0: # Top-k confidence
381
+ sort_metric = confidence if alg != 'entropy' else -confidence # Lower entropy = higher confidence
382
+ # Ensure k is not greater than the number of elements
383
+ k_topk = min(num_samples, sort_metric.numel())
384
+ if k_topk > 0:
385
+ _, transfer_indices_relative = torch.topk(sort_metric, k=k_topk)
386
+ else:
387
+ transfer_indices_relative = torch.tensor([], dtype=torch.long, device=device)
388
+
389
+ else: # Sample based on confidence temperature
390
+ # Ensure confidence has elements before processing
391
+ if confidence.numel() > 0:
392
+ conf_probs = confidence / alg_temp_val
393
+ # Handle potential inf/-inf before softmax, ensure non-negative probabilities
394
+ conf_probs = torch.nan_to_num(conf_probs, nan=0.0, posinf=1e9, neginf=-1e9) # Use large numbers instead of inf
395
+ conf_probs = torch.clamp(conf_probs - conf_probs.max(), min=-30) # Prevent large positive values leading to inf in exp
396
+ conf_probs = F.softmax(conf_probs, dim=-1)
397
+ conf_probs = torch.clamp(conf_probs, min=0.0) # Ensure non-negative
398
+ conf_probs = torch.nan_to_num(conf_probs, nan=0.0) # Handle NaNs
399
+
400
+ # Normalize probabilities if they don't sum to 1
401
+ prob_sum = conf_probs.sum()
402
+ if not torch.isclose(prob_sum, torch.tensor(1.0, device=device), atol=1e-4) and prob_sum > 0:
403
+ # print(f"Warning step {i}: Confidence probabilities sum {prob_sum:.4f} != 1. Re-normalizing.")
404
+ conf_probs = conf_probs / prob_sum
405
+
406
+ if conf_probs.numel() > 0 and num_samples > 0 and torch.all(conf_probs >= 0) and torch.isclose(conf_probs.sum(), torch.tensor(1.0, device=device)):
407
+ try:
408
+ transfer_indices_relative = torch.multinomial(conf_probs, num_samples=num_samples, replacement=False)
409
+ except RuntimeError as e:
410
+ print(f"Warning step {i}: Multinomial sampling failed ('{e}'). Falling back to top-k.")
411
+ sort_metric = confidence if alg != 'entropy' else -confidence
412
+ k_multinomial_fallback = min(num_samples, sort_metric.numel())
413
+ if k_multinomial_fallback > 0:
414
+ _, transfer_indices_relative = torch.topk(sort_metric, k=k_multinomial_fallback)
415
+ else:
416
+ transfer_indices_relative = torch.tensor([], dtype=torch.long, device=device)
417
+ else: # Handle cases where multinomial is not possible
418
+ # print(f"Warning step {i}: Invalid probabilities for multinomial sampling. Falling back to top-k.")
419
+ sort_metric = confidence if alg != 'entropy' else -confidence
420
+ k_multinomial_fallback = min(num_samples, sort_metric.numel())
421
+ if k_multinomial_fallback > 0:
422
+ _, transfer_indices_relative = torch.topk(sort_metric, k=k_multinomial_fallback)
423
+ else:
424
+ transfer_indices_relative = torch.tensor([], dtype=torch.long, device=device)
425
+ else: # No confidence values to sample from
426
+ transfer_indices_relative = torch.tensor([], dtype=torch.long, device=device)
427
+
428
+ # Apply the transfer
429
+ if transfer_indices_relative.numel() > 0:
430
+ # Ensure indices are within bounds of x0_candidates
431
+ valid_indices = transfer_indices_relative < x0_candidates.shape[0]
432
+ valid_transfer_indices = transfer_indices_relative[valid_indices]
433
+
434
+ if valid_transfer_indices.numel() > 0:
435
+ # Ensure indices are also within bounds of x_new_masked_part
436
+ if valid_transfer_indices.max() < x_new_masked_part.shape[0]:
437
+ x_new_masked_part[valid_transfer_indices] = x0_candidates[valid_transfer_indices].clone()
438
+ else:
439
+ print(f"Warning step {i}: transfer_indices out of bounds for x_new_masked_part.")
440
 
441
  # Update the global state `x` only at the masked positions
442
  x[mask_index] = x_new_masked_part
443
 
444
  # --- Apply Constraints ---
 
445
  x = apply_constraints_to_state(x, prompt_length, total_length, parsed_constraints, current_step=i)
446
 
447
  # --- Yield Visualization ---
448
+ current_generated_tokens = x[0, prompt_length:].cpu()
449
  vis_data = []
450
+ # [Keep visualization formatting logic the same]
451
  for j in range(gen_length):
452
  current_tok_id = current_generated_tokens[j].item()
453
+ previous_tok_id = previous_tokens_vis[j].item() if previous_tokens_vis is not None and j < len(previous_tokens_vis) else MASK_ID
454
 
455
  try:
456
+ # Use replace to handle potential bytes rendering issues
457
+ decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False, clean_up_tokenization_spaces=False)
458
  display_token = MASK_TOKEN if current_tok_id == MASK_ID else decoded_token
459
  except Exception:
460
  display_token = f"[ID:{current_tok_id}]" # Fallback
 
469
  else: # Token was already revealed
470
  color = "#6699CC" # Light Blue
471
 
 
 
472
  should_hide = (PAD_ID is not None and current_tok_id == PAD_ID) or \
473
  (EOS_ID is not None and current_tok_id == EOS_ID)
474
  if should_hide and previous_tok_id == current_tok_id:
475
  token_to_display = "" # Hide by making empty
476
  color = None # No color for hidden
477
 
 
478
  if token_to_display:
479
  vis_data.append((token_to_display, color))
480
+
481
+ previous_tokens_vis = current_generated_tokens # Update for next step
482
+
 
 
 
 
 
 
 
 
 
 
 
483
  intermediate_response_tokens = x[0, prompt_length:]
484
  intermediate_response_text = tokenizer.decode(
485
  intermediate_response_tokens,
 
487
  clean_up_tokenization_spaces=True
488
  ).strip()
489
 
 
 
 
490
  yield history_copy, vis_data, intermediate_response_text
491
  time.sleep(visualization_delay)
492
 
 
496
  # --- 6. Final Processing & Yield ---
497
  final_sequence = x[0]
498
  response_tokens = final_sequence[prompt_length:]
 
 
499
  final_response_text = tokenizer.decode(
500
  response_tokens,
501
+ skip_special_tokens=True,
502
  clean_up_tokenization_spaces=True
503
  ).strip()
 
 
504
  history_copy[-1][1] = final_response_text
505
 
 
 
506
  final_generated_tokens = x[0, prompt_length:].cpu()
507
  vis_data_final = []
508
+ # [Keep final visualization formatting logic the same]
509
  for j in range(gen_length):
510
  current_tok_id = final_generated_tokens[j].item()
511
+ previous_tok_id = previous_tokens_vis[j].item() if previous_tokens_vis is not None and j < len(previous_tokens_vis) else MASK_ID
 
512
  try:
513
+ decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False, clean_up_tokenization_spaces=False)
514
  display_token = MASK_TOKEN if current_tok_id == MASK_ID else decoded_token
515
  except Exception:
516
  display_token = f"[ID:{current_tok_id}]" # Fallback
 
517
  color = None
518
  token_to_display = display_token
519
+ if current_tok_id == MASK_ID: color = "#444444"
520
+ elif previous_tok_id == MASK_ID: color = "#66CC66"
521
+ else: color = "#6699CC"
 
 
 
 
 
522
  should_hide = (PAD_ID is not None and current_tok_id == PAD_ID) or \
523
  (EOS_ID is not None and current_tok_id == EOS_ID)
524
  if should_hide and previous_tok_id == current_tok_id:
525
+ token_to_display = ""; color = None
526
+ if token_to_display: vis_data_final.append((token_to_display, color))
 
 
 
527
 
 
528
  yield history_copy, vis_data_final, final_response_text
529
  print("Visualization streaming complete.")
530
 
 
531
  except Exception as e:
532
  print(f"Error during generation or processing: {e}")
533
  import traceback
534
  traceback.print_exc()
 
535
  yield history_copy, [("Error during generation.", "red")], ""
536
  return
537
 
538
 
539
+ # --- Gradio UI (No changes needed here) ---
540
  css = '''
541
  .category-legend{display:none}
542
  button{min-height: 60px}
 
549
  "[[Blog](https://hkunlp.github.io/blog/2025/dream/)]" # Note: Link might be hypothetical
550
  )
551
 
 
 
 
552
  _chat_history_store = gr.State([]) # Hidden state to store actual history list
553
 
 
554
  with gr.Row():
555
  with gr.Column(scale=3):
556
  chatbot_ui = gr.Chatbot(
 
558
  height=500,
559
  show_copy_button=True,
560
  bubble_full_width=False,
 
561
  )
 
 
562
  with gr.Group():
563
  with gr.Row():
564
  user_input = gr.Textbox(
 
567
  scale=7,
568
  autofocus=True,
569
  show_label=False,
570
+ container=False
571
  )
572
  send_btn = gr.Button("Send", scale=1, variant="primary")
 
 
573
  constraints_input = gr.Textbox(
574
  label="Word Constraints (Optional)",
575
  info="Place words at specific positions (0-indexed from start of generation). Format: 'pos:word, pos:word,...'. Example: '0:Once, 5:upon, 10:time'",
 
579
  with gr.Column(scale=2):
580
  output_vis = gr.HighlightedText(
581
  label="Denoising Process Visualization",
582
+ combine_adjacent=True,
583
+ show_legend=False,
584
+ interactive=False
585
  )
 
586
  response_text_display = gr.Textbox(
587
  label="Generated Response",
588
  interactive=False,
589
+ lines=5
590
  )
591
 
 
 
592
  with gr.Accordion("Generation Settings", open=False):
593
+ with gr.Row():
594
+ gen_length = gr.Slider(minimum=16, maximum=512, value=128, step=8, label="Max New Tokens")
595
+ steps = gr.Slider(minimum=8, maximum=512, value=128, step=8, label="Diffusion Steps")
596
+ with gr.Row():
597
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.05, label="Temperature (0 = greedy)")
598
+ alg_temp = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="Remasking Temp (Confidence Algs)")
599
+ with gr.Row():
600
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="Top-P (0 disables)")
601
+ top_k = gr.Slider(minimum=0, maximum=200, value=0, step=5, label="Top-K (0 disables)")
602
+ with gr.Row():
603
+ remasking_strategy = gr.Radio(choices=['origin', 'maskgit_plus', 'topk_margin', 'entropy'], value='entropy', label="Remasking Strategy (Algorithm)")
604
+ with gr.Row():
605
+ visualization_delay = gr.Slider(minimum=0.0, maximum=0.5, value=0.03, step=0.01, label="Visualization Delay (seconds)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
606
 
 
607
  clear_btn = gr.Button("Clear Conversation")
608
 
 
 
609
  def add_user_message_to_history(message: str, history_store: List[List[Optional[str]]]):
 
610
  if not message.strip():
611
  gr.Warning("Please enter a message.")
 
612
  return history_store, history_store, "", [], ""
 
 
613
  history_store.append([message, None])
 
614
  return history_store, history_store, "", [], ""
615
 
616
  def clear_conversation():
617
+ return [], [], "", [], ""
 
618
 
 
 
 
619
  generation_inputs = [
620
  _chat_history_store, gen_length, steps, constraints_input,
621
  temperature, top_p, top_k, remasking_strategy, alg_temp,
622
  visualization_delay
623
  ]
 
 
 
624
  generation_outputs = [chatbot_ui, output_vis, response_text_display]
625
 
 
626
  submit_listener = user_input.submit(
627
  fn=add_user_message_to_history,
628
  inputs=[user_input, _chat_history_store],
629
+ outputs=[_chat_history_store, chatbot_ui, user_input, output_vis, response_text_display]
630
+ ).then(
 
 
631
  fn=generate_dream_response,
632
  inputs=generation_inputs,
633
+ outputs=generation_outputs,
634
+ show_progress="hidden"
635
  )
636
 
 
637
  click_listener = send_btn.click(
638
  fn=add_user_message_to_history,
639
  inputs=[user_input, _chat_history_store],
640
+ outputs=[_chat_history_store, chatbot_ui, user_input, output_vis, response_text_display]
641
+ ).then(
 
 
642
  fn=generate_dream_response,
643
  inputs=generation_inputs,
644
+ outputs=generation_outputs,
645
  show_progress="hidden"
646
  )
647
 
 
648
  clear_btn.click(
649
  clear_conversation,
650
  inputs=[],
 
656
  # --- Launch ---
657
  if __name__ == "__main__":
658
  demo = create_chatbot_demo()
659
+ demo.queue().launch(debug=True, share=False)