multimodalart HF Staff commited on
Commit
0d2292c
·
verified ·
1 Parent(s): 6c9bbe6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +203 -166
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # llada_app.py -> dream_app.py
2
 
3
  import torch
4
  import numpy as np
@@ -32,21 +32,32 @@ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
32
  print("Model and tokenizer loaded.")
33
 
34
  # --- Constants for DREAM ---
35
- # Find the mask token and ID from the DREAM tokenizer
36
  if tokenizer.mask_token is None:
37
- # Handle cases where a mask token might not be explicitly set
38
- # You might need to choose a suitable placeholder or investigate further
39
- # For now, let's try adding one if it's missing and check its id
40
- # This is speculative and might depend on the specific tokenizer setup
41
- print("Warning: Mask token not found in tokenizer. Attempting to add.")
42
  tokenizer.add_special_tokens({'mask_token': '[MASK]'})
43
  model.resize_token_embeddings(len(tokenizer)) # Important if vocab size changed
44
- if tokenizer.mask_token is None:
45
- raise ValueError("Could not set a mask token for the tokenizer.")
46
 
47
  MASK_TOKEN = tokenizer.mask_token
48
  MASK_ID = tokenizer.mask_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  print(f"Using MASK_TOKEN='{MASK_TOKEN}' with ID={MASK_ID}")
 
 
50
  # --- Helper Functions (Constraint Parsing, History Formatting) ---
51
 
52
  def parse_constraints(constraints_text):
@@ -136,6 +147,7 @@ def dream_generate_response_with_visualization(
136
  print("Processing constraints:")
137
  for pos, word in constraints.items():
138
  # Prepend space for consistent tokenization, similar to LLaDA example
 
139
  tokens = tokenizer.encode(" " + word, add_special_tokens=False)
140
  if not tokens:
141
  print(f" Warning: Could not tokenize constraint word '{word}' at position {pos}. Skipping.")
@@ -149,7 +161,6 @@ def dream_generate_response_with_visualization(
149
  print(f" Warning: Overlapping constraint at position {pos+i}. Keeping first.")
150
 
151
  # Prepare the prompt using chat template
152
- # Note: DREAM examples use add_generation_prompt=True
153
  try:
154
  inputs = tokenizer.apply_chat_template(
155
  messages,
@@ -161,17 +172,10 @@ def dream_generate_response_with_visualization(
161
  attention_mask = inputs.attention_mask.to(device=device) # Get attention mask
162
  prompt_length = input_ids.shape[1]
163
  print(f"Input prompt length: {prompt_length}")
164
- print(f"Input IDs: {input_ids}")
165
  except Exception as e:
166
  print(f"Error applying chat template: {e}")
167
- # Provide a fallback or raise the error
168
- # Fallback: Simple concatenation (less ideal for instruction models)
169
- # chat_input = "".join([f"{msg['role']}: {msg['content']}\n" for msg in messages]) + "assistant:"
170
- # input_ids = tokenizer(chat_input, return_tensors="pt").input_ids.to(device)
171
- # attention_mask = torch.ones_like(input_ids)
172
- # prompt_length = input_ids.shape[1]
173
- # print(f"Warning: Using basic concatenation due to template error. Prompt length: {prompt_length}")
174
- return [([("Error applying chat template.", "red")],)], f"Error: {e}"
175
 
176
 
177
  if prompt_length + gen_length > 2048: # Check context length (DREAM uses 2048)
@@ -179,7 +183,7 @@ def dream_generate_response_with_visualization(
179
  gen_length = 2048 - prompt_length
180
  if gen_length <= 0:
181
  print("Error: Prompt is already too long.")
182
- return [([("Prompt too long.", "red")],)], "Error: Prompt too long."
183
 
184
 
185
  # --- State for Visualization Hook ---
@@ -192,74 +196,80 @@ def dream_generate_response_with_visualization(
192
  for pos, token_id in processed_constraints.items():
193
  absolute_pos = pos # Position relative to start of generation
194
  if 0 <= absolute_pos < gen_length:
195
- initial_x_part[0, absolute_pos] = token_id
196
-
197
- initial_state_vis = []
198
- for i in range(gen_length):
199
- token_id = initial_x_part[0, i].item()
200
- if token_id == MASK_ID:
201
- initial_state_vis.append((MASK_TOKEN, "#444444")) # Mask color
202
- else:
203
- # This must be a constraint applied initially
204
- token_str = tokenizer.decode([token_id], skip_special_tokens=True)
205
- initial_state_vis.append((token_str if token_str else "?", "#800080")) # Constraint color (purple)
206
- visualization_states.append(initial_state_vis)
207
 
208
  # --- Define the Hook Function ---
 
209
  def generation_tokens_hook_func(step, x, logits):
210
  nonlocal last_x, visualization_states # Allow modification of outer scope variables
211
- print(f"Hook called for step {step}")
212
 
213
- current_x = x.clone() # Work on a copy for comparison
214
 
215
- # 1. Apply Constraints *before* generating visualization
216
  # Constraints are relative to the start of the *generated* part
217
  constrained_x = current_x.clone()
218
- prompt_len = current_x.shape[1] - gen_length # Recalculate just in case
219
- if prompt_len < 0:
220
  print("Warning: prompt_len negative in hook, skipping constraints/vis.")
221
  return current_x # Return unmodified if something is wrong
222
 
223
- constraints_applied_this_step = False
224
  for pos, token_id in processed_constraints.items():
225
- absolute_pos = prompt_len + pos
226
- if prompt_len <= absolute_pos < current_x.shape[1]:
 
227
  if constrained_x[0, absolute_pos] != token_id:
228
  constrained_x[0, absolute_pos] = token_id
229
- constraints_applied_this_step = True
230
  # print(f" Constraint applied at pos {pos} ({absolute_pos}) -> token {token_id}")
231
 
232
 
233
  # 2. Generate Visualization State for *this* step
 
 
234
  current_state_vis = []
235
- # Compare current_x (before explicit constraint application in *this* hook call)
236
- # with last_x (state from *previous* hook call / initial state)
237
- # Generate based on the state *before* reapplying constraints here,
238
- # but *after* the model's diffusion step determined current_x.
239
- gen_part_current = current_x[0, prompt_len:]
240
- gen_part_last = last_x[0, prompt_len:] if last_x is not None else None
241
 
242
  for i in range(gen_length):
243
  current_token_id = gen_part_current[i].item()
244
- token_str = tokenizer.decode([current_token_id], skip_special_tokens=True).strip()
245
- # Use a placeholder if decoding results in empty string
246
- display_token = token_str if token_str else MASK_TOKEN if current_token_id == MASK_ID else "?"
247
 
248
- # Check if this position is constrained
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  is_constrained = i in processed_constraints
250
 
251
  if current_token_id == MASK_ID:
252
- color = "#444444" # Dark Gray for masks
253
  elif is_constrained and processed_constraints[i] == current_token_id:
254
- color = "#800080" # Purple for correctly constrained tokens
255
- elif gen_part_last is None or gen_part_last[i].item() == MASK_ID:
256
- # Newly revealed (was mask in previous step or initial state)
257
- color = "#66CC66" # Light Green
258
- else:
259
- # Previously revealed and not masked
260
- color = "#6699CC" # Light Blue
 
 
261
 
262
- current_state_vis.append((display_token, color))
263
 
264
  visualization_states.append(current_state_vis)
265
 
@@ -268,7 +278,6 @@ def dream_generate_response_with_visualization(
268
  last_x = constrained_x.clone()
269
 
270
  # 4. Return the sequence with constraints applied for the model's next step
271
- # print(f"Hook returning constrained_x: {constrained_x[:, prompt_len:]}")
272
  return constrained_x # Return the sequence with constraints enforced
273
 
274
 
@@ -277,8 +286,30 @@ def dream_generate_response_with_visualization(
277
  print("Calling model.diffusion_generate...")
278
  # Make sure last_x is initialized correctly before the first hook call
279
  # It should represent the state *before* the first diffusion step.
 
280
  initial_full_x = torch.cat([input_ids, initial_x_part], dim=1)
281
- last_x = initial_full_x.clone()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
  output = model.diffusion_generate(
284
  input_ids,
@@ -296,63 +327,55 @@ def dream_generate_response_with_visualization(
296
  print("model.diffusion_generate finished.")
297
 
298
  # Extract final generated sequence (response part only)
299
- # The hook ensures the returned sequence has constraints applied
300
  final_sequence = output.sequences[0]
301
  response_token_ids = final_sequence[prompt_length:]
302
 
303
- # Decode the final response
304
  final_text = tokenizer.decode(
305
  response_token_ids,
306
  skip_special_tokens=True,
307
- clean_up_tokenization_spaces=True # Recommended for cleaner output
308
  ).strip()
309
  print(f"Final generated text: {final_text}")
310
 
311
- # Add the very final state to visualization if the hook didn't capture it
312
- # (Should be captured, but as a safeguard)
313
- if len(visualization_states) <= steps: # Hook might run 'steps' times
314
- final_state_vis = []
315
- final_gen_part = final_sequence[prompt_length:]
316
- for i in range(gen_length):
317
- token_id = final_gen_part[i].item()
318
- token_str = tokenizer.decode([token_id], skip_special_tokens=True).strip()
319
- display_token = token_str if token_str else MASK_TOKEN if token_id == MASK_ID else "?"
320
- is_constrained = i in processed_constraints
321
-
322
- if token_id == MASK_ID: color = "#444444"
323
- elif is_constrained and processed_constraints[i] == token_id: color = "#800080"
324
- else: color = "#6699CC" # Default to blue for final state tokens
325
- final_state_vis.append((display_token, color))
326
- visualization_states.append(final_state_vis)
327
 
328
 
329
  except Exception as e:
330
  print(f"Error during generation: {e}")
331
  import traceback
332
  traceback.print_exc()
333
- # Add error message to visualization
334
  error_msg = f"Error during generation: {str(e)}"
335
- visualization_states.append([("Error", "red")])
336
  final_text = f"Generation failed: {e}"
337
 
338
  print("--- DREAM Generation Finished ---")
 
339
  return visualization_states, final_text
340
 
341
 
342
  # --- Gradio UI Setup ---
343
 
344
  css = '''
345
- .category-legend{display:none}
346
- /* button{height: 60px} */ /* Optional: Adjust button height */
 
347
  .small_btn {
348
  max-width: 100px; /* Adjust as needed */
 
349
  height: 40px; /* Adjust as needed */
350
- flex-grow: 0; /* Prevent button from growing */
351
- margin-left: 5px; /* Add some space */
 
 
 
 
352
  }
353
  .chat-input-row {
354
  display: flex;
355
  align-items: center; /* Vertically align items */
 
356
  }
357
  .chat-input-row > * {
358
  margin-right: 5px; /* Space between textbox and button */
@@ -360,7 +383,43 @@ css = '''
360
  .chat-input-row > *:last-child {
361
  margin-right: 0;
362
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  def create_chatbot_demo():
365
  with gr.Blocks(css=css) as demo:
366
  gr.Markdown("# Dream 7B - Diffusion Language Model Demo")
@@ -376,7 +435,7 @@ def create_chatbot_demo():
376
  chatbot_ui = gr.Chatbot(
377
  label="Conversation",
378
  height=500,
379
- bubble_full_width=False # Improves layout for shorter messages
380
  )
381
 
382
  # Message input Row
@@ -384,8 +443,8 @@ def create_chatbot_demo():
384
  user_input = gr.Textbox(
385
  label="Your Message",
386
  placeholder="Type your message here and press Enter...",
387
- scale=4, # Give textbox more space
388
- container=False, # Remove container background/padding
389
  show_label=False
390
  )
391
  send_btn = gr.Button("Send", scale=1, elem_classes="small_btn")
@@ -394,41 +453,34 @@ def create_chatbot_demo():
394
  label="Word Constraints (Optional)",
395
  info="Force specific words at positions (0-indexed from response start). Format: 'pos:word, pos:word'. Example: '0:Once, 5:upon, 10:time'",
396
  placeholder="e.g., 0:Hello, 6:world",
397
- value="" # Default empty
398
  )
399
  with gr.Column(scale=2):
400
  output_vis = gr.HighlightedText(
401
  label="Denoising Process Visualization",
402
- combine_adjacent=False,
403
- show_legend=False, # Keep legend off as requested
404
- # Color map for legend (though hidden)
405
- # color_map={
406
- # "Mask": "#444444",
407
- # "New": "#66CC66",
408
- # "Old": "#6699CC",
409
- # "Constraint": "#800080",
410
- # "Error": "red"
411
- # }
412
- )
413
- gr.Markdown(
414
- "**Color Legend:** <span style='color:#444444'>■ Mask</span> | <span style='color:#66CC66'>■ Newly Generated</span> | <span style='color:#6699CC'>■ Previously Generated</span> | <span style='color:#800080'>■ Constraint</span>"
415
  )
 
 
416
 
417
 
418
  # Advanced generation settings
419
  with gr.Accordion("Generation Settings", open=False):
420
  with gr.Row():
421
  gen_length = gr.Slider(
422
- minimum=16, maximum=512, value=128, step=8, # Increased max length
423
  label="Max New Tokens"
424
  )
425
  steps = gr.Slider(
426
- minimum=8, maximum=512, value=128, step=8, # Increased max steps
427
  label="Diffusion Steps"
428
  )
429
  with gr.Row():
430
  temperature = gr.Slider(
431
- minimum=0.0, maximum=1.5, value=0.6, step=0.05, # Wider range for temp
432
  label="Temperature"
433
  )
434
  top_p = gr.Slider(
@@ -436,15 +488,14 @@ def create_chatbot_demo():
436
  label="Top-P (Nucleus Sampling)"
437
  )
438
  with gr.Row():
439
- # Map UI choices to DREAM's alg parameters
440
  remasking_strategy = gr.Radio(
441
  choices=[
442
- ("Random", "origin"), # User friendly name -> actual param
443
  ("Entropy", "entropy"),
444
  ("MaskGit+", "maskgit_plus"),
445
  ("TopK Margin", "topk_margin"),
446
  ],
447
- value="entropy", # Default
448
  label="Generation Order Strategy (alg)"
449
  )
450
  alg_temp = gr.Slider(
@@ -462,9 +513,6 @@ def create_chatbot_demo():
462
  # Clear button
463
  clear_btn = gr.Button("Clear Conversation")
464
 
465
- # Hidden textbox to potentially store intermediate response (might not be needed)
466
- # current_response = gr.Textbox(visible=False)
467
-
468
  # --- Event Handlers ---
469
 
470
  # Helper to add message to history state
@@ -478,22 +526,12 @@ def create_chatbot_demo():
478
  print(f"User submitted: '{message}'")
479
  if not message or not message.strip():
480
  print("Empty message submitted, doing nothing.")
481
- # Return unchanged state if message is empty
482
- # Need to return values for all outputs of the .submit/.click
483
  return history, history, "", [] # history, chatbot_ui, user_input, output_vis
484
 
485
- # Add user message to history (with None for bot response initially)
486
  history = add_message_to_history(history, message, None)
487
-
488
- # Prepare updated history for display in Chatbot UI
489
  history_for_display = history.copy()
490
-
491
- # Clear the input textbox
492
  message_out = ""
493
- # Clear the visualization
494
- vis_clear = []
495
-
496
- # Return updated history state, chatbot display, cleared input, cleared visualization
497
  return history, history_for_display, message_out, vis_clear
498
 
499
  # Function to generate bot response (triggered after user message is processed)
@@ -504,18 +542,13 @@ def create_chatbot_demo():
504
  print("--- Generating Bot Response ---")
505
  if not history or history[-1][1] is not None:
506
  print("History empty or last message already has response. Skipping generation.")
507
- # Yield current state if called unnecessarily
508
- yield history, [], "No response generated."
509
  return
510
 
511
- # Get the conversation history in the format the model expects
512
- messages = format_chat_history(history) # Includes the latest user query
513
-
514
- # Parse constraints from the textbox
515
  parsed_constraints = parse_constraints(constraints_text)
516
 
517
  try:
518
- # Generate response with visualization
519
  vis_states, response_text = dream_generate_response_with_visualization(
520
  messages,
521
  gen_length=gen_length,
@@ -527,31 +560,30 @@ def create_chatbot_demo():
527
  alg_temp=alg_temp
528
  )
529
 
530
- # Update the history state with the final bot response
531
- history[-1][1] = response_text.strip()
532
-
533
- # Yield the initial visualization state immediately
534
- if vis_states:
535
- yield history, vis_states[0] # Update chatbot, update visualization
536
- else:
537
- # Handle case where generation failed before first state
538
- yield history, [("Generation failed.", "red")]
539
 
540
- # Then animate through the rest of the visualization states
541
- for state in vis_states[1:]:
542
- time.sleep(delay)
543
- yield history, state # Update chatbot (implicitly via history), update visualization
 
 
 
 
 
544
 
545
  except Exception as e:
546
  print(f"Error in bot_response_generator: {e}")
547
  import traceback
548
  traceback.print_exc()
549
  error_msg = f"Error: {str(e)}"
550
- # Show error in visualization
551
- error_vis = [(error_msg, "red")]
552
- # Update history with error message? Optional.
553
- # history[-1][1] = error_msg
554
- yield history, error_vis
555
 
556
  # Function to clear everything
557
  def clear_conversation():
@@ -561,34 +593,39 @@ def create_chatbot_demo():
561
  # --- Wire UI elements to functions ---
562
 
563
  # Typing in Textbox and pressing Enter
564
- user_input.submit(
565
  fn=user_message_submitted,
566
  inputs=[user_input, chat_history],
567
- outputs=[chat_history, chatbot_ui, user_input, output_vis], # Update history state, chatbot display, clear input, clear vis
568
- queue=False # Process immediately
569
- ).then(
570
- fn=bot_response_generator,
571
- inputs=[
572
- chat_history, gen_length, steps, constraints_input, visualization_delay,
573
- temperature, top_p, remasking_strategy, alg_temp
574
- ],
575
- outputs=[chatbot_ui, output_vis] # Update chatbot display (with new response), update visualization
576
- # Note: history state is updated implicitly by bot_response_generator modifying its input
577
  )
578
 
579
  # Clicking the Send button
580
- send_btn.click(
581
  fn=user_message_submitted,
582
  inputs=[user_input, chat_history],
583
  outputs=[chat_history, chatbot_ui, user_input, output_vis],
584
  queue=False
585
- ).then(
586
- fn=bot_response_generator,
587
- inputs=[
 
 
588
  chat_history, gen_length, steps, constraints_input, visualization_delay,
589
  temperature, top_p, remasking_strategy, alg_temp
590
- ],
591
- outputs=[chatbot_ui, output_vis]
 
 
 
 
 
 
 
 
 
 
 
592
  )
593
 
594
  # Clicking the Clear button
 
1
+ # dream_app.py (Updated)
2
 
3
  import torch
4
  import numpy as np
 
32
  print("Model and tokenizer loaded.")
33
 
34
  # --- Constants for DREAM ---
 
35
  if tokenizer.mask_token is None:
36
+ print("Warning: Mask token not found in tokenizer. Attempting to add '[MASK]'.")
 
 
 
 
37
  tokenizer.add_special_tokens({'mask_token': '[MASK]'})
38
  model.resize_token_embeddings(len(tokenizer)) # Important if vocab size changed
39
+ if tokenizer.mask_token is None or tokenizer.mask_token_id is None:
40
+ raise ValueError("Could not set or find ID for a mask token for the tokenizer.")
41
 
42
  MASK_TOKEN = tokenizer.mask_token
43
  MASK_ID = tokenizer.mask_token_id
44
+ EOS_TOKEN = tokenizer.eos_token # Get EOS token string
45
+ EOS_ID = tokenizer.eos_token_id # Get EOS token ID
46
+ # Add other special tokens if needed for visualization
47
+ SPECIAL_TOKENS_MAP = {
48
+ tokenizer.eos_token_id: "[EOS]",
49
+ tokenizer.bos_token_id: "[BOS]",
50
+ tokenizer.pad_token_id: "[PAD]",
51
+ tokenizer.unk_token_id: "[UNK]",
52
+ MASK_ID: MASK_TOKEN # Map mask ID back to its string representation
53
+ }
54
+ # Add None key to handle cases where token IDs might be None (shouldn't happen with tensors)
55
+ SPECIAL_TOKENS_MAP[None] = "[NONE]"
56
+
57
+
58
  print(f"Using MASK_TOKEN='{MASK_TOKEN}' with ID={MASK_ID}")
59
+ print(f"Using EOS_TOKEN='{EOS_TOKEN}' with ID={EOS_ID}")
60
+
61
  # --- Helper Functions (Constraint Parsing, History Formatting) ---
62
 
63
  def parse_constraints(constraints_text):
 
147
  print("Processing constraints:")
148
  for pos, word in constraints.items():
149
  # Prepend space for consistent tokenization, similar to LLaDA example
150
+ # Important: use add_special_tokens=False for constraints
151
  tokens = tokenizer.encode(" " + word, add_special_tokens=False)
152
  if not tokens:
153
  print(f" Warning: Could not tokenize constraint word '{word}' at position {pos}. Skipping.")
 
161
  print(f" Warning: Overlapping constraint at position {pos+i}. Keeping first.")
162
 
163
  # Prepare the prompt using chat template
 
164
  try:
165
  inputs = tokenizer.apply_chat_template(
166
  messages,
 
172
  attention_mask = inputs.attention_mask.to(device=device) # Get attention mask
173
  prompt_length = input_ids.shape[1]
174
  print(f"Input prompt length: {prompt_length}")
175
+ # print(f"Input IDs: {input_ids}") # Keep commented unless debugging
176
  except Exception as e:
177
  print(f"Error applying chat template: {e}")
178
+ return [([("Error applying chat template.", "Error")],)], f"Error: {e}"
 
 
 
 
 
 
 
179
 
180
 
181
  if prompt_length + gen_length > 2048: # Check context length (DREAM uses 2048)
 
183
  gen_length = 2048 - prompt_length
184
  if gen_length <= 0:
185
  print("Error: Prompt is already too long.")
186
+ return [([("Prompt too long.", "Error")],)], "Error: Prompt too long."
187
 
188
 
189
  # --- State for Visualization Hook ---
 
196
  for pos, token_id in processed_constraints.items():
197
  absolute_pos = pos # Position relative to start of generation
198
  if 0 <= absolute_pos < gen_length:
199
+ # Check if the constraint token itself is special
200
+ if token_id in SPECIAL_TOKENS_MAP:
201
+ print(f" Note: Constraint at pos {pos} is a special token: {SPECIAL_TOKENS_MAP[token_id]}")
202
+ initial_x_part[0, absolute_pos] = token_id
203
+
 
 
 
 
 
 
 
204
 
205
  # --- Define the Hook Function ---
206
+ # This function will be called at each diffusion step
207
  def generation_tokens_hook_func(step, x, logits):
208
  nonlocal last_x, visualization_states # Allow modification of outer scope variables
209
+ # print(f"Hook called for step {step}") # Keep commented unless debugging
210
 
211
+ current_x = x.clone() # Work on a copy for comparison/modification
212
 
213
+ # 1. Apply Constraints *before* generating visualization for this step
214
  # Constraints are relative to the start of the *generated* part
215
  constrained_x = current_x.clone()
216
+ current_prompt_len = current_x.shape[1] - gen_length # Recalculate actual prompt length
217
+ if current_prompt_len < 0:
218
  print("Warning: prompt_len negative in hook, skipping constraints/vis.")
219
  return current_x # Return unmodified if something is wrong
220
 
 
221
  for pos, token_id in processed_constraints.items():
222
+ absolute_pos = current_prompt_len + pos
223
+ if current_prompt_len <= absolute_pos < current_x.shape[1]:
224
+ # Apply constraint if the current token doesn't match
225
  if constrained_x[0, absolute_pos] != token_id:
226
  constrained_x[0, absolute_pos] = token_id
 
227
  # print(f" Constraint applied at pos {pos} ({absolute_pos}) -> token {token_id}")
228
 
229
 
230
  # 2. Generate Visualization State for *this* step
231
+ # Compare current_x (output of diffusion for this step, before constraints applied *in this call*)
232
+ # with last_x (state from *previous* hook call / initial state, *after* constraints were applied then)
233
  current_state_vis = []
234
+ gen_part_current = current_x[0, current_prompt_len:]
235
+ gen_part_last = last_x[0, current_prompt_len:] if last_x is not None else None
 
 
 
 
236
 
237
  for i in range(gen_length):
238
  current_token_id = gen_part_current[i].item()
239
+ last_token_id = gen_part_last[i].item() if gen_part_last is not None else MASK_ID # Assume mask initially
 
 
240
 
241
+ # Determine display string - Handle special tokens explicitly
242
+ if current_token_id in SPECIAL_TOKENS_MAP:
243
+ display_token = SPECIAL_TOKENS_MAP[current_token_id]
244
+ else:
245
+ # Decode non-special tokens, skipping special tokens in the *output string*
246
+ # and stripping whitespace
247
+ display_token = tokenizer.decode([current_token_id],
248
+ skip_special_tokens=True,
249
+ clean_up_tokenization_spaces=True).strip()
250
+ # If decoding results in empty string for a non-special token, use a space perhaps
251
+ if not display_token:
252
+ display_token = " " # Use a single space as placeholder
253
+
254
+
255
+ # Determine category (label) for color mapping
256
+ category = "Old" # Default assume it was revealed before
257
  is_constrained = i in processed_constraints
258
 
259
  if current_token_id == MASK_ID:
260
+ category = "Mask"
261
  elif is_constrained and processed_constraints[i] == current_token_id:
262
+ # Check if it was *just* constrained or already was correct
263
+ # We mark as 'Constraint' if it matches the required token, regardless of when it appeared
264
+ category = "Constraint"
265
+ elif last_token_id == MASK_ID and current_token_id != MASK_ID:
266
+ # It was a mask before, now it's not -> Newly revealed
267
+ # (Unless it's a constraint, handled above)
268
+ category = "New"
269
+ # else: category remains "Old"
270
+
271
 
272
+ current_state_vis.append((display_token, category))
273
 
274
  visualization_states.append(current_state_vis)
275
 
 
278
  last_x = constrained_x.clone()
279
 
280
  # 4. Return the sequence with constraints applied for the model's next step
 
281
  return constrained_x # Return the sequence with constraints enforced
282
 
283
 
 
286
  print("Calling model.diffusion_generate...")
287
  # Make sure last_x is initialized correctly before the first hook call
288
  # It should represent the state *before* the first diffusion step.
289
+ # Create the initial full sequence (prompt + initial masked/constrained part)
290
  initial_full_x = torch.cat([input_ids, initial_x_part], dim=1)
291
+ last_x = initial_full_x.clone() # Initialize last_x with the state before step 0
292
+
293
+ # Add the very first visualization state (prompt + initial masks/constraints)
294
+ # This state corresponds to the `last_x` *before* the first hook call.
295
+ initial_state_vis = []
296
+ initial_gen_part = initial_full_x[0, prompt_length:]
297
+ for i in range(gen_length):
298
+ token_id = initial_gen_part[i].item()
299
+ category = "Mask"
300
+ display_token = MASK_TOKEN
301
+ if token_id != MASK_ID:
302
+ # This must be an initial constraint
303
+ category = "Constraint"
304
+ if token_id in SPECIAL_TOKENS_MAP:
305
+ display_token = SPECIAL_TOKENS_MAP[token_id]
306
+ else:
307
+ display_token = tokenizer.decode([token_id], skip_special_tokens=True).strip()
308
+ if not display_token: display_token = " " # Placeholder
309
+
310
+ initial_state_vis.append((display_token, category))
311
+ visualization_states.append(initial_state_vis)
312
+
313
 
314
  output = model.diffusion_generate(
315
  input_ids,
 
327
  print("model.diffusion_generate finished.")
328
 
329
  # Extract final generated sequence (response part only)
 
330
  final_sequence = output.sequences[0]
331
  response_token_ids = final_sequence[prompt_length:]
332
 
333
+ # Decode the final response, skipping special tokens for the final output text
334
  final_text = tokenizer.decode(
335
  response_token_ids,
336
  skip_special_tokens=True,
337
+ clean_up_tokenization_spaces=True
338
  ).strip()
339
  print(f"Final generated text: {final_text}")
340
 
341
+ # The hook should have added the last state, no need for safeguard typically
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
 
344
  except Exception as e:
345
  print(f"Error during generation: {e}")
346
  import traceback
347
  traceback.print_exc()
348
+ # Add error message to visualization using the "Error" category
349
  error_msg = f"Error during generation: {str(e)}"
350
+ visualization_states.append([("Error", "Error")]) # Use 'Error' category
351
  final_text = f"Generation failed: {e}"
352
 
353
  print("--- DREAM Generation Finished ---")
354
+ # Return states list (already built by hook) and final text
355
  return visualization_states, final_text
356
 
357
 
358
  # --- Gradio UI Setup ---
359
 
360
  css = '''
361
+ /* Hide the default legend */
362
+ .gradio-container .output-markdown table { display: none !important; }
363
+
364
  .small_btn {
365
  max-width: 100px; /* Adjust as needed */
366
+ min-width: 60px; /* Ensure button doesn't collapse */
367
  height: 40px; /* Adjust as needed */
368
+ flex-grow: 0 !important; /* Prevent button from growing */
369
+ margin-left: 5px !important; /* Add some space */
370
+ margin-top: auto; /* Align button bottom with textbox */
371
+ margin-bottom: auto; /* Align button bottom with textbox */
372
+ line-height: 1; /* Adjust line height if text vertical align is off */
373
+ padding: 0 10px; /* Adjust padding */
374
  }
375
  .chat-input-row {
376
  display: flex;
377
  align-items: center; /* Vertically align items */
378
+ margin-bottom: 10px; /* Add space below input row */
379
  }
380
  .chat-input-row > * {
381
  margin-right: 5px; /* Space between textbox and button */
 
383
  .chat-input-row > *:last-child {
384
  margin-right: 0;
385
  }
386
+ /* Style HighlightedText elements */
387
+ .token-hl span {
388
+ padding: 2px 1px; /* Minimal padding */
389
+ margin: 0 1px; /* Minimal margin */
390
+ border-radius: 3px;
391
+ display: inline-block; /* Ensure background covers token */
392
+ line-height: 1.2; /* Adjust for better vertical spacing */
393
+ }
394
+ /* Custom legend styling */
395
+ .custom-legend span {
396
+ display: inline-block;
397
+ margin-right: 15px;
398
+ font-size: 0.9em;
399
+ }
400
+ .custom-legend span::before {
401
+ content: "■";
402
+ margin-right: 4px;
403
+ font-size: 1.1em; /* Make square slightly larger */
404
+ vertical-align: middle; /* Align square with text */
405
+ }
406
  '''
407
+ # Define color map mapping CATEGORY names to colors
408
+ color_map = {
409
+ "Mask": "#A0A0A0", # Darker Gray for masks
410
+ "New": "#77DD77", # Light Green for new tokens
411
+ "Old": "#AEC6CF", # Light Blue/Gray for old tokens
412
+ "Constraint": "#C3A0E0", # Purple for constraints
413
+ "Error": "#FF6961" # Light Red for errors
414
+ }
415
+
416
+ # Create the custom legend HTML string
417
+ legend_html = "<div class='custom-legend'>"
418
+ for category, color in color_map.items():
419
+ legend_html += f"<span style='color:{color};'>{category}</span>"
420
+ legend_html += "</div>"
421
+
422
+
423
  def create_chatbot_demo():
424
  with gr.Blocks(css=css) as demo:
425
  gr.Markdown("# Dream 7B - Diffusion Language Model Demo")
 
435
  chatbot_ui = gr.Chatbot(
436
  label="Conversation",
437
  height=500,
438
+ bubble_full_width=False
439
  )
440
 
441
  # Message input Row
 
443
  user_input = gr.Textbox(
444
  label="Your Message",
445
  placeholder="Type your message here and press Enter...",
446
+ scale=4,
447
+ container=False,
448
  show_label=False
449
  )
450
  send_btn = gr.Button("Send", scale=1, elem_classes="small_btn")
 
453
  label="Word Constraints (Optional)",
454
  info="Force specific words at positions (0-indexed from response start). Format: 'pos:word, pos:word'. Example: '0:Once, 5:upon, 10:time'",
455
  placeholder="e.g., 0:Hello, 6:world",
456
+ value=""
457
  )
458
  with gr.Column(scale=2):
459
  output_vis = gr.HighlightedText(
460
  label="Denoising Process Visualization",
461
+ combine_adjacent=False, # Keep tokens separate
462
+ show_legend=False, # Hide default legend table
463
+ color_map=color_map, # Provide the color map
464
+ elem_classes="token-hl" # Add class for token styling
 
 
 
 
 
 
 
 
 
465
  )
466
+ # Use Markdown to display the custom legend
467
+ gr.Markdown(legend_html)
468
 
469
 
470
  # Advanced generation settings
471
  with gr.Accordion("Generation Settings", open=False):
472
  with gr.Row():
473
  gen_length = gr.Slider(
474
+ minimum=16, maximum=512, value=128, step=8,
475
  label="Max New Tokens"
476
  )
477
  steps = gr.Slider(
478
+ minimum=8, maximum=512, value=128, step=8,
479
  label="Diffusion Steps"
480
  )
481
  with gr.Row():
482
  temperature = gr.Slider(
483
+ minimum=0.0, maximum=1.5, value=0.6, step=0.05,
484
  label="Temperature"
485
  )
486
  top_p = gr.Slider(
 
488
  label="Top-P (Nucleus Sampling)"
489
  )
490
  with gr.Row():
 
491
  remasking_strategy = gr.Radio(
492
  choices=[
493
+ ("Random", "origin"),
494
  ("Entropy", "entropy"),
495
  ("MaskGit+", "maskgit_plus"),
496
  ("TopK Margin", "topk_margin"),
497
  ],
498
+ value="entropy",
499
  label="Generation Order Strategy (alg)"
500
  )
501
  alg_temp = gr.Slider(
 
513
  # Clear button
514
  clear_btn = gr.Button("Clear Conversation")
515
 
 
 
 
516
  # --- Event Handlers ---
517
 
518
  # Helper to add message to history state
 
526
  print(f"User submitted: '{message}'")
527
  if not message or not message.strip():
528
  print("Empty message submitted, doing nothing.")
 
 
529
  return history, history, "", [] # history, chatbot_ui, user_input, output_vis
530
 
 
531
  history = add_message_to_history(history, message, None)
 
 
532
  history_for_display = history.copy()
 
 
533
  message_out = ""
534
+ vis_clear = [] # Clear visualization when new message submitted
 
 
 
535
  return history, history_for_display, message_out, vis_clear
536
 
537
  # Function to generate bot response (triggered after user message is processed)
 
542
  print("--- Generating Bot Response ---")
543
  if not history or history[-1][1] is not None:
544
  print("History empty or last message already has response. Skipping generation.")
545
+ yield history, [], "No response generated." # Yield current state if called unnecessarily
 
546
  return
547
 
548
+ messages = format_chat_history(history)
 
 
 
549
  parsed_constraints = parse_constraints(constraints_text)
550
 
551
  try:
 
552
  vis_states, response_text = dream_generate_response_with_visualization(
553
  messages,
554
  gen_length=gen_length,
 
560
  alg_temp=alg_temp
561
  )
562
 
563
+ # Update the history state only ONCE with the final bot response
564
+ final_history = history.copy() # Create copy to modify
565
+ final_history[-1][1] = response_text.strip() # Update the last element
 
 
 
 
 
 
566
 
567
+ # Yield visualization states one by one
568
+ # Important: Yield the *original* history for all intermediate steps,
569
+ # only yield the final_history with the *last* visualization state.
570
+ num_states = len(vis_states)
571
+ for i, state in enumerate(vis_states):
572
+ current_chatbot_state = history if i < num_states - 1 else final_history
573
+ yield current_chatbot_state, state
574
+ if delay > 0 and i < num_states - 1: # Don't sleep after last state
575
+ time.sleep(delay)
576
 
577
  except Exception as e:
578
  print(f"Error in bot_response_generator: {e}")
579
  import traceback
580
  traceback.print_exc()
581
  error_msg = f"Error: {str(e)}"
582
+ error_vis = [(error_msg, "Error")] # Use Error category
583
+ # Update history with error message? Optional.
584
+ final_history_error = history.copy()
585
+ final_history_error[-1][1] = error_msg # Add error to chatbot too
586
+ yield final_history_error, error_vis
587
 
588
  # Function to clear everything
589
  def clear_conversation():
 
593
  # --- Wire UI elements to functions ---
594
 
595
  # Typing in Textbox and pressing Enter
596
+ submit_event = user_input.submit(
597
  fn=user_message_submitted,
598
  inputs=[user_input, chat_history],
599
+ outputs=[chat_history, chatbot_ui, user_input, output_vis],
600
+ queue=False # Show user message immediately
 
 
 
 
 
 
 
 
601
  )
602
 
603
  # Clicking the Send button
604
+ click_event = send_btn.click(
605
  fn=user_message_submitted,
606
  inputs=[user_input, chat_history],
607
  outputs=[chat_history, chatbot_ui, user_input, output_vis],
608
  queue=False
609
+ )
610
+
611
+ # Chain the generation after user message is processed (for both submit and click)
612
+ # Use .then() to trigger the generator
613
+ generation_inputs = [
614
  chat_history, gen_length, steps, constraints_input, visualization_delay,
615
  temperature, top_p, remasking_strategy, alg_temp
616
+ ]
617
+ generation_outputs = [chatbot_ui, output_vis]
618
+
619
+ submit_event.then(
620
+ fn=bot_response_generator,
621
+ inputs=generation_inputs,
622
+ outputs=generation_outputs
623
+ )
624
+
625
+ click_event.then(
626
+ fn=bot_response_generator,
627
+ inputs=generation_inputs,
628
+ outputs=generation_outputs
629
  )
630
 
631
  # Clicking the Clear button