multimodalart HF Staff commited on
Commit
825e87d
·
verified ·
1 Parent(s): 168a7c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +413 -607
app.py CHANGED
@@ -4,519 +4,335 @@ import numpy as np
4
  import gradio as gr
5
  import spaces
6
  import torch.nn.functional as F
7
- from transformers import AutoTokenizer, AutoModel
8
- from transformers.generation.configuration_utils import GenerationConfig
9
  import time
10
- import re
11
- import torch.distributions as dists # Import dists for sampling logic
12
 
13
- # --- Model Loading ---
14
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
  print(f"Using device: {device}")
16
 
17
- # Load Dream model and tokenizer
18
  model_path = "Dream-org/Dream-v0-Instruct-7B"
19
- # Load configuration first to get token IDs
20
- config = DreamConfig.from_pretrained(model_path) # Assuming configuration_dream.py is present
 
 
21
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
22
- model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True)
23
- model = model.to(device).eval()
24
- print("Model and Tokenizer loaded.")
25
 
26
- # --- Constants ---
27
- MASK_TOKEN = tokenizer.mask_token # "<|mask|>"
28
- MASK_ID = config.mask_token_id # Get from config (e.g., 151666)
29
- EOS_ID = config.eos_token_id # Get from config (e.g., 151643)
30
- PAD_ID = config.pad_token_id # Get from config (e.g., 151643)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # --- Helper Functions ---
33
 
34
- def parse_constraints(constraints_text):
35
  """Parse constraints in format: 'position:word, position:word, ...'"""
36
  constraints = {}
 
37
  if not constraints_text:
38
- return constraints
39
 
40
  parts = constraints_text.split(',')
41
  for part in parts:
42
  if ':' not in part:
43
  continue
 
44
  try:
45
- pos_str, word = part.split(':', 1)
46
  pos = int(pos_str.strip())
47
- # Use strip() and lower() for robustness if needed, but preserve case for now
48
  word = word.strip()
49
  if word and pos >= 0:
50
- # Tokenize the word - handle potential multi-token words
51
- # Add space prefix typical for non-leading words if pos > 0
 
 
52
  prefix = " " if pos > 0 else ""
53
  tokens = tokenizer.encode(prefix + word, add_special_tokens=False)
54
  for i, token_id in enumerate(tokens):
55
- # Only add if the token is not a special token id already
56
- # (This prevents accidental replacement of things like MASK_ID)
57
- if token_id not in [MASK_ID, EOS_ID, PAD_ID]:
58
- constraints[pos + i] = token_id
59
  except ValueError:
60
  continue
61
  except Exception as e:
62
- print(f"Error parsing constraint part '{part}': {e}")
63
- continue
64
-
65
- return constraints
66
 
 
 
 
 
 
67
 
68
  def format_chat_history(history):
69
  """
70
- Format chat history for the Dream model (using ChatML format potentially)
71
 
72
  Args:
73
  history: List of [user_message, assistant_message] pairs
74
 
75
  Returns:
76
- Formatted conversation for the model (list of message dicts)
77
  """
78
  messages = []
79
- # Check if the first message is a system prompt
80
- if history and history[0][0].lower().startswith("system:"):
81
- # Special handling if needed, or just treat as user
82
- # For now, let's assume standard user/assistant alternation
83
- pass # Or handle system prompt separately if template requires
84
-
85
- for i, (user_msg, assistant_msg) in enumerate(history):
86
- # Basic user/assistant structure
87
- messages.append({"role": "user", "content": user_msg})
88
- if assistant_msg is not None: # Skip if None (for the latest user message)
89
  messages.append({"role": "assistant", "content": assistant_msg})
90
 
91
  return messages
92
 
93
- # --- Core Generation Logic (Adapted from Dream's _sample) ---
94
 
95
- def sample_tokens_for_vis(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  """
97
- Simplified version of Dream's sample_tokens to get both token and confidence.
98
- Returns confidence and chosen token ID.
99
  """
100
- # Apply temperature
101
- if temperature > 0:
102
- logits = logits / temperature
103
-
104
- # Apply Top-P
105
- if top_p is not None and top_p < 1.0:
106
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
107
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
108
- sorted_indices_to_remove = cumulative_probs > top_p
109
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
110
- sorted_indices_to_remove[..., 0] = 0
111
- indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_remove)
112
- logits = logits.masked_fill(indices_to_remove, float('-inf'))
113
-
114
- # Apply Top-K
115
- if top_k is not None and top_k > 0:
116
- top_k = min(top_k, logits.size(-1))
117
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
118
- logits = logits.masked_fill(indices_to_remove, float('-inf'))
119
-
120
- # Calculate probabilities
121
- probs = torch.softmax(logits, dim=-1)
122
-
123
- # Sample or Argmax
124
- if temperature > 0:
125
- # Use torch distributions for robust sampling
126
- dist = dists.Categorical(probs=probs)
127
- x0 = dist.sample()
128
- # Gather confidence for the sampled token
129
- confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
130
- else:
131
- # Argmax for deterministic generation
132
- confidence, x0 = torch.max(probs, dim=-1)
133
-
134
- # --- Calculate specific confidence metrics if requested ---
135
- # Note: These modify the 'confidence' variable *after* sampling x0
136
- if margin_confidence:
137
- if probs.shape[-1] >= 2:
138
- # Ensure logits weren't completely masked, handle edge cases
139
- if not torch.isinf(logits).all(dim=-1).any():
140
- # Sort probabilities to get top1 and top2
141
- sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
142
- top1_probs = sorted_probs[..., 0]
143
- top2_probs = sorted_probs[..., 1]
144
- confidence = top1_probs - top2_probs
145
- else:
146
- # Fallback if all logits are -inf (shouldn't normally happen)
147
- confidence.fill_(0.0) # Or some other indicator
148
- else:
149
- # Only one possible token, margin is undefined or 1? Set to top1 prob.
150
- confidence, _ = torch.max(probs, dim=-1)
151
-
152
- elif neg_entropy:
153
- epsilon = 1e-9 # Slightly smaller epsilon
154
- log_probs = torch.log(probs + epsilon)
155
- # Negative entropy is sum(p * log(p))
156
- confidence = torch.sum(probs * log_probs, dim=-1) # Lower value (more negative) is higher confidence
157
-
158
- return confidence, x0
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
- @spaces.GPU
162
- @torch.no_grad()
163
- def generate_response_with_visualization_dream(
164
- messages, gen_length=64, steps=64,
165
- constraints=None, temperature=0.2, top_p=0.95, top_k=None, # Added top_k
166
- alg="entropy", alg_temp=0.1, # Dream specific params
167
- yield_intermediate=True # Control yielding behavior
168
- ):
169
- """
170
- Generate text with Dream model with real-time visualization.
171
- Adapts logic from Dream's _sample method.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
- Args:
174
- messages: List of message dictionaries with 'role' and 'content'
175
- gen_length: Max new tokens to generate
176
- steps: Number of diffusion steps
177
- constraints: Dictionary mapping positions to *token IDs*
178
- temperature: Sampling temperature
179
- top_p: Nucleus sampling probability
180
- top_k: Top-k sampling
181
- alg: Remasking strategy ('origin', 'maskgit_plus', 'topk_margin', 'entropy')
182
- alg_temp: Temperature for confidence-based remasking randomness
183
- yield_intermediate: Whether to yield intermediate states for visualization
184
 
185
- Returns:
186
- Yields visualization states or returns final state list, and final text.
187
- """
188
- if constraints is None:
189
- constraints = {} # keys are positions relative to start of response
190
-
191
- # --- Prepare Input ---
192
- chat_input_text = tokenizer.apply_chat_template(
193
- messages, add_generation_prompt=True, tokenize=False
194
- )
195
- input_ids = tokenizer(chat_input_text, return_tensors="pt")['input_ids'].to(device)
196
- prompt_length = input_ids.shape[1]
197
- max_length = prompt_length + gen_length
198
-
199
- # Clamp max_length if it exceeds model capacity (use config value if available)
200
- model_max_len = getattr(config, 'max_position_embeddings', 2048) # Default fallback
201
- if max_length > model_max_len:
202
- print(f"Warning: Requested length ({max_length}) exceeds model max ({model_max_len}). Clamping.")
203
- max_length = model_max_len
204
- gen_length = max_length - prompt_length
205
- if gen_length <= 0:
206
- print("Warning: Prompt is already at or exceeding model max length. Cannot generate.")
207
- if yield_intermediate:
208
- yield [], "Error: Prompt too long."
209
- return
210
- else:
211
- return [], "Error: Prompt too long."
212
 
 
 
 
 
 
213
 
214
- # Initialize sequence 'x' with input_ids and padding with MASK_ID
215
- x = torch.full((1, max_length), MASK_ID, dtype=torch.long, device=device)
216
- x[:, :prompt_length] = input_ids.clone()
217
 
218
- # Apply initial constraints to x (relative position -> absolute position)
219
- for rel_pos, token_id in constraints.items():
220
- abs_pos = prompt_length + rel_pos
221
- if abs_pos < max_length:
222
- # Ensure we don't overwrite prompt or special tokens accidentally
223
- if token_id not in [MASK_ID, EOS_ID, PAD_ID]:
224
- x[:, abs_pos] = token_id
225
- else:
226
- print(f"Warning: Skipped constraint for special token ID {token_id} at pos {rel_pos}")
227
 
 
 
 
228
 
229
- # --- Visualization Setup ---
230
- visualization_states = []
231
- revealed_eos_pad = set() # Track positions where EOS/PAD was shown once
232
-
233
- def get_vis_state(current_x, old_x, step_confidences=None):
234
- nonlocal revealed_eos_pad
235
- state = []
236
- newly_revealed_in_step = False # Flag if any token changed from MASK
237
- current_revealed_eos_pad = set() # Track EOS/PAD revealed *in this step*
238
-
239
- for i in range(gen_length):
240
- abs_pos = prompt_length + i
241
- current_token_id = current_x[0, abs_pos].item()
242
- old_token_id = old_x[0, abs_pos].item()
243
-
244
- is_eos_or_pad = (current_token_id == EOS_ID or current_token_id == PAD_ID)
245
-
246
- # Handle EOS/PAD hiding: Show once, then hide
247
- if is_eos_or_pad and abs_pos in revealed_eos_pad:
248
- state.append(("", "#FFFFFF")) # Make it invisible (white on white/transparent)
249
- continue # Skip rest of logic for this pos
250
-
251
- token_str = tokenizer.decode([current_token_id], skip_special_tokens=False) # Decode even specials initially
252
-
253
- if current_token_id == MASK_ID:
254
- color = "#444444" # Dark Gray for Mask
255
- token_str = MASK_TOKEN # Display mask token string
256
- elif old_token_id == MASK_ID: # Newly revealed in this step
257
- newly_revealed_in_step = True
258
- confidence = step_confidences.get(abs_pos, 0.5) # Get confidence if available, default 0.5
259
-
260
- # Color based on confidence (adjust thresholds as needed)
261
- # Note: Entropy confidence is negative, more negative = higher confidence
262
- if alg == 'entropy':
263
- # Example thresholds for negative entropy (adjust based on observation)
264
- if confidence > -1.0: # Low confidence (high entropy)
265
- color = "#FF6666" # Light Red
266
- elif confidence > -3.0: # Medium confidence
267
- color = "#FFAA33" # Orange
268
- else: # High confidence (low entropy)
269
- color = "#66CC66" # Light Green
270
- else: # Standard confidence (probability or margin)
271
- if confidence < 0.3:
272
- color = "#FF6666" # Light Red
273
- elif confidence < 0.7:
274
- color = "#FFAA33" # Orange
275
- else:
276
- color = "#66CC66" # Light Green
277
-
278
- # If it's EOS/PAD revealed now, mark for future hiding
279
- if is_eos_or_pad:
280
- current_revealed_eos_pad.add(abs_pos)
281
-
282
- else: # Previously revealed
283
- color = "#6699CC" # Light Blue
284
-
285
- # Clean up token string for display (optional)
286
- # token_str = token_str.replace(" ", " ") # Keep spaces visible
287
-
288
- state.append((token_str, color))
289
-
290
- # Update the global set of revealed EOS/PAD positions
291
- revealed_eos_pad.update(current_revealed_eos_pad)
292
-
293
- return state, newly_revealed_in_step
294
-
295
- # Add initial state (all masked, constraints applied)
296
- initial_vis_state, _ = get_vis_state(x, x) # Pass x as old_x initially
297
- visualization_states.append(initial_vis_state)
298
- if yield_intermediate:
299
- yield initial_vis_state # Yield the starting state
300
-
301
- # --- Diffusion Loop ---
302
- timesteps = torch.linspace(1.0, 1e-3, steps + 1, device=device) # Use epsilon from Dream's defaults if needed
303
-
304
- # Store the state before the loop starts
305
- old_x = x.clone()
306
-
307
- for i in range(steps):
308
- # --- Core Dream Step ---
309
- mask_index = (x == MASK_ID)
310
- if not mask_index.any(): # Stop if no masks left
311
- print(f"No masks left at step {i}. Stopping generation.")
312
- break
313
-
314
- # Prepare attention mask (full attention for Dream unless specified otherwise)
315
- # Dream's modeling code handles standard causal masking internally based on position_ids
316
- # For diffusion, we typically allow attending to everything (masked or not)
317
- # The `model` forward pass expects a standard causal mask or None
318
- # Let's use None, assuming the model handles positions correctly
319
- attention_mask = None # Or potentially create a full mask: torch.ones_like(x)
320
-
321
- # Create position_ids (simple range for now)
322
- position_ids = torch.arange(0, x.shape[1], device=device).unsqueeze(0)
323
-
324
- # Model forward pass
325
- outputs = model(input_ids=x, attention_mask=attention_mask, position_ids=position_ids)
326
- logits = outputs.logits
327
- # logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1) # Dream applies shift in utils, replicate if needed
328
-
329
- # Select logits for masked positions ONLY
330
- # Need to handle batch dimension (which is 1 here)
331
- current_mask_indices_flat = torch.where(mask_index.flatten())[0]
332
- if len(current_mask_indices_flat) == 0:
333
- print(f"No mask indices found flat at step {i}. Stopping generation.")
334
- break
335
-
336
- # Use advanced indexing to get logits for masked positions
337
- # Logits shape: [batch_size, seq_len, vocab_size]
338
- # Mask_index shape: [batch_size, seq_len]
339
- # We need logits corresponding to True values in mask_index
340
- # Example: batch_idx = torch.where(mask_index)[0], seq_idx = torch.where(mask_index)[1]
341
- # mask_logits = logits[batch_idx, seq_idx]
342
- batch_indices, seq_indices = torch.where(mask_index)
343
- mask_logits = logits[batch_indices, seq_indices] # Shape: [num_masked_tokens, vocab_size]
344
-
345
- if mask_logits.numel() == 0: # Double check after indexing
346
- print(f"No mask logits selected at step {i}. Stopping generation.")
347
- break
348
-
349
- t = timesteps[i]
350
- s = timesteps[i + 1]
351
-
352
- # --- Remasking Logic (Simplified from Dream's _sample) ---
353
- step_confidences = {} # Store confidences for revealed tokens in this step {abs_pos: confidence}
354
-
355
- if alg == 'origin':
356
- p_transfer = (1.0 - s / t) if i < steps - 1 else 1.0
357
- # Sample for all masked positions
358
- confidence, x0_masked = sample_tokens_for_vis(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
359
- # Decide which ones to transfer based on random probability
360
- transfer_mask = torch.rand(x0_masked.shape, device=device) < p_transfer
361
- # Create a tensor of MASK_IDs, and fill in the transferred tokens
362
- updates_for_masked_pos = torch.full_like(x0_masked, MASK_ID)
363
- updates_for_masked_pos[transfer_mask] = x0_masked[transfer_mask]
364
- # Update x at the masked positions
365
- x[mask_index] = updates_for_masked_pos
366
-
367
- # Store confidences for the *transferred* tokens for visualization
368
- transferred_indices_flat = current_mask_indices_flat[transfer_mask]
369
- transferred_confidences = confidence[transfer_mask]
370
- for flat_idx, conf in zip(transferred_indices_flat, transferred_confidences):
371
- abs_pos = flat_idx.item() # Convert flat index back to seq position (assuming batch=1)
372
- step_confidences[abs_pos] = conf.item()
373
-
374
-
375
- else: # Confidence-based algorithms ('maskgit_plus', 'topk_margin', 'entropy')
376
- use_margin = (alg == 'topk_margin')
377
- use_entropy = (alg == 'entropy')
378
- # Sample potential replacements for ALL masked positions first
379
- confidence, x0_masked = sample_tokens_for_vis(
380
- mask_logits,
381
- temperature=temperature,
382
- top_p=top_p,
383
- top_k=top_k,
384
- margin_confidence=use_margin,
385
- neg_entropy=use_entropy
386
- )
387
 
388
- num_mask_tokens = mask_index.sum().item()
389
- # Calculate how many tokens to unmask/transfer in this step
390
- num_transfer_tokens = int(num_mask_tokens * (1.0 - s / t)) if i < steps - 1 else num_mask_tokens
391
-
392
- if num_transfer_tokens > 0 and confidence.numel() > 0:
393
- transfer_indices_relative = None # Indices relative to the masked tokens
394
- if alg_temp is None or alg_temp <= 0:
395
- # Deterministic: Select top-k confidence scores among masked tokens
396
- # Ensure k is not larger than the number of masked tokens
397
- k = min(num_transfer_tokens, confidence.shape[0])
398
- if k > 0:
399
- _, transfer_indices_relative = torch.topk(confidence, k)
400
- else:
401
- # Stochastic: Sample based on confidence scores
402
- # Ensure probabilities are valid
403
- conf_probs = F.softmax(confidence / alg_temp, dim=-1)
404
- if not torch.isnan(conf_probs).any() and not torch.isinf(conf_probs).any() and conf_probs.sum() > 1e-6:
405
- # Ensure k is not larger than the number of masked tokens
406
- k = min(num_transfer_tokens, confidence.shape[0])
407
- if k > 0:
408
- transfer_indices_relative = torch.multinomial(conf_probs, num_samples=k, replacement=False)
409
- else:
410
- print(f"Warning: Invalid confidence probabilities at step {i}. Falling back to top-k.")
411
- # Fallback to deterministic if sampling fails
412
- k = min(num_transfer_tokens, confidence.shape[0])
413
- if k > 0:
414
- _, transfer_indices_relative = torch.topk(confidence, k)
415
-
416
-
417
- if transfer_indices_relative is not None and transfer_indices_relative.numel() > 0:
418
- # Create updates, initially all MASK_ID
419
- updates_for_masked_pos = torch.full_like(x0_masked, MASK_ID)
420
- # Place the selected sampled tokens into the updates tensor
421
- updates_for_masked_pos[transfer_indices_relative] = x0_masked[transfer_indices_relative]
422
- # Update x at the original masked positions
423
- x[mask_index] = updates_for_masked_pos
424
-
425
- # Store confidences for the *transferred* tokens for visualization
426
- selected_confidences = confidence[transfer_indices_relative]
427
- # Get the absolute positions corresponding to these relative indices
428
- original_indices_flat = current_mask_indices_flat[transfer_indices_relative]
429
- for flat_idx, conf in zip(original_indices_flat, selected_confidences):
430
- abs_pos = flat_idx.item()
431
- step_confidences[abs_pos] = conf.item()
432
 
433
- else:
434
- # No tokens were selected to transfer, x remains unchanged for masked parts
435
- pass # x[mask_index] remains MASK_ID essentially
436
 
437
- else:
438
- # If num_transfer_tokens is 0, x remains unchanged for masked parts
439
- pass
440
-
441
- # --- Apply Constraints and Finalize Step ---
442
- # Ensure constraints are always maintained AFTER updates
443
- for rel_pos, token_id in constraints.items():
444
- abs_pos = prompt_length + rel_pos
445
- if abs_pos < max_length:
446
- # Check if the position was masked before applying constraint
447
- # if mask_index[0, abs_pos]: # Only apply if it *was* a mask, maybe? Optional.
448
- x[:, abs_pos] = token_id
449
-
450
- # --- Visualization Update ---
451
- current_vis_state, newly_revealed = get_vis_state(x, old_x, step_confidences)
452
-
453
- # Only add/yield if something actually changed or if it's the last step
454
- if newly_revealed or i == steps - 1:
455
- visualization_states.append(current_vis_state)
456
- if yield_intermediate:
457
- yield current_vis_state
458
-
459
- # Update old_x for the next iteration
460
- old_x = x.clone()
461
-
462
-
463
- # --- Final Output ---
464
- response_tokens = x[0, prompt_length:]
465
- # Decode, cleaning up potential special tokens unless they are intended
466
- final_text = tokenizer.decode(response_tokens,
467
- skip_special_tokens=True, # Skip things like <|mask|> in final output
468
- clean_up_tokenization_spaces=True)
469
-
470
- # If not yielding intermediates, return the full list now
471
- if not yield_intermediate:
472
- return visualization_states, final_text
473
- else:
474
- # If yielding intermediates, we still need a way to signal completion
475
- # and return the final text. Gradio's yield typically handles this if
476
- # the last yielded value is the final one. We'll return the final text
477
- # separately after the loop finishes in the calling function.
478
- # The loop yields states, the calling function returns the final text.
479
- pass # Final text is handled outside the generator function
480
-
481
-
482
- # --- Gradio UI ---
483
  css = '''
484
  .category-legend{display:none}
485
  button{height: 60px}
486
- .token-revealed { transition: background-color 0.5s ease; } /* Optional: Add transition effect */
487
- .token-masked { background-color: #444444; color: white; padding: 1px 2px; margin: 1px; border-radius: 3px; display: inline-block; }
488
- .token-new-high { background-color: #66CC66; color: black; padding: 1px 2px; margin: 1px; border-radius: 3px; display: inline-block; }
489
- .token-new-mid { background-color: #FFAA33; color: black; padding: 1px 2px; margin: 1px; border-radius: 3px; display: inline-block; }
490
- .token-new-low { background-color: #FF6666; color: black; padding: 1px 2px; margin: 1px; border-radius: 3px; display: inline-block; }
491
- .token-old { background-color: #6699CC; color: white; padding: 1px 2px; margin: 1px; border-radius: 3px; display: inline-block; }
492
- .token-hidden { display: none; } /* Hide EOS/PAD after first reveal */
493
  '''
494
-
495
  def create_chatbot_demo():
496
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
497
  gr.Markdown("# Dream 7B - Diffusion Language Model Demo")
498
  gr.Markdown(
499
  "[[Model Card](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)] "
500
- "[[Blog](https://hkunlp.github.io/blog/2025/dream/)] "
501
- "[[Original LLaDA Demo Inspiration](https://huggingface.co/spaces/GSAI-ML/LLaDA-demo)]"
502
- )
503
- gr.Markdown(
504
- "**Note:** This demo visualizes the diffusion process in real-time. "
505
- "Tokens start masked (<font color='#444444'>[MASK]</font>) and are revealed step-by-step. "
506
- "Colors indicate confidence: <font color='#66CC66'>High</font>, "
507
- "<font color='#FFAA33'>Medium</font>, <font color='#FF6666'>Low</font>. "
508
- "Previously revealed tokens are <font color='#6699CC'>blue</font>. "
509
- f"EOS/PAD tokens ({tokenizer.decode([EOS_ID])}) are hidden after appearing once."
510
  )
511
 
512
  # STATE MANAGEMENT
513
  chat_history = gr.State([])
514
- current_response_text = gr.State("") # Store the final text separately
 
515
 
516
  # UI COMPONENTS
517
  with gr.Row():
518
  with gr.Column(scale=3):
519
- chatbot_ui = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False)
 
 
 
 
520
 
521
  # Message input
522
  with gr.Group():
@@ -524,229 +340,219 @@ def create_chatbot_demo():
524
  user_input = gr.Textbox(
525
  label="Your Message",
526
  placeholder="Type your message here...",
527
- scale=7,
528
- show_label=False
529
  )
530
  send_btn = gr.Button("Send", scale=1)
531
 
532
  constraints_input = gr.Textbox(
533
- label="Word Constraints (Relative Position)",
534
- info="Place words at specific 0-indexed positions in the *response*. Format: 'pos:word, pos:word'. Example: '0:Once, 5:upon, 10:time'",
535
- placeholder="0:Hello, 10:world",
536
  value=""
537
  )
538
  with gr.Column(scale=2):
539
- # Use HighlightedText with specific classes for better styling control
540
  output_vis = gr.HighlightedText(
541
  label="Denoising Process Visualization",
542
- # Show legend mapping colors to confidence might be useful if classes aren't self-explanatory
543
- # For now, using the description markdown above.
544
- show_legend=False,
545
- # Use custom classes defined in CSS
546
- # color_map={ # This might not work directly with dynamic classes, CSS is better
547
- # "MASK": "#444444", "NEW_H": "#66CC66", "NEW_M": "#FFAA33",
548
- # "NEW_L": "#FF6666", "OLD": "#6699CC", "HIDDEN": "#FFFFFF"
549
- # }
550
- combine_adjacent=False, # Keep tokens separate
551
- height=550, # Adjust height as needed
552
  )
553
 
554
-
555
  # Advanced generation settings
556
  with gr.Accordion("Generation Settings", open=False):
557
  with gr.Row():
558
  gen_length = gr.Slider(
559
- minimum=16, maximum=512, value=64, step=8, # Increased max length
560
  label="Max New Tokens"
561
  )
562
  steps = gr.Slider(
563
- minimum=8, maximum=512, value=64, step=4, # Allow more steps
564
- label="Diffusion Steps"
565
  )
566
  with gr.Row():
567
  temperature = gr.Slider(
568
- minimum=0.0, maximum=1.5, value=0.2, step=0.05, # Wider range for temp
569
  label="Temperature"
570
  )
571
  top_p = gr.Slider(
572
- minimum=0.0, maximum=1.0, value=0.95, step=0.05,
573
- label="Top-P (Nucleus Sampling)"
574
- )
575
- # top_k = gr.Slider(
576
- # minimum=0, maximum=200, value=0, step=5, # Allow Top-K=0 (disabled)
577
- # label="Top-K (0 to disable)"
578
- # )
579
- with gr.Row():
580
- # Dream specific algorithm choice
581
- alg_strategy = gr.Radio(
582
- choices=["entropy", "maskgit_plus", "topk_margin", "origin"],
583
- value="entropy",
584
- label="Algorithm (`alg`)"
585
  )
586
- alg_temp = gr.Slider(
587
- minimum=0.0, maximum=1.0, value=0.1, step=0.01,
588
- label="Algorithm Temp (`alg_temp`)"
589
  )
 
 
 
 
 
 
 
 
 
 
 
590
  with gr.Row():
591
  visualization_delay = gr.Slider(
592
- minimum=0.0, maximum=0.5, value=0.03, step=0.01, # Faster default delay
593
  label="Visualization Delay (seconds)"
594
  )
595
 
596
  # Clear button
597
  clear_btn = gr.Button("Clear Conversation")
598
 
599
- # --- Helper Functions for UI ---
600
- def add_message_to_history(history, message, response):
601
- """Add a message pair to the history state"""
 
 
 
602
  history.append([message, response])
603
  return history
604
 
605
- def user_message_action(message, history):
606
- """Handles user sending a message: updates history, clears input."""
607
- if not message or message.strip() == "":
608
- return history, history, "", [], "" # Return empty vis, empty response
609
-
610
- # Add user message with None response placeholder
611
- history = add_message_to_history(history, message, None)
612
- # Return updated history for chatbot display, clear input box
613
- return history, history, "", [], "" # Clear vis and response text state too
614
-
615
- def bot_response_generator(
616
- history, gen_length, steps, constraints_str, delay,
617
- temperature, top_p, # top_k,
618
- alg, alg_temp
619
- ):
620
- """Generates bot response and yields visualization states."""
621
- if not history or history[-1][1] is not None: # Check if last message already has a response
622
- print("History empty or last message already processed.")
623
- yield history, [], "" # Yield empty state if no work to do
624
- return
625
-
626
- last_user_message = history[-1][0]
627
- print(f"Generating response for: {last_user_message}")
628
-
629
- try:
630
- # Format history for the model (excluding the last None response)
631
- messages = format_chat_history(history[:-1])
632
- # Add the current user message
633
- messages.append({"role": "user", "content": last_user_message})
634
-
635
- # Parse constraints into token IDs
636
- parsed_constraints = parse_constraints(constraints_str)
637
- print(f"Parsed constraints: {parsed_constraints}")
638
-
639
-
640
- final_text = "" # Initialize final_text
641
-
642
- # Use the generator function
643
- response_generator = generate_response_with_visualization_dream(
644
- messages,
645
- gen_length=gen_length,
646
- steps=steps,
647
- constraints=parsed_constraints,
648
- temperature=temperature,
649
- top_p=top_p if top_p > 0 else None, # Pass None if 0
650
- top_k=None, # Pass None if 0 top_k if top_k > 0 else None,
651
- alg=alg,
652
- alg_temp=alg_temp if alg_temp > 0 else None, # Pass None if 0
653
- yield_intermediate=True
654
- )
655
 
656
- # Iterate through the yielded visualization states
657
- last_state = None
658
- for vis_state in response_generator:
659
- last_state = vis_state
660
- # Update chatbot with placeholder during generation
661
- history[-1][1] = "..." # Indicate thinking
662
- yield history, vis_state, "..." # Yield history, current vis state, placeholder text
663
- if delay > 0:
664
- time.sleep(delay)
665
-
666
- # --- Generation Finished ---
667
- # Extract final text (needs to be done *after* the generator is exhausted)
668
- # Re-run the generation without yielding intermediates to get the final text reliably
669
- # (Or modify the generator to return it, but this is simpler for now)
670
- # TODO: Optimize this - maybe the generator could return the final text at the end?
671
-
672
- print("Re-generating final text (non-streaming)...")
673
- final_vis_states, final_text = generate_response_with_visualization_dream(
674
- messages, gen_length, steps, parsed_constraints, temperature,
675
- top_p if top_p > 0 else None, None, #top_k if top_k > 0 else None,
676
- alg, alg_temp if alg_temp > 0 else None,
677
- yield_intermediate=False # Get final result only
678
- )
679
- print(f"Final Text: {final_text}")
680
 
 
 
681
 
682
- # Update the history with the actual final response
683
- history[-1][1] = final_text.strip() if final_text else "[No response]"
684
 
685
- # Yield the final state one last time
686
- yield history, final_vis_states[-1] if final_vis_states else [], final_text.strip()
 
 
 
 
687
 
688
- except Exception as e:
689
- import traceback
690
- print(f"Error during generation: {e}")
691
- traceback.print_exc()
692
- error_msg = f"Error: {str(e)}"
693
- history[-1][1] = error_msg # Show error in chat
694
- # Show error in visualization (red text)
695
- error_vis = [(error_msg, "#FF0000")]
696
- yield history, error_vis, error_msg
697
 
 
 
 
 
 
 
698
 
699
- def clear_conversation_action():
700
- """Clears chat history, visualization, and response text."""
701
- return [], [], "", [] # History, Chatbot UI, Response Text, Visualization
702
 
 
 
 
703
 
704
- # --- Event Wiring ---
 
705
 
706
- # 1. User Submits Message (Textbox Enter or Button Click)
707
- submit_triggers = [user_input.submit, send_btn.click]
708
- for trigger in submit_triggers:
709
- trigger.then(
710
- fn=user_message_action,
711
- inputs=[user_input, chat_history],
712
- outputs=[chat_history, chatbot_ui, user_input, output_vis, current_response_text], # Update history state, chatbot UI, clear input, clear vis, clear response state
713
- queue=True # Enable queue for handling multiple users
714
- ).then(
715
- # 2. Trigger Bot Response Generation (Generator Function)
716
- fn=bot_response_generator,
717
- inputs=[
718
- chat_history, gen_length, steps, constraints_input, visualization_delay,
719
- temperature, top_p, # top_k,
720
- alg_strategy, alg_temp
721
- ],
722
- outputs=[chatbot_ui, output_vis, current_response_text] # Stream updates to chatbot, visualization, and store final text
723
- )
724
 
725
- # Clear Button Action
 
 
 
 
 
 
726
  clear_btn.click(
727
- fn=clear_conversation_action,
728
  inputs=[],
729
- outputs=[chat_history, chatbot_ui, current_response_text, output_vis],
730
- queue=False # No need to queue clear action
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
731
  )
732
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
733
  return demo
734
 
735
- # --- Launch ---
736
  if __name__ == "__main__":
737
- # Make sure the necessary Dream model files (modeling_dream.py, configuration_dream.py etc.)
738
- # are in the same directory or accessible in the Python path.
739
- # Also ensure 'generation_utils.py' is available if needed by the model loading/config.
740
- # Check if 'modeling_dream.py' exists before launching
741
- import os
742
- if not os.path.exists("modeling_dream.py") or not os.path.exists("configuration_dream.py"):
743
- print("\nERROR: Could not find 'modeling_dream.py' and/or 'configuration_dream.py'.")
744
- print("Please make sure these files (from the 'dream_model.txt' source) are in the same directory as this script.")
745
- print("You might need to extract them from the provided text file.")
746
- # exit() # Optional: stop execution if files are missing
747
-
748
- print("Creating Gradio Demo...")
749
  demo = create_chatbot_demo()
750
- print("Launching Gradio Demo...")
751
- # Use queueing for better user experience with potentially long generation times
752
- demo.queue().launch(share=True, debug=True) # Enable debug for more detailed logs
 
4
  import gradio as gr
5
  import spaces
6
  import torch.nn.functional as F
7
+ from transformers import AutoTokenizer, AutoModel, AutoConfig
 
8
  import time
9
+ import copy
 
10
 
11
+ # Determine device
12
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
  print(f"Using device: {device}")
14
 
15
+ # --- Model and Tokenizer Loading ---
16
  model_path = "Dream-org/Dream-v0-Instruct-7B"
17
+
18
+ print(f"Loading tokenizer from {model_path}...")
19
+ # Load configuration first to get special token IDs
20
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
21
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 
 
 
22
 
23
+ print(f"Loading model from {model_path}...")
24
+ model = AutoModel.from_pretrained(
25
+ model_path,
26
+ torch_dtype=torch.bfloat16,
27
+ trust_remote_code=True
28
+ ).to(device).eval()
29
+ print("Model loaded successfully.")
30
+
31
+ # --- Constants from Dream Model ---
32
+ # Get IDs directly from config or tokenizer if available
33
+ MASK_TOKEN = tokenizer.mask_token
34
+ MASK_ID = config.mask_token_id if hasattr(config, 'mask_token_id') else tokenizer.mask_token_id
35
+ EOS_ID = config.eos_token_id if hasattr(config, 'eos_token_id') else tokenizer.eos_token_id
36
+ PAD_ID = config.pad_token_id if hasattr(config, 'pad_token_id') else tokenizer.pad_token_id # Often same as EOS
37
+
38
+ print(f"MASK_TOKEN: '{MASK_TOKEN}', MASK_ID: {MASK_ID}")
39
+ print(f"EOS_ID: {EOS_ID}, PAD_ID: {PAD_ID}")
40
+ if MASK_ID is None:
41
+ raise ValueError("Could not determine MASK_ID from model config or tokenizer.")
42
+ if EOS_ID is None:
43
+ raise ValueError("Could not determine EOS_ID from model config or tokenizer.")
44
+ if PAD_ID is None:
45
+ raise ValueError("Could not determine PAD_ID from model config or tokenizer.")
46
+
47
 
48
  # --- Helper Functions ---
49
 
50
+ def parse_constraints(constraints_text, tokenizer):
51
  """Parse constraints in format: 'position:word, position:word, ...'"""
52
  constraints = {}
53
+ processed_constraints_tokens = {}
54
  if not constraints_text:
55
+ return constraints, processed_constraints_tokens
56
 
57
  parts = constraints_text.split(',')
58
  for part in parts:
59
  if ':' not in part:
60
  continue
61
+ pos_str, word = part.split(':', 1)
62
  try:
 
63
  pos = int(pos_str.strip())
 
64
  word = word.strip()
65
  if word and pos >= 0:
66
+ # Store original word constraint for display/debugging if needed
67
+ constraints[pos] = word
68
+ # Tokenize the word (add space for consistency if not BOS)
69
+ # Note: Dream tokenizer might handle spaces differently, adjust if needed
70
  prefix = " " if pos > 0 else ""
71
  tokens = tokenizer.encode(prefix + word, add_special_tokens=False)
72
  for i, token_id in enumerate(tokens):
73
+ # Prevent overwriting multi-token constraints partially
74
+ if pos + i not in processed_constraints_tokens:
75
+ processed_constraints_tokens[pos + i] = token_id
 
76
  except ValueError:
77
  continue
78
  except Exception as e:
79
+ print(f"Error tokenizing constraint word '{word}': {e}")
80
+ continue
 
 
81
 
82
+ # Sort by position for consistent application
83
+ processed_constraints_tokens = dict(sorted(processed_constraints_tokens.items()))
84
+ print(f"Parsed Constraints (Word): {constraints}")
85
+ print(f"Parsed Constraints (Tokens): {processed_constraints_tokens}")
86
+ return constraints, processed_constraints_tokens
87
 
88
  def format_chat_history(history):
89
  """
90
+ Format chat history for the Dream model using its chat template convention.
91
 
92
  Args:
93
  history: List of [user_message, assistant_message] pairs
94
 
95
  Returns:
96
+ Formatted list of message dictionaries for the model
97
  """
98
  messages = []
99
+ # Add system prompt if not present (standard practice)
100
+ if not history or history[0][0] is None or history[0][0].lower() != "system":
101
+ messages.append({"role": "system", "content": "You are a helpful assistant."})
102
+
103
+ for user_msg, assistant_msg in history:
104
+ if user_msg is not None: # Handle potential system message case
105
+ messages.append({"role": "user", "content": user_msg})
106
+ if assistant_msg: # Skip if None (for the latest user message)
 
 
107
  messages.append({"role": "assistant", "content": assistant_msg})
108
 
109
  return messages
110
 
111
+ # --- Core Generation Logic with Visualization Hook ---
112
 
113
+ @spaces.GPU
114
+ def generate_response_with_visualization(
115
+ messages, # List of message dictionaries
116
+ gen_length=64,
117
+ steps=64,
118
+ constraints_text="", # Raw constraint text
119
+ temperature=0.2,
120
+ top_p=0.95,
121
+ top_k=None, # Added for Dream
122
+ alg="entropy", # Changed from remasking
123
+ alg_temp=0.0, # Added for Dream
124
+ visualization_delay=0.05,
125
+ tokenizer=tokenizer,
126
+ model=model,
127
+ device=device,
128
+ MASK_ID=MASK_ID,
129
+ EOS_ID=EOS_ID,
130
+ PAD_ID=PAD_ID
131
+ ):
132
  """
133
+ Generate text with Dream model with real-time visualization using a hook.
 
134
  """
135
+ visualization_states = []
136
+ final_text = ""
137
+ # Use a list to hold previous_x, allowing nonlocal modification
138
+ # Initialize with None, it will be set after the first hook call
139
+ shared_state = {'previous_x': None}
140
+
141
+
142
+ try:
143
+ # --- 1. Prepare Inputs ---
144
+ _, parsed_constraints_tokens = parse_constraints(constraints_text, tokenizer)
145
+
146
+ # Apply chat template
147
+ # Important: Keep tokenize=False initially to get prompt length correctly
148
+ # The template adds roles and special tokens like <|im_start|> etc.
149
+ chat_input_text = tokenizer.apply_chat_template(
150
+ messages,
151
+ add_generation_prompt=True, # Adds the prompt for the assistant's turn
152
+ tokenize=False
153
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
+ # Tokenize the full templated chat string
156
+ inputs = tokenizer(chat_input_text, return_tensors="pt", return_dict=True)
157
+ input_ids = inputs.input_ids.to(device)
158
+ attention_mask = inputs.attention_mask.to(device) # Use mask from tokenizer
159
+
160
+ prompt_length = input_ids.shape[1]
161
+ total_length = prompt_length + gen_length
162
+
163
+ # --- 2. Initialize Generation Sequence ---
164
+ # Start with the prompt, pad the rest with MASK_ID
165
+ x = torch.full((1, total_length), MASK_ID, dtype=torch.long, device=device)
166
+ x[:, :prompt_length] = input_ids.clone()
167
+ attention_mask = F.pad(attention_mask, (0, gen_length), value=1) # Extend attention mask
168
+
169
+ # Apply initial constraints to the masked sequence `x`
170
+ for pos, token_id in parsed_constraints_tokens.items():
171
+ absolute_pos = prompt_length + pos
172
+ if absolute_pos < total_length:
173
+ print(f"Applying initial constraint at pos {absolute_pos}: token {token_id}")
174
+ x[:, absolute_pos] = token_id
175
+
176
+ # Store initial state (prompt + all masked) for visualization
177
+ initial_state_vis = []
178
+ # Add prompt tokens (optional visualization, could be grayed out or skipped)
179
+ # for i in range(prompt_length):
180
+ # token_str = tokenizer.decode([x[0, i].item()], skip_special_tokens=True)
181
+ # initial_state_vis.append((token_str if token_str else " ", "#AAAAAA")) # Gray for prompt
182
+
183
+ # Add masked tokens for the generation part
184
+ for _ in range(gen_length):
185
+ initial_state_vis.append((MASK_TOKEN, "#444444")) # Dark gray for masks
186
+ visualization_states.append(initial_state_vis)
187
+ shared_state['previous_x'] = x.clone() # Initialize previous_x
188
+
189
+
190
+ # --- 3. Define the Visualization Hook ---
191
+ def generation_tokens_hook_func(step, current_x_hook, logits):
192
+ # nonlocal previous_x # Allow modification of the outer scope variable
193
+ current_x_hook = current_x_hook.clone() # Work on a copy
194
+
195
+ # --- Apply constraints within the hook ---
196
+ # This ensures constraints are respected even if the model tries to overwrite them
197
+ for pos, token_id in parsed_constraints_tokens.items():
198
+ absolute_pos = prompt_length + pos
199
+ if absolute_pos < total_length:
200
+ current_x_hook[:, absolute_pos] = token_id
201
+ # --- End Constraint Application ---
202
+
203
+ if shared_state['previous_x'] is None: # First call
204
+ shared_state['previous_x'] = current_x_hook.clone()
205
+ return current_x_hook # Must return the (potentially modified) sequence
206
+
207
+ # Generate visualization state for this step
208
+ current_state_vis = []
209
+ prev_x_step = shared_state['previous_x']
210
+
211
+ for i in range(gen_length):
212
+ pos = prompt_length + i # Absolute position in the sequence
213
+ current_token_id = current_x_hook[0, pos].item()
214
+ prev_token_id = prev_x_step[0, pos].item()
215
+
216
+ # Decode token, handling special tokens we want to hide
217
+ token_str = ""
218
+ color = "#444444" # Default: Dark Gray (Mask)
219
+ token_str_raw = tokenizer.decode([current_token_id], skip_special_tokens=False) # Keep special tokens for ID check
220
+
221
+ if current_token_id == MASK_ID:
222
+ token_str = MASK_TOKEN
223
+ color = "#444444" # Dark gray
224
+ elif current_token_id == EOS_ID or current_token_id == PAD_ID:
225
+ token_str = "" # Hide EOS/PAD visually
226
+ color = "#DDDDDD" # Use a light gray or make transparent if possible
227
+ else:
228
+ # Decode without special tokens for display if it's not MASK/EOS/PAD
229
+ token_str = tokenizer.decode([current_token_id], skip_special_tokens=True).strip()
230
+ if not token_str: token_str = token_str_raw # Fallback if strip removes everything (e.g., space)
231
 
232
+ if prev_token_id == MASK_ID:
233
+ # Newly revealed in this step
234
+ color = "#66CC66" # Light green (Simplified from confidence levels)
235
+ else:
236
+ # Previously revealed
237
+ color = "#6699CC" # Light blue
238
+
239
+ current_state_vis.append((token_str if token_str else " ", color)) # Ensure non-empty tuple element
240
+
241
+ visualization_states.append(current_state_vis)
242
+ shared_state['previous_x'] = current_x_hook.clone() # Update previous_x for the next step
243
+
244
+ return current_x_hook # Return the sequence (constraints applied)
245
+
246
+ # --- 4. Run Diffusion Generation ---
247
+ print("Starting diffusion generation...")
248
+ start_time = time.time()
249
+ output = model.diffusion_generate(
250
+ input_ids=x[:, :prompt_length], # Pass only the initial prompt to diffusion_generate
251
+ # as it handles the masking internally based on MASK_ID
252
+ attention_mask=attention_mask, # Provide the full attention mask
253
+ max_new_tokens=gen_length,
254
+ output_history=False, # We capture history via the hook
255
+ return_dict_in_generate=True,
256
+ steps=steps,
257
+ temperature=temperature,
258
+ top_p=top_p,
259
+ top_k=top_k,
260
+ alg=alg,
261
+ alg_temp=alg_temp if alg != 'origin' else None, # alg_temp only for confidence-based
262
+ # Pass the hook function
263
+ generation_tokens_hook_func=generation_tokens_hook_func,
264
+ # Ensure the initial masked sequence `x` is used correctly if needed by internal logic
265
+ # Depending on the exact implementation of diffusion_generate, passing x directly might be needed
266
+ # Check Dream's generation_utils if issues arise. For now, assume it uses input_ids + max_new_tokens
267
+ )
268
+ end_time = time.time()
269
+ print(f"Diffusion generation finished in {end_time - start_time:.2f} seconds.")
270
 
271
+ # --- 5. Process Final Output ---
272
+ # The hook has already built visualization_states
273
+ final_sequence = output.sequences[0]
 
 
 
 
 
 
 
 
274
 
275
+ # Decode the generated part, skipping special tokens for the final text output
276
+ response_tokens = final_sequence[prompt_length:]
277
+ # Filter out PAD tokens before final decode, keep EOS if needed conceptually, but skip for clean text
278
+ response_tokens_cleaned = [tok for tok in response_tokens if tok != PAD_ID] # Keep EOS initially if needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
+ final_text = tokenizer.decode(
281
+ response_tokens_cleaned,
282
+ skip_special_tokens=True, # Skip EOS, BOS, etc.
283
+ clean_up_tokenization_spaces=True # Recommended for cleaner output
284
+ ).strip()
285
 
286
+ # Ensure the last state in visualization matches the final text (debug check)
287
+ # print(f"Last Vis State Tokens: {''.join([t[0] for t in visualization_states[-1]]).strip()}")
288
+ # print(f"Final Decoded Text: {final_text}")
289
 
290
+ except Exception as e:
291
+ print(f"Error during generation: {e}")
292
+ import traceback
293
+ traceback.print_exc()
294
+ # Add error message to visualization
295
+ error_msg = f"Error: {str(e)}"
296
+ visualization_states.append([(error_msg, "red")])
297
+ final_text = error_msg # Display error in the chatbot too
 
298
 
299
+ # Make sure at least the initial state is present
300
+ if not visualization_states:
301
+ visualization_states.append([("Error: No states generated.", "red")])
302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
+ return visualization_states, final_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
+ # --- Gradio UI Definition ---
 
 
307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  css = '''
309
  .category-legend{display:none}
310
  button{height: 60px}
311
+ .token-text { white-space: pre; } /* Preserve spaces in tokens */
312
+ footer { display: none !important; visibility: hidden !important; }
 
 
 
 
 
313
  '''
 
314
  def create_chatbot_demo():
315
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
316
  gr.Markdown("# Dream 7B - Diffusion Language Model Demo")
317
  gr.Markdown(
318
  "[[Model Card](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)] "
319
+ "[[Blog Post](https://hkunlp.github.io/blog/2025/dream/)] "
320
+ "(Note: Visualization shows token reveal steps, colors indicate status: Gray=Masked, Green=Newly Revealed, Blue=Previously Revealed)"
 
 
 
 
 
 
 
 
321
  )
322
 
323
  # STATE MANAGEMENT
324
  chat_history = gr.State([])
325
+ # Store constraints parsed into token IDs
326
+ parsed_constraints_state = gr.State({})
327
 
328
  # UI COMPONENTS
329
  with gr.Row():
330
  with gr.Column(scale=3):
331
+ chatbot_ui = gr.Chatbot(
332
+ label="Conversation",
333
+ height=500,
334
+ bubble_full_width=False # Makes bubbles wrap content
335
+ )
336
 
337
  # Message input
338
  with gr.Group():
 
340
  user_input = gr.Textbox(
341
  label="Your Message",
342
  placeholder="Type your message here...",
343
+ show_label=False,
344
+ scale=7
345
  )
346
  send_btn = gr.Button("Send", scale=1)
347
 
348
  constraints_input = gr.Textbox(
349
+ label="Word Constraints (Experimental)",
350
+ info="Place specific words at positions (0-indexed). Format: 'pos:word, pos:word'. Example: '0:Once, 5:upon, 10:time'. Multi-token words supported.",
351
+ placeholder="0:The, 10:story",
352
  value=""
353
  )
354
  with gr.Column(scale=2):
 
355
  output_vis = gr.HighlightedText(
356
  label="Denoising Process Visualization",
357
+ combine_adjacent=False,
358
+ show_legend=False, # Legend not very informative here
359
+ height=560, # Match chatbot height + input box approx
360
+ elem_classes=["token-text"] # Apply custom class for styling if needed
 
 
 
 
 
 
361
  )
362
 
 
363
  # Advanced generation settings
364
  with gr.Accordion("Generation Settings", open=False):
365
  with gr.Row():
366
  gen_length = gr.Slider(
367
+ minimum=16, maximum=512, value=128, step=8,
368
  label="Max New Tokens"
369
  )
370
  steps = gr.Slider(
371
+ minimum=8, maximum=512, value=128, step=4,
372
+ label="Denoising Steps"
373
  )
374
  with gr.Row():
375
  temperature = gr.Slider(
376
+ minimum=0.0, maximum=1.0, value=0.2, step=0.05,
377
  label="Temperature"
378
  )
379
  top_p = gr.Slider(
380
+ minimum=0.1, maximum=1.0, value=0.95, step=0.05,
381
+ label="Top-P"
 
 
 
 
 
 
 
 
 
 
 
382
  )
383
+ top_k = gr.Slider(
384
+ minimum=0, maximum=200, value=0, step=5,
385
+ label="Top-K (0=disabled)"
386
  )
387
+ with gr.Row():
388
+ alg = gr.Radio(
389
+ choices=['origin', 'maskgit_plus', 'topk_margin', 'entropy'],
390
+ value='entropy',
391
+ label="Sampling Algorithm (`alg`)"
392
+ )
393
+ alg_temp = gr.Slider(
394
+ minimum=0.0, maximum=1.0, value=0.0, step=0.05,
395
+ label="Algorithm Temp (`alg_temp`, adds randomness to confidence-based `alg`)"
396
+ )
397
+
398
  with gr.Row():
399
  visualization_delay = gr.Slider(
400
+ minimum=0.0, maximum=0.5, value=0.02, step=0.01,
401
  label="Visualization Delay (seconds)"
402
  )
403
 
404
  # Clear button
405
  clear_btn = gr.Button("Clear Conversation")
406
 
407
+ # --- Event Handlers ---
408
+ def add_message(history, message, response):
409
+ """Add a message pair to the history and return the updated history"""
410
+ # Ensure history is a list
411
+ if not isinstance(history, list):
412
+ history = []
413
  history.append([message, response])
414
  return history
415
 
416
+ def user_message_submitted(message, history):
417
+ """Process a submitted user message"""
418
+ if not message.strip():
419
+ return history, history, "", [] # No change if empty
420
+
421
+ # Add user message (response is None for now)
422
+ history = add_message(history, message, None)
423
+
424
+ # Return updated history for display, clear input box
425
+ return history, history, "", [] # history, chatbot_ui, user_input, output_vis
426
+
427
+
428
+ def bot_response_stream(
429
+ history, # Current chat history (list of lists)
430
+ gen_length, steps, constraints, # Generation settings
431
+ temperature, top_p, top_k, alg, alg_temp, # Sampling settings
432
+ delay # Visualization delay
433
+ ):
434
+ """Generate bot response and stream visualization states"""
435
+ if not history or history[-1][1] is not None: # Check if history is present and last response isn't already set
436
+ print("Skipping bot response generation: No new user message.")
437
+ # Yield empty state if needed to prevent errors downstream
438
+ # Ensure history is returned correctly if nothing happens
439
+ yield history, [], "Internal Error: No user message found."
440
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
+ # Format messages for the model
443
+ # Exclude the last entry as it only contains the user message
444
+ messages_for_model = format_chat_history(history) # Already includes system prompt
445
+
446
+ print("\n--- Generating Bot Response ---")
447
+ print(f"History: {history}")
448
+ print(f"Messages for model: {messages_for_model}")
449
+ print(f"Constraints text: '{constraints}'")
450
+ print(f"Gen length: {gen_length}, Steps: {steps}, Temp: {temperature}, Top-P: {top_p}, Top-K: {top_k}, Alg: {alg}, Alg Temp: {alg_temp}")
451
+
452
+ # Call the generation function
453
+ vis_states, response_text = generate_response_with_visualization(
454
+ messages_for_model,
455
+ gen_length=gen_length,
456
+ steps=steps,
457
+ constraints_text=constraints,
458
+ temperature=temperature,
459
+ top_p=top_p if top_p < 1.0 else None, # None disables top-p
460
+ top_k=top_k if top_k > 0 else None, # None disables top-k
461
+ alg=alg,
462
+ alg_temp=alg_temp,
463
+ visualization_delay=delay,
464
+ # Pass other necessary args like tokenizer, model if not global
465
+ )
466
 
467
+ print(f"Generated response text: '{response_text}'")
468
+ print(f"Number of visualization states: {len(vis_states)}")
469
 
 
 
470
 
471
+ # Update the history with the final response
472
+ # Make sure history is mutable if needed or reassign
473
+ if history:
474
+ history[-1][1] = response_text
475
+ else:
476
+ print("Warning: History was empty when trying to update response.")
477
 
 
 
 
 
 
 
 
 
 
478
 
479
+ # Stream the visualization states
480
+ if not vis_states:
481
+ print("Warning: No visualization states were generated.")
482
+ # Yield something to prevent downstream errors
483
+ yield history, [("Error: No visualization.", "red")], response_text
484
+ return
485
 
486
+ # Yield initial state immediately if desired, or just start loop
487
+ # yield history, vis_states[0], response_text
 
488
 
489
+ for state in vis_states:
490
+ yield history, state, response_text # Yield updated history, current vis state, final text
491
+ time.sleep(delay) # Pause between steps
492
 
493
+ # Final yield to ensure the last state is displayed
494
+ yield history, vis_states[-1], response_text
495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
 
497
+ def clear_conversation():
498
+ """Clear the conversation history and visualization"""
499
+ return [], [], "", [] # history, chatbot, user_input, output_vis
500
+
501
+ # --- Event Wiring ---
502
+
503
+ # Clear button
504
  clear_btn.click(
505
+ fn=clear_conversation,
506
  inputs=[],
507
+ outputs=[chat_history, chatbot_ui, user_input, output_vis]
508
+ )
509
+
510
+ # User message submission flow (2-step using .then)
511
+ # 1. User submits message -> Update history and chatbot UI immediately
512
+ submit_action = user_input.submit(
513
+ fn=user_message_submitted,
514
+ inputs=[user_input, chat_history],
515
+ outputs=[chat_history, chatbot_ui, user_input, output_vis] # Update chatbot, clear input
516
+ )
517
+
518
+ # Connect send button to the same function
519
+ send_action = send_btn.click(
520
+ fn=user_message_submitted,
521
+ inputs=[user_input, chat_history],
522
+ outputs=[chat_history, chatbot_ui, user_input, output_vis]
523
+ )
524
+
525
+ # 2. After UI update -> Trigger bot response generation and streaming
526
+ # Use the updated chat_history from the first step
527
+ submit_action.then(
528
+ fn=bot_response_stream,
529
+ inputs=[
530
+ chat_history, gen_length, steps, constraints_input,
531
+ temperature, top_p, top_k, alg, alg_temp,
532
+ visualization_delay
533
+ ],
534
+ outputs=[chatbot_ui, output_vis, user_input] # Update chatbot, visualization. Keep user_input as output to potentially display final text/error? (Check Gradio docs for Textbox output binding on yield) Let's remove user_input from outputs here.
535
  )
536
 
537
+ send_action.then(
538
+ fn=bot_response_stream,
539
+ inputs=[
540
+ chat_history, gen_length, steps, constraints_input,
541
+ temperature, top_p, top_k, alg, alg_temp,
542
+ visualization_delay
543
+ ],
544
+ outputs=[chatbot_ui, output_vis] # Update chatbot and visualization
545
+ )
546
+
547
+ # Clear input after send/submit (already handled in user_message_submitted)
548
+ # submit_action.then(lambda: "", outputs=user_input)
549
+ # send_action.then(lambda: "", outputs=user_input)
550
+
551
+
552
  return demo
553
 
554
+ # --- Launch the Gradio App ---
555
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
556
  demo = create_chatbot_demo()
557
+ # Using queue for streaming and handling multiple users
558
+ demo.queue().launch(debug=True, share=True)