multimodalart HF Staff commited on
Commit
4474e7a
·
verified ·
1 Parent(s): 11d48f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +515 -183
app.py CHANGED
@@ -8,6 +8,91 @@ from transformers import AutoTokenizer, AutoModel, AutoConfig
8
  import time
9
  import re
10
  from typing import List, Dict, Tuple, Optional
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # Load model configuration to get special token IDs
13
  config = AutoConfig.from_pretrained("Dream-org/Dream-v0-Instruct-7B", trust_remote_code=True)
@@ -27,25 +112,48 @@ print("Loading model...")
27
  model = AutoModel.from_pretrained(
28
  model_path,
29
  torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32, # Use bfloat16 only on CUDA
30
- trust_remote_code=True
 
31
  )
32
  model = model.to(device).eval()
33
  print("Model loaded.")
34
 
35
  # Constants from Dream's config/tokenizer
36
- # Use attributes from loaded config/tokenizer objects
37
  MASK_TOKEN = tokenizer.mask_token
38
- MASK_ID = config.mask_token_id
39
- PAD_ID = config.pad_token_id
40
- EOS_ID = config.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # Make sure EOS_ID and PAD_ID are handled correctly; Dream uses the same ID for both
42
  SPECIAL_TOKEN_IDS = {PAD_ID, EOS_ID, MASK_ID}
43
  # Add other special tokens defined in tokenizer_config.json if needed for hiding
44
  # Get IDs for im_start, im_end etc. if they should also be hidden/handled specially
45
- IM_START_ID = tokenizer.convert_tokens_to_ids("<|im_start|>")
46
- IM_END_ID = tokenizer.convert_tokens_to_ids("<|im_end|>")
47
- SPECIAL_TOKEN_IDS.add(IM_START_ID)
48
- SPECIAL_TOKEN_IDS.add(IM_END_ID)
 
 
 
 
 
 
49
 
50
  # --- Helper Functions ---
51
 
@@ -61,25 +169,57 @@ def parse_constraints(constraints_text: str) -> Dict[int, List[int]]:
61
 
62
  parts = constraints_text.split(',')
63
  for part in parts:
 
64
  if ':' not in part:
65
  continue
66
  pos_str, word = part.split(':', 1)
67
  try:
68
  # Position relative to the start of the *generation*
69
  pos = int(pos_str.strip())
70
- word = word.strip()
71
- # Tokenize the word - add leading space if not BOS? Dream handles spaces.
72
- # Check Dream tokenizer behavior for spaces. Assuming standard behavior:
73
- token_ids = tokenizer.encode(" " + word if pos > 0 else word, add_special_tokens=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  if token_ids and pos >= 0:
76
  constraints[pos] = token_ids
 
 
77
  except ValueError:
 
78
  continue # Ignore malformed constraint parts
79
  except Exception as e:
80
  print(f"Warning: Error processing constraint '{part}': {e}")
81
  continue
82
 
 
83
  return constraints
84
 
85
 
@@ -95,23 +235,45 @@ def format_chat_history(history: List[List[Optional[str]]]) -> List[Dict[str, st
95
  Formatted list of message dictionaries for tokenizer.apply_chat_template.
96
  """
97
  messages = []
98
- # Check if the first message is a system prompt, handle accordingly if needed
99
- # Based on Dream's examples, the template adds a default system prompt if none exists.
100
- # If history starts with System, it should be handled by the template.
101
- # Let's assume the template handles the system prompt correctly.
102
-
103
  for user_msg, assistant_msg in history:
104
  if user_msg: # Defensive check
105
  messages.append({"role": "user", "content": user_msg})
106
  # Add assistant message only if it exists (it won't for the last turn before generation)
107
  if assistant_msg:
108
  messages.append({"role": "assistant", "content": assistant_msg})
109
-
110
  return messages
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  # --- Core Generation Logic with Live Visualization ---
113
 
114
  @spaces.GPU # Decorator for Hugging Face Spaces GPU usage
 
115
  def generate_dream_response(
116
  history: List[List[Optional[str]]],
117
  gen_length: int,
@@ -125,7 +287,7 @@ def generate_dream_response(
125
  visualization_delay: float
126
  ) -> List[Tuple[str, str]]:
127
  """
128
- Generates text using the Dream model and yields visualization states live.
129
 
130
  Args:
131
  history: Chat history.
@@ -133,21 +295,20 @@ def generate_dream_response(
133
  steps: Number of diffusion steps.
134
  constraints_text: User-provided constraints string.
135
  temperature: Sampling temperature.
136
- top_p: Top-p sampling nucleus.
137
- top_k: Top-k sampling.
138
  alg: Remasking algorithm ('origin', 'maskgit_plus', 'topk_margin', 'entropy').
139
  alg_temp: Temperature for confidence-based algorithms.
140
  visualization_delay: Delay between visualization steps.
141
 
142
  Yields:
143
  Tuple[List[List[Optional[str]]], List[Tuple[str, Optional[str]]], str]:
144
- - Updated history
145
- - Visualization data for HighlightedText
146
- - Final response text (repeated in each yield)
147
  """
148
 
149
  if not history or not history[-1][0]:
150
- # No user message to respond to
151
  yield history, [("No input message found.", "red")], ""
152
  return
153
 
@@ -167,90 +328,275 @@ def generate_dream_response(
167
  add_generation_prompt=True # Important for instruct models
168
  )
169
  input_ids = inputs.input_ids.to(device)
170
- attention_mask = inputs.attention_mask.to(device)
171
  prompt_length = input_ids.shape[1]
172
  except Exception as e:
173
  print(f"Error applying chat template: {e}")
174
  yield history, [("Error preparing input.", "red")], ""
175
  return
176
 
177
- # Calculate total sequence length for the model
178
- # Max length constraint from model config (e.g., 2048 for original Dream?)
179
- # Let's use a reasonable default or allow configuration if needed.
180
- # The provided code uses max_position_embeddings=131072, let's stick to user input + gen_length.
181
- total_length = prompt_length + gen_length
 
182
 
183
- # --- 2. Visualization Setup ---
184
- # This list will store the token sequence (just the generated part) at each step
185
- step_sequence_history: List[torch.Tensor] = []
186
- previous_step_tokens = None # Keep track of the previous step's state
187
 
188
- # Define the hook function *inside* this function to capture state
189
- def live_visualization_hook(step: Optional[int], x: torch.Tensor, logits: Optional[torch.Tensor]) -> torch.Tensor:
190
- nonlocal step_sequence_history, parsed_constraints, prompt_length
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
- # --- Apply Constraints ---
193
- # Constraints are applied *after* the model proposes tokens but *before* they are finalized for the step
194
- # Note: The hook receives the state *before* the next model call in the next step,
195
- # or the final state after the last step. Let's apply constraints consistently.
196
- # The `diffusion_generate` calls the hook *after* updating x based on sampling.
197
- current_x = x.clone() # Work on a copy
198
 
199
- for rel_pos, word_token_ids in parsed_constraints.items():
200
- abs_start_pos = prompt_length + rel_pos
201
- abs_end_pos = abs_start_pos + len(word_token_ids)
202
 
203
- # Ensure the constraint fits within the generation length
204
- if abs_start_pos < total_length and abs_end_pos <= total_length:
205
- try:
206
- constraint_tensor = torch.tensor(word_token_ids, dtype=torch.long, device=current_x.device)
207
- # Force the constraint tokens onto the sequence
208
- current_x[0, abs_start_pos:abs_end_pos] = constraint_tensor
209
- except IndexError:
210
- print(f"Warning: Constraint at {rel_pos} ('{tokenizer.decode(word_token_ids)}') goes out of bounds.")
211
- except Exception as e:
212
- print(f"Warning: Failed to apply constraint at {rel_pos}: {e}")
213
-
214
- # Store the state *after* constraints for visualization
215
- # We only need the generated part
216
- generated_part = current_x[0, prompt_length:].clone().cpu() # Move to CPU to save GPU memory
217
- step_sequence_history.append(generated_part)
218
-
219
- # Return the (potentially modified by constraints) tensor x
220
- return current_x # Pass the constrained version to the next step
221
-
222
- # --- 3. Run Generation ---
223
- final_response_text = ""
224
- try:
225
- print(f"Starting Dream generation: prompt_len={prompt_length}, gen_len={gen_length}, steps={steps}")
226
- start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
- # Initial masked state for visualization
229
- initial_generated_state = torch.full((gen_length,), MASK_ID, dtype=torch.long)
230
- # Apply constraints to the *initial* visual state if they start at pos 0
231
- temp_initial_x = torch.cat((input_ids[0], initial_generated_state.to(device)), dim=0).unsqueeze(0)
232
- initial_vis_x = live_visualization_hook(None, temp_initial_x, None) # Apply constraints via hook logic
233
- step_sequence_history.insert(0, initial_vis_x[0, prompt_length:].cpu()) # Prepend initial state
234
-
235
- output = model.diffusion_generate(
236
- input_ids,
237
- attention_mask=attention_mask,
238
- max_new_tokens=gen_length,
239
- output_history=False, # We capture history via the hook
240
- return_dict_in_generate=True,
241
- steps=steps,
242
- temperature=temperature,
243
- top_p=top_p if top_p is not None and top_p < 1.0 else None, # Ensure top_p < 1 or None
244
- top_k=top_k if top_k is not None and top_k > 0 else None, # Ensure top_k > 0 or None
245
- alg=alg,
246
- alg_temp=alg_temp if alg in ['maskgit_plus', 'topk_margin', 'entropy'] else None, # Only relevant for some algs
247
- generation_tokens_hook_func=live_visualization_hook
248
- )
249
  end_time = time.time()
250
  print(f"Dream generation finished in {end_time - start_time:.2f} seconds.")
251
 
252
- # --- 4. Process Final Output ---
253
- final_sequence = output.sequences[0]
254
  response_tokens = final_sequence[prompt_length:]
255
 
256
  # Decode the final response text
@@ -260,83 +606,57 @@ def generate_dream_response(
260
  clean_up_tokenization_spaces=True
261
  ).strip()
262
 
263
- # Update history with the final response
264
- history[-1][1] = final_response_text
265
 
266
- except Exception as e:
267
- print(f"Error during generation or processing: {e}")
268
- import traceback
269
- traceback.print_exc()
270
- yield history, [("Error during generation.", "red")], ""
271
- return
272
-
273
- # --- 5. Stream Visualization ---
274
- print(f"Streaming {len(step_sequence_history)} visualization steps...")
275
- previous_tokens_vis = None
276
- for i, current_tokens_vis in enumerate(step_sequence_history):
277
- # print(f" Step {i}: {current_tokens_vis.tolist()}") # Debug
278
- vis_data = []
279
- current_decoded_tokens = []
280
-
281
- # Compare current step tokens with previous step tokens
282
  for j in range(gen_length):
283
- current_tok_id = current_tokens_vis[j].item()
284
  previous_tok_id = previous_tokens_vis[j].item() if previous_tokens_vis is not None else MASK_ID
285
 
286
- # Decode token - handle potential errors for single IDs if needed
287
  try:
288
- # Use skip_special_tokens=False here to see the actual tokens
289
  decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False)
290
- # Explicitly handle mask token display
291
- if current_tok_id == MASK_ID:
292
- display_token = MASK_TOKEN
293
- else:
294
- display_token = decoded_token
295
-
296
  except Exception:
297
  display_token = f"[ID:{current_tok_id}]" # Fallback
298
 
299
- # Determine color and handle hiding of special tokens (like LLaDA demo)
300
  color = None
301
  token_to_display = display_token
302
 
303
  if current_tok_id == MASK_ID:
304
- color = "#444444" # Dark Gray for masks
305
- elif previous_tok_id == MASK_ID: # Token was just revealed
306
- # Simple green for newly revealed, no confidence score available from hook
307
- color = "#66CC66" # Light Green
308
- else: # Token was already revealed
309
- color = "#6699CC" # Light Blue
310
-
311
- # LLaDA hiding effect: If it's a special token (EOS/PAD) *and* it was revealed before this step, hide it.
312
- if current_tok_id in {PAD_ID, EOS_ID} and previous_tok_id == current_tok_id:
313
- # Hide by making it empty or using a background color - empty string is simpler
314
  token_to_display = ""
315
- color = "#FFFFFF" # Or just make it blend in
316
 
317
- # Add token and color to visualization data
318
- if token_to_display: # Avoid adding empty strings if hiding
319
- vis_data.append((token_to_display, color))
320
- elif len(vis_data) > 0 and isinstance(vis_data[-1], tuple):
321
- # If hidden, and previous was text, add a space for visual separation?
322
- # This might complicate things, let's omit for now.
323
- pass
324
- # elif len(vis_data) == 0: # If first token is hidden
325
- # vis_data.append(("", None)) # Placeholder?
326
 
327
- # Update previous state for next iteration
328
- previous_tokens_vis = current_tokens_vis
 
329
 
330
- # Yield the current visualization state
331
- yield history, vis_data, final_response_text
332
 
333
- # Pause for the specified delay
334
- time.sleep(visualization_delay)
335
-
336
- print("Visualization streaming complete.")
 
 
 
337
 
338
 
339
- # --- Gradio UI ---
340
  css = '''
341
  .category-legend{display:none}
342
  button{min-height: 60px}
@@ -346,11 +666,13 @@ def create_chatbot_demo():
346
  gr.Markdown("# Dream 7B - Diffusion Language Model Demo")
347
  gr.Markdown(
348
  "[[Model Card](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)] "
349
- "[[Blog](https://hkunlp.github.io/blog/2025/dream/)]"
350
  )
351
 
352
  # STATE MANAGEMENT
353
- chat_history = gr.State([])
 
 
354
 
355
  # UI COMPONENTS
356
  with gr.Row():
@@ -359,7 +681,8 @@ def create_chatbot_demo():
359
  label="Conversation",
360
  height=500,
361
  show_copy_button=True,
362
- bubble_full_width=False
 
363
  )
364
 
365
  # Message input
@@ -386,9 +709,17 @@ def create_chatbot_demo():
386
  output_vis = gr.HighlightedText(
387
  label="Denoising Process Visualization",
388
  combine_adjacent=False,
389
- show_legend=True,
 
 
 
 
 
 
 
390
  )
391
 
 
392
  # Advanced generation settings
393
  with gr.Accordion("Generation Settings", open=False):
394
  with gr.Row():
@@ -403,21 +734,21 @@ def create_chatbot_demo():
403
  with gr.Row():
404
  temperature = gr.Slider(
405
  minimum=0.0, maximum=1.0, value=0.4, step=0.05,
406
- label="Temperature"
407
  )
408
  alg_temp = gr.Slider(
409
  minimum=0.0, maximum=1.0, value=0.1, step=0.05,
410
- label="Remasking Temp (for confidence algs)"
411
  )
412
 
413
  with gr.Row():
414
  top_p = gr.Slider(
415
  minimum=0.0, maximum=1.0, value=0.95, step=0.05,
416
- label="Top-P (0=disabled)"
417
  )
418
  top_k = gr.Slider(
419
  minimum=0, maximum=200, value=0, step=5,
420
- label="Top-K (0=disabled)"
421
  )
422
 
423
  with gr.Row():
@@ -429,76 +760,77 @@ def create_chatbot_demo():
429
 
430
  with gr.Row():
431
  visualization_delay = gr.Slider(
432
- minimum=0.0, maximum=0.5, value=0.02, step=0.01, # Faster default
433
  label="Visualization Delay (seconds)"
434
  )
435
 
436
  # Clear button
437
  clear_btn = gr.Button("Clear Conversation")
438
 
439
- # Current response text box (hidden, maybe useful for debugging)
440
- # current_response = gr.Textbox(visible=False)
441
-
442
  # --- Event Handlers ---
443
 
444
- def add_user_message_to_history(message: str, history: List[List[Optional[str]]]):
445
  """Adds user message, clears input, prepares for bot response."""
446
  if not message.strip():
447
  gr.Warning("Please enter a message.")
448
- return history, history, "", [("Enter a message", "grey")] # Keep vis empty or show prompt
 
449
 
450
  # Add user message with placeholder for bot response
451
- history.append([message, None])
452
- # Return updated history for chatbot, empty input box, empty visualization
453
- return history, history, "", []
454
-
455
 
456
  def clear_conversation():
457
- """Clears the chat history and visualization."""
458
- return [], [], "", []
459
 
460
  # --- Connect UI elements ---
461
 
462
  # Define the inputs for the generation function once
463
  generation_inputs = [
464
- chat_history, gen_length, steps, constraints_input,
465
  temperature, top_p, top_k, remasking_strategy, alg_temp,
466
  visualization_delay
467
  ]
468
  # Define the outputs for the generation function
469
- generation_outputs = [chatbot_ui, output_vis]
 
 
470
 
471
  # Handle Textbox Submission (Enter key)
472
  submit_listener = user_input.submit(
473
  fn=add_user_message_to_history,
474
- inputs=[user_input, chat_history],
475
- outputs=[chat_history, chatbot_ui, user_input, output_vis] # Step 1: Add user msg
476
  )
477
  # Chain the bot response generation after the user message is added
478
  submit_listener.then(
479
  fn=generate_dream_response,
480
  inputs=generation_inputs,
481
- outputs=generation_outputs # Step 2: Generate response and stream vis
 
482
  )
483
 
484
  # Handle Send Button Click
485
  click_listener = send_btn.click(
486
  fn=add_user_message_to_history,
487
- inputs=[user_input, chat_history],
488
- outputs=[chat_history, chatbot_ui, user_input, output_vis] # Step 1: Add user msg
489
  )
490
  # Chain the bot response generation after the user message is added
491
  click_listener.then(
492
  fn=generate_dream_response,
493
  inputs=generation_inputs,
494
- outputs=generation_outputs # Step 2: Generate response and stream vis
 
495
  )
496
 
497
- # Clear Button Action remains the same
498
  clear_btn.click(
499
  clear_conversation,
500
  inputs=[],
501
- outputs=[chat_history, chatbot_ui, user_input, output_vis]
502
  )
503
 
504
  return demo
@@ -507,4 +839,4 @@ def create_chatbot_demo():
507
  if __name__ == "__main__":
508
  demo = create_chatbot_demo()
509
  # Use queue for handling multiple users and streaming
510
- demo.queue().launch(debug=True, share=True) # Add share=True for public link if needed
 
8
  import time
9
  import re
10
  from typing import List, Dict, Tuple, Optional
11
+ import torch.distributions as dists # Added import
12
+
13
+ # --- START: Copied Helper functions from generation_utils.py ---
14
+ # These are needed because we are reimplementing the sampling loop locally.
15
+
16
+ def top_p_logits(logits, top_p=None):
17
+ """ Applies top-p filtering to logits. """
18
+ if top_p is None or top_p >= 1.0:
19
+ return logits
20
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
21
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
22
+ sorted_indices_to_remove = cumulative_probs > top_p
23
+ # Shift the indices to the right to keep the first token above the threshold
24
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
25
+ sorted_indices_to_remove[..., 0] = 0
26
+
27
+ mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
28
+ mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
29
+ logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
30
+ return logits
31
+
32
+ def top_k_logits(logits, top_k=None):
33
+ """ Applies top-k filtering to logits. """
34
+ if top_k is None or top_k <= 0:
35
+ return logits
36
+ top_k = min(top_k, logits.size(-1)) # Safety check
37
+ # Remove all tokens with a probability less than the last token of the top-k
38
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
39
+ logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
40
+ return logits
41
+
42
+ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
43
+ """ Samples tokens based on logits and calculates confidence. """
44
+ if temperature > 0:
45
+ logits = logits / temperature
46
+ if top_p is not None and top_p < 1.0: # Apply top_p if valid
47
+ logits = top_p_logits(logits, top_p)
48
+ if top_k is not None and top_k > 0: # Apply top_k if valid
49
+ logits = top_k_logits(logits, top_k)
50
+
51
+ # Ensure logits are not all -inf after filtering, if so, sample uniformly? Or handle error.
52
+ # For simplicity, assume valid logits after filtering. If not, sampling might fail.
53
+ # Add a small epsilon to prevent log(0) or issues with all -inf logits
54
+ logits = torch.where(logits == torch.finfo(logits.dtype).min, torch.full_like(logits, -1e9), logits)
55
+
56
+
57
+ probs = torch.softmax(logits, dim=-1)
58
+
59
+ if temperature > 0:
60
+ try:
61
+ # Check for NaNs or Infs in probs before sampling
62
+ if torch.isnan(probs).any() or torch.isinf(probs).any():
63
+ print("Warning: NaN or Inf detected in probabilities before sampling. Attempting to recover.")
64
+ # Simple recovery: Sample from uniform distribution or highest prob token
65
+ probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
66
+ if probs.sum() == 0: # If all probabilities became zero
67
+ print("Warning: All probabilities became zero. Sampling uniformly.")
68
+ probs = torch.ones_like(probs) / probs.shape[-1]
69
+ else:
70
+ probs = probs / probs.sum(dim=-1, keepdim=True) # Re-normalize
71
+
72
+ x0 = dists.Categorical(probs=probs).sample()
73
+ confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
74
+ except Exception as e: # Catch broader exceptions during sampling
75
+ print(f"Warning: Error during Categorical sampling: {e}. Falling back to argmax.")
76
+ confidence, x0 = probs.max(dim=-1)
77
+ else:
78
+ confidence, x0 = probs.max(dim=-1)
79
+
80
+ if margin_confidence:
81
+ sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
82
+ # Ensure there are at least 2 probabilities to compare
83
+ top1_probs = sorted_probs[..., 0]
84
+ top2_probs = sorted_probs[..., 1] if sorted_probs.shape[-1] > 1 else top1_probs # Handle case with only 1 possible token
85
+ confidence = top1_probs - top2_probs
86
+
87
+ if neg_entropy:
88
+ epsilon = 1e-10
89
+ log_probs = torch.log(probs + epsilon)
90
+ confidence = torch.sum(probs * log_probs, dim=-1) # Should be negative entropy
91
+
92
+ return confidence, x0
93
+
94
+ # --- END: Copied Helper functions ---
95
+
96
 
97
  # Load model configuration to get special token IDs
98
  config = AutoConfig.from_pretrained("Dream-org/Dream-v0-Instruct-7B", trust_remote_code=True)
 
112
  model = AutoModel.from_pretrained(
113
  model_path,
114
  torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32, # Use bfloat16 only on CUDA
115
+ trust_remote_code=True,
116
+ # attn_implementation="flash_attention_2" # Optional: Speed up if FA2 is available
117
  )
118
  model = model.to(device).eval()
119
  print("Model loaded.")
120
 
121
  # Constants from Dream's config/tokenizer
 
122
  MASK_TOKEN = tokenizer.mask_token
123
+ MASK_ID = tokenizer.mask_token_id # Use tokenizer's mask_token_id directly
124
+ PAD_ID = tokenizer.pad_token_id # Use tokenizer's pad_token_id
125
+ EOS_ID = tokenizer.eos_token_id # Use tokenizer's eos_token_id
126
+ # Use attributes from loaded config/tokenizer objects
127
+ # MASK_ID = config.mask_token_id # Can use this too, should be consistent
128
+ # PAD_ID = config.pad_token_id
129
+ # EOS_ID = config.eos_token_id
130
+
131
+ # Ensure mask_token_id is correctly identified
132
+ if MASK_ID is None:
133
+ print("Warning: Mask token ID not found in config/tokenizer. Trying to fetch from tokenizer...")
134
+ # Try getting from tokenizer directly if config doesn't have it or it's None
135
+ mask_token_special = tokenizer.mask_token
136
+ if mask_token_special:
137
+ MASK_ID = tokenizer.convert_tokens_to_ids(mask_token_special)
138
+ print(f"Found MASK_ID from tokenizer: {MASK_ID}")
139
+ else:
140
+ # Fallback or raise error if still not found
141
+ raise ValueError("Cannot determine MASK_ID. Check model's tokenizer configuration.")
142
+
143
  # Make sure EOS_ID and PAD_ID are handled correctly; Dream uses the same ID for both
144
  SPECIAL_TOKEN_IDS = {PAD_ID, EOS_ID, MASK_ID}
145
  # Add other special tokens defined in tokenizer_config.json if needed for hiding
146
  # Get IDs for im_start, im_end etc. if they should also be hidden/handled specially
147
+ try:
148
+ IM_START_ID = tokenizer.convert_tokens_to_ids("<|im_start|>")
149
+ IM_END_ID = tokenizer.convert_tokens_to_ids("<|im_end|>")
150
+ SPECIAL_TOKEN_IDS.add(IM_START_ID)
151
+ SPECIAL_TOKEN_IDS.add(IM_END_ID)
152
+ except KeyError:
153
+ print("Warning: <|im_start|> or <|im_end|> not found in tokenizer vocab.")
154
+ IM_START_ID = None
155
+ IM_END_ID = None
156
+
157
 
158
  # --- Helper Functions ---
159
 
 
169
 
170
  parts = constraints_text.split(',')
171
  for part in parts:
172
+ part = part.strip() # Remove leading/trailing whitespace from the part itself
173
  if ':' not in part:
174
  continue
175
  pos_str, word = part.split(':', 1)
176
  try:
177
  # Position relative to the start of the *generation*
178
  pos = int(pos_str.strip())
179
+ word = word.strip() # Strip whitespace from word
180
+ # Tokenize the word - Dream tokenizer handles spaces well typically.
181
+ # Let's check if the word starts with a space implicitly or needs one.
182
+ # Standard tokenizers often need a space prefix if not at the start.
183
+ # Test: tokenizer.encode(" world") vs tokenizer.encode("world")
184
+ # Dream often encodes ' world' differently from 'world'.
185
+ # Assume we want the word as it would appear mid-sentence unless pos is 0.
186
+ token_ids = tokenizer.encode(word, add_special_tokens=False)
187
+ # Add space prefix if needed based on position? This is tricky.
188
+ # Let's assume the user provides the word how they want it tokenized,
189
+ # potentially including a leading space if necessary.
190
+ # Example: " 5: word" might be tokenized differently than "5:word".
191
+ # Simplest approach: Tokenize exactly what the user provided.
192
+ # Let's refine: add space prefix automatically if pos > 0, unless word already starts with space?
193
+ # This seems more robust for typical usage.
194
+ if pos > 0 and not word.startswith(" "):
195
+ token_ids_with_space = tokenizer.encode(" " + word, add_special_tokens=False)
196
+ # Check if adding space actually changes tokenization significantly
197
+ # Heuristic: if the first token ID changes, use the space-prefixed version.
198
+ first_token_no_space = tokenizer.encode(word, add_special_tokens=False)[0] if token_ids else None
199
+ first_token_with_space = tokenizer.encode(" " + word, add_special_tokens=False)[0] if token_ids_with_space else None
200
+
201
+ if first_token_no_space != first_token_with_space:
202
+ token_ids = token_ids_with_space
203
+ # If tokenization doesn't change much, maybe stick to original? Less surprising.
204
+ # Let's stick to adding the space if pos > 0 for consistency, like original code.
205
+ token_ids = tokenizer.encode(" " + word, add_special_tokens=False)
206
+
207
+ elif pos == 0:
208
+ token_ids = tokenizer.encode(word, add_special_tokens=False)
209
+
210
 
211
  if token_ids and pos >= 0:
212
  constraints[pos] = token_ids
213
+ elif not token_ids:
214
+ print(f"Warning: Could not tokenize constraint word '{word}'")
215
  except ValueError:
216
+ print(f"Warning: Invalid position '{pos_str}' in constraint part '{part}'")
217
  continue # Ignore malformed constraint parts
218
  except Exception as e:
219
  print(f"Warning: Error processing constraint '{part}': {e}")
220
  continue
221
 
222
+ print(f"Parsed constraints: {constraints}") # Debugging
223
  return constraints
224
 
225
 
 
235
  Formatted list of message dictionaries for tokenizer.apply_chat_template.
236
  """
237
  messages = []
 
 
 
 
 
238
  for user_msg, assistant_msg in history:
239
  if user_msg: # Defensive check
240
  messages.append({"role": "user", "content": user_msg})
241
  # Add assistant message only if it exists (it won't for the last turn before generation)
242
  if assistant_msg:
243
  messages.append({"role": "assistant", "content": assistant_msg})
 
244
  return messages
245
 
246
+ def apply_constraints_to_state(
247
+ x: torch.Tensor,
248
+ prompt_length: int,
249
+ total_length: int,
250
+ parsed_constraints: Dict[int, List[int]],
251
+ current_step: Optional[int] = None # For logging/debugging
252
+ ) -> torch.Tensor:
253
+ """Applies constraints directly to the state tensor `x`."""
254
+ modified_x = x.clone() # Work on a copy to avoid modifying original if needed elsewhere
255
+ for rel_pos, word_token_ids in parsed_constraints.items():
256
+ abs_start_pos = prompt_length + rel_pos
257
+ abs_end_pos = abs_start_pos + len(word_token_ids)
258
+
259
+ # Ensure the constraint fits within the generation length
260
+ if abs_start_pos < total_length and abs_end_pos <= total_length:
261
+ try:
262
+ constraint_tensor = torch.tensor(word_token_ids, dtype=torch.long, device=modified_x.device)
263
+ # Force the constraint tokens onto the sequence
264
+ modified_x[0, abs_start_pos:abs_end_pos] = constraint_tensor
265
+ # print(f"Debug (Step {current_step}): Applied constraint {tokenizer.decode(word_token_ids)} at pos {rel_pos}") # Debug
266
+ except IndexError:
267
+ print(f"Warning (Step {current_step}): Constraint at {rel_pos} ('{tokenizer.decode(word_token_ids)}') goes out of bounds.")
268
+ except Exception as e:
269
+ print(f"Warning (Step {current_step}): Failed to apply constraint at {rel_pos}: {e}")
270
+ return modified_x
271
+
272
+
273
  # --- Core Generation Logic with Live Visualization ---
274
 
275
  @spaces.GPU # Decorator for Hugging Face Spaces GPU usage
276
+ @torch.no_grad() # Ensure no gradients are computed during generation
277
  def generate_dream_response(
278
  history: List[List[Optional[str]]],
279
  gen_length: int,
 
287
  visualization_delay: float
288
  ) -> List[Tuple[str, str]]:
289
  """
290
+ Generates text using the Dream model step-by-step and yields visualization states live.
291
 
292
  Args:
293
  history: Chat history.
 
295
  steps: Number of diffusion steps.
296
  constraints_text: User-provided constraints string.
297
  temperature: Sampling temperature.
298
+ top_p: Top-p sampling nucleus. Clamp to < 1.0 or None.
299
+ top_k: Top-k sampling. Clamp to > 0 or None.
300
  alg: Remasking algorithm ('origin', 'maskgit_plus', 'topk_margin', 'entropy').
301
  alg_temp: Temperature for confidence-based algorithms.
302
  visualization_delay: Delay between visualization steps.
303
 
304
  Yields:
305
  Tuple[List[List[Optional[str]]], List[Tuple[str, Optional[str]]], str]:
306
+ - Updated history (may be intermediate until final response)
307
+ - Visualization data for HighlightedText for the current step
308
+ - Intermediate or Final response text (yielded repeatedly)
309
  """
310
 
311
  if not history or not history[-1][0]:
 
312
  yield history, [("No input message found.", "red")], ""
313
  return
314
 
 
328
  add_generation_prompt=True # Important for instruct models
329
  )
330
  input_ids = inputs.input_ids.to(device)
331
+ prompt_attention_mask = inputs.attention_mask.to(device) # Mask for the prompt part
332
  prompt_length = input_ids.shape[1]
333
  except Exception as e:
334
  print(f"Error applying chat template: {e}")
335
  yield history, [("Error preparing input.", "red")], ""
336
  return
337
 
338
+ # --- Config parameters for the loop ---
339
+ eps = 1e-3 # Default from DreamGenerationConfig, make configurable if needed
340
+ # Ensure top_p and top_k have valid values for filtering functions
341
+ top_p_val = top_p if top_p is not None and top_p < 1.0 else None
342
+ top_k_val = top_k if top_k is not None and top_k > 0 else None
343
+ alg_temp_val = alg_temp if alg in ['maskgit_plus', 'topk_margin', 'entropy'] else None
344
 
345
+ # --- 2. Initialize Generation State ---
346
+ total_length = prompt_length + gen_length
 
 
347
 
348
+ # Initial state: prompt + MASK tokens
349
+ initial_generation_part = torch.full((1, gen_length), MASK_ID, dtype=torch.long, device=device)
350
+ x = torch.cat((input_ids, initial_generation_part), dim=1)
351
+
352
+ # Prepare full attention mask (assuming full attention over generated part initially)
353
+ generation_attention_mask = torch.ones((1, gen_length), dtype=torch.long, device=device)
354
+ full_attention_mask = torch.cat((prompt_attention_mask, generation_attention_mask), dim=1)
355
+
356
+ # Check if model needs specific attention mask format (e.g., causal for prompt?)
357
+ # The original `diffusion_generate` handles this internally. Replicating requires care.
358
+ # Based on `_sample`, it prepares a broadcastable mask if padding exists, else uses "full".
359
+ # Let's assume "full" attention is okay for Dream's purpose here, as mask tokens don't depend on future masks.
360
+ # If the base model *requires* causal masking internally even with diffusion, this might need adjustment.
361
+ # For simplicity, using a full mask (ones) over the whole sequence.
362
+ # The model's internal attention should handle causality if needed.
363
+ # Let's stick to the simpler full mask preparation from the original code when no padding.
364
+ if torch.any(full_attention_mask == 0): # Handle padding if present (shouldn't be with template?)
365
+ tok_idx = full_attention_mask.long().cumsum(-1) - 1
366
+ tok_idx.masked_fill_(full_attention_mask == 0, 0) # Use 0 for padding index? Or 1? Check original. Original used 1.
367
+ tok_idx.masked_fill_(full_attention_mask == 0, 1)
368
+ attention_mask_for_model = torch.logical_and(
369
+ full_attention_mask.unsqueeze(1).unsqueeze(-2),
370
+ full_attention_mask.unsqueeze(1).unsqueeze(-1),
371
+ ) # Shape [B, 1, N, N]
372
+ else:
373
+ tok_idx = None
374
+ attention_mask_for_model = None # Let the model handle full attention if mask is None
375
+
376
+ # Timesteps for diffusion
377
+ timesteps = torch.linspace(1, eps, steps + 1, device=device)
378
+
379
+ # Apply initial constraints (before first step)
380
+ x = apply_constraints_to_state(x, prompt_length, total_length, parsed_constraints, current_step=-1) # Step -1 for initial
381
+
382
+ # --- 3. Visualization Setup ---
383
+ previous_tokens_vis = None # Keep track of the previous step's state for coloring
384
+ final_response_text = "" # Store the final decoded text
385
+ history_copy = [list(item) for item in history] # Make a mutable copy
386
+
387
+ # --- 4. Initial Yield (Masked State) ---
388
+ initial_generated_tokens = x[0, prompt_length:].cpu()
389
+ vis_data_initial = []
390
+ for tok_id in initial_generated_tokens.tolist():
391
+ display_token = MASK_TOKEN
392
+ color = "#444444" # Dark Gray for masks
393
+ vis_data_initial.append((display_token, color))
394
+
395
+ previous_tokens_vis = initial_generated_tokens
396
+ yield history_copy, vis_data_initial, "" # Yield initial state
397
+ time.sleep(visualization_delay)
398
+
399
+ # --- 5. Step-by-Step Diffusion Loop ---
400
+ try:
401
+ start_time = time.time()
402
+ for i in range(steps):
403
+ # --- Model Forward Pass ---
404
+ mask_index = (x == MASK_ID) # Find masks in the *current* state x
405
+ if not mask_index.any(): # Stop if no masks left
406
+ print(f"No mask tokens left at step {i}. Stopping early.")
407
+ break
408
+
409
+ # print(f"Step {i}: Input shape {x.shape}, Mask sum {mask_index.sum()}") # Debug
410
+ # print(f"Step {i}: Input tokens (first/last 10): {x[0, :10].tolist()} ... {x[0, -10:].tolist()}") # Debug
411
+
412
+ # Call the model - ensure attention mask format is correct
413
+ # The model forward expects `attention_mask` usually of shape [B, N] or broadcastable.
414
+ # If we use `attention_mask_for_model = None`, it implies full attention.
415
+ # If we computed `attention_mask_for_model` as [B, 1, N, N], pass that.
416
+ # Let's try passing the [B, N] mask and let the model handle broadcasting/causality internally.
417
+ outputs = model(
418
+ input_ids=x,
419
+ attention_mask=full_attention_mask, # Pass the [B, N] mask
420
+ position_ids=None, # Let model compute default positions
421
+ use_cache=False, # No cache needed for diffusion steps
422
+ return_dict=True
423
+ )
424
+ logits = outputs.logits
425
+
426
+ # Shift logits like in original code? Check `generation_utils.py`.
427
+ # Yes, `logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)`
428
+ # This seems to align logits with the *previous* token's prediction. Is this correct for diffusion?
429
+ # Let's assume the original code did this for a reason, perhaps related to how the model was trained or expects inputs.
430
+ # Update: Looking at standard LM forward pass, logits[t] predicts token[t+1].
431
+ # The shift aligns logits[t] with token[t]. Let's keep it.
432
+ logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
433
+
434
+
435
+ # Select logits for masked positions
436
+ # Ensure mask_index has the same batch dimension size as logits
437
+ # mask_index shape is [B, N], logits shape is [B, N, V]
438
+ # We need to select elements from the last dim of logits where mask is True
439
+ mask_logits = logits[mask_index] # This correctly selects [num_masked_tokens, V]
440
+
441
+ if mask_logits.numel() == 0: # If no masks, logits selection is empty
442
+ print(f"No masked tokens found for logit selection at step {i}. Stopping.")
443
+ break
444
+
445
+ # print(f"Step {i}: mask_logits shape: {mask_logits.shape}") # Debug
446
+
447
+ # --- Sampling / Remasking Logic ---
448
+ t = timesteps[i]
449
+ s = timesteps[i + 1]
450
+
451
+ x_new_masked_part = torch.full_like(x[mask_index], MASK_ID, device=device, dtype=torch.long)
452
+
453
+ if alg == 'origin':
454
+ # Original diffusion logic
455
+ p_transfer = (1.0 - s / t) if i < steps - 1 else 1.0 # Ensure float division
456
+ # Sample only for the tokens to be revealed in this step
457
+ num_masked = mask_logits.shape[0]
458
+ transfer_indices_relative = torch.rand(num_masked, device=device) < p_transfer
459
+ logits_to_sample = mask_logits[transfer_indices_relative]
460
+
461
+ if logits_to_sample.numel() > 0:
462
+ # print(f"Step {i} (origin): Sampling {logits_to_sample.shape[0]} tokens.") # Debug
463
+ _, sampled_tokens = sample_tokens(logits_to_sample, temperature=temperature, top_p=top_p_val, top_k=top_k_val)
464
+ # Place sampled tokens into the correct positions within the masked part
465
+ x_new_masked_part[transfer_indices_relative] = sampled_tokens
466
+ # else:
467
+ # print(f"Step {i} (origin): No tokens to sample (p_transfer={p_transfer}).") # Debug
468
+
469
+ else:
470
+ # Confidence-based algorithms (maskgit_plus, topk_margin, entropy)
471
+ use_margin = (alg == 'topk_margin')
472
+ use_entropy = (alg == 'entropy')
473
+ # print(f"Step {i} ({alg}): Sampling all {mask_logits.shape[0]} masked tokens for confidence.") # Debug
474
+ confidence, x0_candidates = sample_tokens(
475
+ mask_logits,
476
+ temperature=temperature,
477
+ top_p=top_p_val,
478
+ top_k=top_k_val,
479
+ margin_confidence=use_margin,
480
+ neg_entropy=use_entropy
481
+ )
482
+ # print(f"Step {i} ({alg}): Confidence range: [{confidence.min():.2f}, {confidence.max():.2f}]") # Debug
483
+
484
+
485
+ num_mask_token = mask_logits.shape[0]
486
+ # Calculate number to reveal based on time steps, ensure it's an int
487
+ target_num_revealed_float = num_mask_token * (1.0 - s / t)
488
+ number_transfer_tokens = int(target_num_revealed_float) if i < steps - 1 else num_mask_token
489
+
490
+
491
+ if number_transfer_tokens > 0:
492
+ # print(f"Step {i} ({alg}): Need to reveal {number_transfer_tokens} tokens.") # Debug
493
+ if alg_temp_val is None or alg_temp_val <= 0: # Use top-k confidence
494
+ # Sort by confidence (use negative entropy directly if alg='entropy')
495
+ # For entropy, lower (more negative) is higher confidence (less uncertainty)
496
+ sort_metric = confidence if alg != 'entropy' else -confidence
497
+ _, transfer_indices_relative = torch.topk(sort_metric, k=min(number_transfer_tokens, num_mask_token)) # Ensure k is not > num_mask_token
498
+ else: # Use sampling based on confidence temperature
499
+ conf_probs = confidence / alg_temp_val
500
+ # Check for inf/-inf before softmax
501
+ conf_probs = torch.nan_to_num(conf_probs, nan=0.0, posinf=1e9, neginf=-1e9)
502
+ conf_probs = F.softmax(conf_probs, dim=-1)
503
+ # Check probs sum to 1
504
+ if not torch.allclose(conf_probs.sum(), torch.tensor(1.0, device=device), atol=1e-4):
505
+ print(f"Warning step {i}: Confidence probabilities do not sum to 1 after softmax ({conf_probs.sum()}). Re-normalizing.")
506
+ conf_probs = conf_probs / conf_probs.sum(dim=-1, keepdim=True) # Normalize
507
+
508
+ # Ensure num_samples is valid
509
+ num_samples = min(number_transfer_tokens, num_mask_token)
510
+ if conf_probs.numel() > 0 and num_samples > 0:
511
+ try:
512
+ transfer_indices_relative = torch.multinomial(conf_probs, num_samples=num_samples, replacement=False)
513
+ except RuntimeError as e:
514
+ print(f"Warning step {i}: Multinomial sampling failed ('{e}'). Falling back to top-k.")
515
+ # Fallback to top-k if multinomial fails (e.g., due to prob issues)
516
+ sort_metric = confidence if alg != 'entropy' else -confidence
517
+ _, transfer_indices_relative = torch.topk(sort_metric, k=num_samples)
518
+ else:
519
+ transfer_indices_relative = torch.tensor([], dtype=torch.long, device=device) # No indices if no probs or num_samples=0
520
+
521
+ # Place the selected candidate tokens into the masked part update
522
+ if transfer_indices_relative.numel() > 0:
523
+ x_new_masked_part[transfer_indices_relative] = x0_candidates[transfer_indices_relative].clone()
524
+ # else:
525
+ # print(f"Step {i} ({alg}): No tokens revealed via confidence ({number_transfer_tokens} target).") # Debug
526
+
527
+ # Update the global state `x` only at the masked positions
528
+ x[mask_index] = x_new_masked_part
529
+
530
+ # --- Apply Constraints ---
531
+ # Constraints should be applied *after* sampling/revealing tokens for the step
532
+ x = apply_constraints_to_state(x, prompt_length, total_length, parsed_constraints, current_step=i)
533
+
534
+ # --- Yield Visualization ---
535
+ current_generated_tokens = x[0, prompt_length:].cpu() # Get generated part, move to CPU
536
+ vis_data = []
537
+ for j in range(gen_length):
538
+ current_tok_id = current_generated_tokens[j].item()
539
+ previous_tok_id = previous_tokens_vis[j].item() if previous_tokens_vis is not None else MASK_ID
540
 
541
+ try:
542
+ decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False)
543
+ display_token = MASK_TOKEN if current_tok_id == MASK_ID else decoded_token
544
+ except Exception:
545
+ display_token = f"[ID:{current_tok_id}]" # Fallback
 
546
 
547
+ color = None
548
+ token_to_display = display_token
 
549
 
550
+ if current_tok_id == MASK_ID:
551
+ color = "#444444" # Dark Gray for masks
552
+ elif previous_tok_id == MASK_ID: # Token was just revealed
553
+ color = "#66CC66" # Light Green
554
+ else: # Token was already revealed
555
+ color = "#6699CC" # Light Blue
556
+
557
+ # Hide special tokens (PAD/EOS) if they were already revealed (LLaDA effect)
558
+ # Ensure PAD_ID and EOS_ID are not None before checking
559
+ should_hide = (PAD_ID is not None and current_tok_id == PAD_ID) or \
560
+ (EOS_ID is not None and current_tok_id == EOS_ID)
561
+ if should_hide and previous_tok_id == current_tok_id:
562
+ token_to_display = "" # Hide by making empty
563
+ color = None # No color for hidden
564
+
565
+
566
+ if token_to_display:
567
+ vis_data.append((token_to_display, color))
568
+ elif len(vis_data) > 0 and isinstance(vis_data[-1], tuple) and vis_data[-1][0] == " ":
569
+ # Avoid adding multiple spaces if tokens are hidden consecutively
570
+ pass
571
+ elif len(vis_data) > 0 and not isinstance(vis_data[-1], tuple) and vis_data[-1] == " ":
572
+ pass # Already added a space
573
+ elif len(vis_data) > 0 :
574
+ # Add a single space if hiding follows a visible token, improves readability slightly
575
+ # Let's simplify: just omit hidden tokens. Adding spaces might be complex.
576
+ pass
577
+
578
+ # Update previous state for the next iteration
579
+ previous_tokens_vis = current_generated_tokens
580
+
581
+ # Decode intermediate response (might be partial) - skip specials for readability
582
+ intermediate_response_tokens = x[0, prompt_length:]
583
+ intermediate_response_text = tokenizer.decode(
584
+ intermediate_response_tokens,
585
+ skip_special_tokens=True,
586
+ clean_up_tokenization_spaces=True
587
+ ).strip()
588
+
589
+ # Yield current state
590
+ # We yield the *current* history, the vis data for this step, and intermediate text
591
+ # The final text will overwrite the intermediate text in the UI eventually
592
+ yield history_copy, vis_data, intermediate_response_text
593
+ time.sleep(visualization_delay)
594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
595
  end_time = time.time()
596
  print(f"Dream generation finished in {end_time - start_time:.2f} seconds.")
597
 
598
+ # --- 6. Final Processing & Yield ---
599
+ final_sequence = x[0]
600
  response_tokens = final_sequence[prompt_length:]
601
 
602
  # Decode the final response text
 
606
  clean_up_tokenization_spaces=True
607
  ).strip()
608
 
609
+ # Update history with the final response *before* the last yield
610
+ history_copy[-1][1] = final_response_text
611
 
612
+ # Yield the final state (which might be the same as the last yielded state if loop finished)
613
+ # Need to format vis_data one last time based on the final `x`
614
+ final_generated_tokens = x[0, prompt_length:].cpu()
615
+ vis_data_final = []
 
 
 
 
 
 
 
 
 
 
 
 
616
  for j in range(gen_length):
617
+ current_tok_id = final_generated_tokens[j].item()
618
  previous_tok_id = previous_tokens_vis[j].item() if previous_tokens_vis is not None else MASK_ID
619
 
 
620
  try:
 
621
  decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False)
622
+ display_token = MASK_TOKEN if current_tok_id == MASK_ID else decoded_token
 
 
 
 
 
623
  except Exception:
624
  display_token = f"[ID:{current_tok_id}]" # Fallback
625
 
 
626
  color = None
627
  token_to_display = display_token
628
 
629
  if current_tok_id == MASK_ID:
630
+ color = "#444444"
631
+ elif previous_tok_id == MASK_ID:
632
+ color = "#66CC66"
633
+ else:
634
+ color = "#6699CC"
635
+
636
+ should_hide = (PAD_ID is not None and current_tok_id == PAD_ID) or \
637
+ (EOS_ID is not None and current_tok_id == EOS_ID)
638
+ if should_hide and previous_tok_id == current_tok_id:
 
639
  token_to_display = ""
640
+ color = None
641
 
642
+ if token_to_display:
643
+ vis_data_final.append((token_to_display, color))
 
 
 
 
 
 
 
644
 
645
+ # Yield the final history, final visualization, and final text
646
+ yield history_copy, vis_data_final, final_response_text
647
+ print("Visualization streaming complete.")
648
 
 
 
649
 
650
+ except Exception as e:
651
+ print(f"Error during generation or processing: {e}")
652
+ import traceback
653
+ traceback.print_exc()
654
+ # Update history with error message? Or leave as None? Let's leave as None.
655
+ yield history_copy, [("Error during generation.", "red")], ""
656
+ return
657
 
658
 
659
+ # --- Gradio UI (Remains largely the same, ensures outputs match yield structure) ---
660
  css = '''
661
  .category-legend{display:none}
662
  button{min-height: 60px}
 
666
  gr.Markdown("# Dream 7B - Diffusion Language Model Demo")
667
  gr.Markdown(
668
  "[[Model Card](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)] "
669
+ "[[Blog](https://hkunlp.github.io/blog/2025/dream/)]" # Note: Link might be hypothetical
670
  )
671
 
672
  # STATE MANAGEMENT
673
+ # chat_history = gr.State([]) # Use gr.Chatbot's internal state implicitly if possible, or manage manually
674
+ # Let's manage manually with a list for clarity with yielding updates
675
+ _chat_history_store = gr.State([]) # Hidden state to store actual history list
676
 
677
  # UI COMPONENTS
678
  with gr.Row():
 
681
  label="Conversation",
682
  height=500,
683
  show_copy_button=True,
684
+ bubble_full_width=False,
685
+ # value=[] # Initialize chatbot UI empty
686
  )
687
 
688
  # Message input
 
709
  output_vis = gr.HighlightedText(
710
  label="Denoising Process Visualization",
711
  combine_adjacent=False,
712
+ show_legend=True, # Legend isn't very informative here
713
+ interactive=False # Not interactive
714
+ )
715
+ # Add a text box to display the final/intermediate response clearly
716
+ response_text_display = gr.Textbox(
717
+ label="Generated Response",
718
+ interactive=False,
719
+ lines=5 # Show a few lines
720
  )
721
 
722
+
723
  # Advanced generation settings
724
  with gr.Accordion("Generation Settings", open=False):
725
  with gr.Row():
 
734
  with gr.Row():
735
  temperature = gr.Slider(
736
  minimum=0.0, maximum=1.0, value=0.4, step=0.05,
737
+ label="Temperature (0 = greedy)"
738
  )
739
  alg_temp = gr.Slider(
740
  minimum=0.0, maximum=1.0, value=0.1, step=0.05,
741
+ label="Remasking Temp (Confidence Algs)"
742
  )
743
 
744
  with gr.Row():
745
  top_p = gr.Slider(
746
  minimum=0.0, maximum=1.0, value=0.95, step=0.05,
747
+ label="Top-P (<=0 or >=1 disables)" # Clarify disabling condition
748
  )
749
  top_k = gr.Slider(
750
  minimum=0, maximum=200, value=0, step=5,
751
+ label="Top-K (0 disables)"
752
  )
753
 
754
  with gr.Row():
 
760
 
761
  with gr.Row():
762
  visualization_delay = gr.Slider(
763
+ minimum=0.0, maximum=0.5, value=0.03, step=0.01, # Slightly faster default
764
  label="Visualization Delay (seconds)"
765
  )
766
 
767
  # Clear button
768
  clear_btn = gr.Button("Clear Conversation")
769
 
 
 
 
770
  # --- Event Handlers ---
771
 
772
+ def add_user_message_to_history(message: str, history_store: List[List[Optional[str]]]):
773
  """Adds user message, clears input, prepares for bot response."""
774
  if not message.strip():
775
  gr.Warning("Please enter a message.")
776
+ # Return unchanged history, empty vis, empty response text
777
+ return history_store, history_store, "", [], ""
778
 
779
  # Add user message with placeholder for bot response
780
+ history_store.append([message, None])
781
+ # Return updated history store, history for chatbot UI, empty input, empty vis, empty response
782
+ return history_store, history_store, "", [], ""
 
783
 
784
  def clear_conversation():
785
+ """Clears the chat history, visualization, and response text."""
786
+ return [], [], "", [], "" # History store, chatbot UI, input, vis, response text
787
 
788
  # --- Connect UI elements ---
789
 
790
  # Define the inputs for the generation function once
791
  generation_inputs = [
792
+ _chat_history_store, gen_length, steps, constraints_input,
793
  temperature, top_p, top_k, remasking_strategy, alg_temp,
794
  visualization_delay
795
  ]
796
  # Define the outputs for the generation function
797
+ # Now yields: history_copy, vis_data, intermediate_response_text
798
+ # Map these to: chatbot_ui, output_vis, response_text_display
799
+ generation_outputs = [chatbot_ui, output_vis, response_text_display]
800
 
801
  # Handle Textbox Submission (Enter key)
802
  submit_listener = user_input.submit(
803
  fn=add_user_message_to_history,
804
+ inputs=[user_input, _chat_history_store],
805
+ outputs=[_chat_history_store, chatbot_ui, user_input, output_vis, response_text_display] # Step 1: Add user msg & clear outputs
806
  )
807
  # Chain the bot response generation after the user message is added
808
  submit_listener.then(
809
  fn=generate_dream_response,
810
  inputs=generation_inputs,
811
+ outputs=generation_outputs, # Step 2: Generate response and stream vis/text
812
+ show_progress="hidden" # Hide default progress bar as we have live vis
813
  )
814
 
815
  # Handle Send Button Click
816
  click_listener = send_btn.click(
817
  fn=add_user_message_to_history,
818
+ inputs=[user_input, _chat_history_store],
819
+ outputs=[_chat_history_store, chatbot_ui, user_input, output_vis, response_text_display] # Step 1: Add user msg & clear outputs
820
  )
821
  # Chain the bot response generation after the user message is added
822
  click_listener.then(
823
  fn=generate_dream_response,
824
  inputs=generation_inputs,
825
+ outputs=generation_outputs, # Step 2: Generate response and stream vis/text
826
+ show_progress="hidden"
827
  )
828
 
829
+ # Clear Button Action
830
  clear_btn.click(
831
  clear_conversation,
832
  inputs=[],
833
+ outputs=[_chat_history_store, chatbot_ui, user_input, output_vis, response_text_display]
834
  )
835
 
836
  return demo
 
839
  if __name__ == "__main__":
840
  demo = create_chatbot_demo()
841
  # Use queue for handling multiple users and streaming
842
+ demo.queue().launch(debug=True, share=False) # Set share=True for public link