multimodalart HF Staff commited on
Commit
69595ed
·
verified ·
1 Parent(s): b861a35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -360
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # dream_app.py (Updated)
2
 
3
  import torch
4
  import numpy as np
@@ -11,15 +11,11 @@ import re # Keep for parsing constraints
11
 
12
  # Use try-except for space deployment vs local
13
  try:
14
- # Used for spaces deployment with GPU
15
  gpu_check = spaces.GPU
16
  print("Running in Gradio Spaces with GPU environment.")
17
  except AttributeError:
18
- # Fallback for local execution or environments without spaces.GPU
19
  print("Running in local environment or without spaces.GPU.")
20
- # Define a dummy decorator if spaces.GPU is not available
21
- def gpu_check(func):
22
- return func
23
 
24
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
25
  print(f"Using device: {device}")
@@ -27,39 +23,50 @@ print(f"Using device: {device}")
27
  # --- Load DREAM Model and Tokenizer ---
28
  model_path = "Dream-org/Dream-v0-Instruct-7B"
29
  print(f"Loading model: {model_path}")
30
- model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True).to(device).eval()
31
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
32
- print("Model and tokenizer loaded.")
 
 
 
 
 
 
33
 
34
  # --- Constants for DREAM ---
 
35
  if tokenizer.mask_token is None:
36
- print("Warning: Mask token not found in tokenizer. Attempting to add '[MASK]'.")
37
- tokenizer.add_special_tokens({'mask_token': '[MASK]'})
38
- model.resize_token_embeddings(len(tokenizer)) # Important if vocab size changed
39
- if tokenizer.mask_token is None or tokenizer.mask_token_id is None:
40
- raise ValueError("Could not set or find ID for a mask token for the tokenizer.")
 
 
 
 
 
41
 
42
  MASK_TOKEN = tokenizer.mask_token
43
  MASK_ID = tokenizer.mask_token_id
44
- EOS_TOKEN = tokenizer.eos_token # Get EOS token string
45
- EOS_ID = tokenizer.eos_token_id # Get EOS token ID
46
- # Add other special tokens if needed for visualization
47
- SPECIAL_TOKENS_MAP = {
48
- tokenizer.eos_token_id: "[EOS]",
49
- tokenizer.bos_token_id: "[BOS]",
50
- tokenizer.pad_token_id: "[PAD]",
51
- tokenizer.unk_token_id: "[UNK]",
52
- MASK_ID: MASK_TOKEN # Map mask ID back to its string representation
53
- }
54
- # Add None key to handle cases where token IDs might be None (shouldn't happen with tensors)
55
- SPECIAL_TOKENS_MAP[None] = "[NONE]"
56
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- print(f"Using MASK_TOKEN='{MASK_TOKEN}' with ID={MASK_ID}")
59
- print(f"Using EOS_TOKEN='{EOS_TOKEN}' with ID={EOS_ID}")
60
 
61
  # --- Helper Functions (Constraint Parsing, History Formatting) ---
62
-
63
  def parse_constraints(constraints_text):
64
  """Parse constraints in format: 'position:word, position:word, ...'"""
65
  constraints = {}
@@ -107,230 +114,182 @@ def format_chat_history(history):
107
 
108
  # --- Core Generation Logic for DREAM with Visualization ---
109
 
110
- @gpu_check # Use the potentially dummy decorator
111
  def dream_generate_response_with_visualization(
112
  messages,
113
  gen_length=64,
114
- steps=64, # Default based on DREAM examples
115
  constraints=None,
116
- temperature=0.6, # Default based on DREAM examples
117
- top_p=0.95, # Default based on DREAM examples
118
- alg="entropy", # Default based on DREAM examples
119
- alg_temp=0.0, # Default based on DREAM examples
120
  ):
121
  """
122
  Generate text with DREAM model with visualization using the generation hook.
123
-
124
- Args:
125
- messages: List of message dictionaries with 'role' and 'content'
126
- gen_length: Length of text to generate (max_new_tokens)
127
- steps: Number of diffusion steps
128
- constraints: Dictionary mapping positions (relative to response start) to words
129
- temperature: Sampling temperature
130
- top_p: Nucleus sampling p
131
- alg: Remasking algorithm ('origin', 'maskgit_plus', 'topk_margin', 'entropy')
132
- alg_temp: Temperature for confidence-based algorithms
133
-
134
- Returns:
135
- Tuple: (List of visualization states, final generated text string)
136
  """
137
  print("--- Starting DREAM Generation ---")
138
  print(f"Parameters: gen_length={gen_length}, steps={steps}, temperature={temperature}, top_p={top_p}, alg='{alg}', alg_temp={alg_temp}")
139
  print(f"Constraints: {constraints}")
140
 
141
  # --- Input Preparation ---
142
- if constraints is None:
143
- constraints = {}
144
 
145
- # Convert word constraints to token IDs (handle multi-token words)
146
  processed_constraints = {}
147
  print("Processing constraints:")
148
  for pos, word in constraints.items():
149
- # Prepend space for consistent tokenization, similar to LLaDA example
150
- # Important: use add_special_tokens=False for constraints
151
  tokens = tokenizer.encode(" " + word, add_special_tokens=False)
152
  if not tokens:
153
  print(f" Warning: Could not tokenize constraint word '{word}' at position {pos}. Skipping.")
154
  continue
155
  print(f" Pos {pos}, Word '{word}' -> Tokens {tokens}")
156
  for i, token_id in enumerate(tokens):
157
- # Ensure we don't overwrite parts of multi-token constraints accidentally
158
  if pos + i not in processed_constraints:
159
  processed_constraints[pos + i] = token_id
160
  else:
161
  print(f" Warning: Overlapping constraint at position {pos+i}. Keeping first.")
162
 
163
- # Prepare the prompt using chat template
164
  try:
165
  inputs = tokenizer.apply_chat_template(
166
- messages,
167
- return_tensors="pt",
168
- return_dict=True,
169
- add_generation_prompt=True # Crucial for instruction-tuned models like Dream-Instruct
170
  )
171
  input_ids = inputs.input_ids.to(device=device)
172
- attention_mask = inputs.attention_mask.to(device=device) # Get attention mask
173
  prompt_length = input_ids.shape[1]
174
  print(f"Input prompt length: {prompt_length}")
175
- # print(f"Input IDs: {input_ids}") # Keep commented unless debugging
176
  except Exception as e:
177
  print(f"Error applying chat template: {e}")
178
- return [([("Error applying chat template.", "Error")],)], f"Error: {e}"
179
 
180
-
181
- if prompt_length + gen_length > 2048: # Check context length (DREAM uses 2048)
182
  print(f"Warning: Requested length ({prompt_length + gen_length}) exceeds model max length (2048). Truncating gen_length.")
183
  gen_length = 2048 - prompt_length
184
  if gen_length <= 0:
185
  print("Error: Prompt is already too long.")
186
  return [([("Prompt too long.", "Error")],)], "Error: Prompt too long."
187
 
188
-
189
  # --- State for Visualization Hook ---
190
  visualization_states = []
191
- last_x = None # Store the sequence from the previous step
192
 
193
- # Initial state: Prompt + all masks
194
  initial_x_part = torch.full((1, gen_length), MASK_ID, dtype=torch.long, device=device)
195
- # Apply initial constraints to the masked part *before* showing the first state
196
  for pos, token_id in processed_constraints.items():
197
- absolute_pos = pos # Position relative to start of generation
198
  if 0 <= absolute_pos < gen_length:
199
- # Check if the constraint token itself is special
200
- if token_id in SPECIAL_TOKENS_MAP:
201
- print(f" Note: Constraint at pos {pos} is a special token: {SPECIAL_TOKENS_MAP[token_id]}")
202
- initial_x_part[0, absolute_pos] = token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
 
205
  # --- Define the Hook Function ---
206
- # This function will be called at each diffusion step
207
  def generation_tokens_hook_func(step, x, logits):
208
- nonlocal last_x, visualization_states # Allow modification of outer scope variables
209
- # print(f"Hook called for step {step}") # Keep commented unless debugging
210
 
211
- current_x = x.clone() # Work on a copy for comparison/modification
212
-
213
- # 1. Apply Constraints *before* generating visualization for this step
214
- # Constraints are relative to the start of the *generated* part
215
  constrained_x = current_x.clone()
216
- current_prompt_len = current_x.shape[1] - gen_length # Recalculate actual prompt length
217
- if current_prompt_len < 0:
218
  print("Warning: prompt_len negative in hook, skipping constraints/vis.")
219
- return current_x # Return unmodified if something is wrong
220
 
 
 
221
  for pos, token_id in processed_constraints.items():
222
- absolute_pos = current_prompt_len + pos
223
- if current_prompt_len <= absolute_pos < current_x.shape[1]:
224
- # Apply constraint if the current token doesn't match
225
  if constrained_x[0, absolute_pos] != token_id:
226
  constrained_x[0, absolute_pos] = token_id
227
- # print(f" Constraint applied at pos {pos} ({absolute_pos}) -> token {token_id}")
228
-
229
 
230
  # 2. Generate Visualization State for *this* step
231
- # Compare current_x (output of diffusion for this step, before constraints applied *in this call*)
232
- # with last_x (state from *previous* hook call / initial state, *after* constraints were applied then)
233
  current_state_vis = []
234
- gen_part_current = current_x[0, current_prompt_len:]
235
- gen_part_last = last_x[0, current_prompt_len:] if last_x is not None else None
236
 
237
  for i in range(gen_length):
238
  current_token_id = gen_part_current[i].item()
239
- last_token_id = gen_part_last[i].item() if gen_part_last is not None else MASK_ID # Assume mask initially
240
 
241
- # Determine display string - Handle special tokens explicitly
242
- if current_token_id in SPECIAL_TOKENS_MAP:
243
- display_token = SPECIAL_TOKENS_MAP[current_token_id]
244
- else:
245
- # Decode non-special tokens, skipping special tokens in the *output string*
246
- # and stripping whitespace
247
- display_token = tokenizer.decode([current_token_id],
248
- skip_special_tokens=True,
249
- clean_up_tokenization_spaces=True).strip()
250
- # If decoding results in empty string for a non-special token, use a space perhaps
251
- if not display_token:
252
- display_token = " " # Use a single space as placeholder
253
-
254
-
255
- # Determine category (label) for color mapping
256
- category = "Old" # Default assume it was revealed before
257
  is_constrained = i in processed_constraints
258
 
259
  if current_token_id == MASK_ID:
260
- category = "Mask"
261
  elif is_constrained and processed_constraints[i] == current_token_id:
262
- # Check if it was *just* constrained or already was correct
263
- # We mark as 'Constraint' if it matches the required token, regardless of when it appeared
264
- category = "Constraint"
265
- elif last_token_id == MASK_ID and current_token_id != MASK_ID:
266
- # It was a mask before, now it's not -> Newly revealed
267
- # (Unless it's a constraint, handled above)
268
- category = "New"
269
- # else: category remains "Old"
270
-
271
 
272
- current_state_vis.append((display_token, category))
273
 
274
  visualization_states.append(current_state_vis)
275
 
276
  # 3. Update last_x for the *next* step's comparison
277
- # Store the state *after* applying constraints for accurate comparison next time
278
  last_x = constrained_x.clone()
279
 
280
- # 4. Return the sequence with constraints applied for the model's next step
281
- return constrained_x # Return the sequence with constraints enforced
282
-
283
 
284
  # --- Run DREAM Generation ---
285
  try:
286
  print("Calling model.diffusion_generate...")
287
- # Make sure last_x is initialized correctly before the first hook call
288
- # It should represent the state *before* the first diffusion step.
289
- # Create the initial full sequence (prompt + initial masked/constrained part)
290
  initial_full_x = torch.cat([input_ids, initial_x_part], dim=1)
291
- last_x = initial_full_x.clone() # Initialize last_x with the state before step 0
292
-
293
- # Add the very first visualization state (prompt + initial masks/constraints)
294
- # This state corresponds to the `last_x` *before* the first hook call.
295
- initial_state_vis = []
296
- initial_gen_part = initial_full_x[0, prompt_length:]
297
- for i in range(gen_length):
298
- token_id = initial_gen_part[i].item()
299
- category = "Mask"
300
- display_token = MASK_TOKEN
301
- if token_id != MASK_ID:
302
- # This must be an initial constraint
303
- category = "Constraint"
304
- if token_id in SPECIAL_TOKENS_MAP:
305
- display_token = SPECIAL_TOKENS_MAP[token_id]
306
- else:
307
- display_token = tokenizer.decode([token_id], skip_special_tokens=True).strip()
308
- if not display_token: display_token = " " # Placeholder
309
-
310
- initial_state_vis.append((display_token, category))
311
- visualization_states.append(initial_state_vis)
312
-
313
 
314
  output = model.diffusion_generate(
315
  input_ids,
316
  attention_mask=attention_mask,
317
  max_new_tokens=gen_length,
318
- output_history=False, # We build history in the hook
319
  return_dict_in_generate=True,
320
  steps=steps,
321
  temperature=temperature,
322
  top_p=top_p,
323
  alg=alg,
324
- alg_temp=alg_temp if alg != "origin" else 0.0, # alg_temp only for confidence algs
325
  generation_tokens_hook_func=generation_tokens_hook_func
326
  )
327
  print("model.diffusion_generate finished.")
328
 
329
- # Extract final generated sequence (response part only)
330
  final_sequence = output.sequences[0]
331
  response_token_ids = final_sequence[prompt_length:]
332
 
333
- # Decode the final response, skipping special tokens for the final output text
334
  final_text = tokenizer.decode(
335
  response_token_ids,
336
  skip_special_tokens=True,
@@ -338,92 +297,55 @@ def dream_generate_response_with_visualization(
338
  ).strip()
339
  print(f"Final generated text: {final_text}")
340
 
341
- # The hook should have added the last state, no need for safeguard typically
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
 
344
  except Exception as e:
345
  print(f"Error during generation: {e}")
346
  import traceback
347
  traceback.print_exc()
348
- # Add error message to visualization using the "Error" category
349
  error_msg = f"Error during generation: {str(e)}"
350
- visualization_states.append([("Error", "Error")]) # Use 'Error' category
 
351
  final_text = f"Generation failed: {e}"
352
 
353
  print("--- DREAM Generation Finished ---")
354
- # Return states list (already built by hook) and final text
355
  return visualization_states, final_text
356
 
357
 
358
  # --- Gradio UI Setup ---
359
 
360
  css = '''
361
- /* Hide the default legend */
362
- .gradio-container .output-markdown table { display: none !important; }
363
-
364
- .small_btn {
365
- max-width: 100px; /* Adjust as needed */
366
- min-width: 60px; /* Ensure button doesn't collapse */
367
- height: 40px; /* Adjust as needed */
368
- flex-grow: 0 !important; /* Prevent button from growing */
369
- margin-left: 5px !important; /* Add some space */
370
- margin-top: auto; /* Align button bottom with textbox */
371
- margin-bottom: auto; /* Align button bottom with textbox */
372
- line-height: 1; /* Adjust line height if text vertical align is off */
373
- padding: 0 10px; /* Adjust padding */
374
- }
375
- .chat-input-row {
376
- display: flex;
377
- align-items: center; /* Vertically align items */
378
- margin-bottom: 10px; /* Add space below input row */
379
- }
380
- .chat-input-row > * {
381
- margin-right: 5px; /* Space between textbox and button */
382
- }
383
- .chat-input-row > *:last-child {
384
- margin-right: 0;
385
- }
386
- /* Style HighlightedText elements */
387
- .token-hl span {
388
- padding: 2px 1px; /* Minimal padding */
389
- margin: 0 1px; /* Minimal margin */
390
- border-radius: 3px;
391
- display: inline-block; /* Ensure background covers token */
392
- line-height: 1.2; /* Adjust for better vertical spacing */
393
- }
394
- /* Custom legend styling */
395
- .custom-legend span {
396
- display: inline-block;
397
- margin-right: 15px;
398
- font-size: 0.9em;
399
- }
400
- .custom-legend span::before {
401
- content: "■";
402
- margin-right: 4px;
403
- font-size: 1.1em; /* Make square slightly larger */
404
- vertical-align: middle; /* Align square with text */
405
- }
406
  '''
407
- # Define color map mapping CATEGORY names to colors
408
- color_map = {
409
- "Mask": "#A0A0A0", # Darker Gray for masks
410
- "New": "#77DD77", # Light Green for new tokens
411
- "Old": "#AEC6CF", # Light Blue/Gray for old tokens
412
- "Constraint": "#C3A0E0", # Purple for constraints
413
- "Error": "#FF6961" # Light Red for errors
414
- }
415
-
416
- # Create the custom legend HTML string
417
- legend_html = "<div class='custom-legend'>"
418
- for category, color in color_map.items():
419
- legend_html += f"<span style='color:{color};'>{category}</span>"
420
- legend_html += "</div>"
421
-
422
-
423
  def create_chatbot_demo():
424
  with gr.Blocks(css=css) as demo:
425
  gr.Markdown("# Dream 7B - Diffusion Language Model Demo")
426
- gr.Markdown("A demonstration of the Dream 7B diffusion-based language model. Watch the text generate step-by-step.")
427
  gr.Markdown("[Model Card](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) - [Blog Post](https://hkunlp.github.io/blog/2025/dream/)")
428
 
429
  # STATE MANAGEMENT
@@ -433,108 +355,75 @@ def create_chatbot_demo():
433
  with gr.Row():
434
  with gr.Column(scale=3):
435
  chatbot_ui = gr.Chatbot(
436
- label="Conversation",
437
- height=500,
438
- bubble_full_width=False
439
  )
440
-
441
- # Message input Row
442
  with gr.Row(elem_classes="chat-input-row"):
443
  user_input = gr.Textbox(
444
- label="Your Message",
445
- placeholder="Type your message here and press Enter...",
446
- scale=4,
447
- container=False,
448
- show_label=False
449
  )
450
  send_btn = gr.Button("Send", scale=1, elem_classes="small_btn")
451
 
452
  constraints_input = gr.Textbox(
453
  label="Word Constraints (Optional)",
454
- info="Force specific words at positions (0-indexed from response start). Format: 'pos:word, pos:word'. Example: '0:Once, 5:upon, 10:time'",
455
- placeholder="e.g., 0:Hello, 6:world",
456
- value=""
457
  )
458
  with gr.Column(scale=2):
 
459
  output_vis = gr.HighlightedText(
460
  label="Denoising Process Visualization",
461
- combine_adjacent=False, # Keep tokens separate
462
- show_legend=True, # Hide default legend table
463
- #color_map=color_map, # Provide the color map
464
- #elem_classes="token-hl" # Add class for token styling
 
 
 
 
 
 
 
 
 
465
  )
466
- # Use Markdown to display the custom legend
467
- gr.Markdown(legend_html)
468
 
469
 
470
- # Advanced generation settings
471
  with gr.Accordion("Generation Settings", open=False):
472
  with gr.Row():
473
- gen_length = gr.Slider(
474
- minimum=16, maximum=512, value=128, step=8,
475
- label="Max New Tokens"
476
- )
477
- steps = gr.Slider(
478
- minimum=8, maximum=512, value=128, step=8,
479
- label="Diffusion Steps"
480
- )
481
  with gr.Row():
482
- temperature = gr.Slider(
483
- minimum=0.0, maximum=1.5, value=0.6, step=0.05,
484
- label="Temperature"
485
- )
486
- top_p = gr.Slider(
487
- minimum=0.0, maximum=1.0, value=0.95, step=0.05,
488
- label="Top-P (Nucleus Sampling)"
489
- )
490
  with gr.Row():
491
  remasking_strategy = gr.Radio(
492
- choices=[
493
- ("Random", "origin"),
494
- ("Entropy", "entropy"),
495
- ("MaskGit+", "maskgit_plus"),
496
- ("TopK Margin", "topk_margin"),
497
- ],
498
- value="entropy",
499
- label="Generation Order Strategy (alg)"
500
  )
501
  alg_temp = gr.Slider(
502
- minimum=0.0, maximum=1.0, value=0.1, step=0.05,
503
- label="Order Randomness (alg_temp)" ,
504
  info="Adds randomness to non-Random strategies. Ignored for Random."
505
  )
506
-
507
  with gr.Row():
508
- visualization_delay = gr.Slider(
509
- minimum=0.0, maximum=0.5, value=0.05, step=0.01,
510
- label="Visualization Delay (seconds)"
511
- )
512
 
513
- # Clear button
514
  clear_btn = gr.Button("Clear Conversation")
515
 
516
- # --- Event Handlers ---
517
-
518
- # Helper to add message to history state
519
  def add_message_to_history(history, message, response):
520
- history = history.copy() # Modify copy
521
- history.append([message, response])
522
- return history
523
 
524
- # Function when user submits message (Enter or Send button)
525
  def user_message_submitted(message, history):
526
  print(f"User submitted: '{message}'")
527
  if not message or not message.strip():
528
- print("Empty message submitted, doing nothing.")
529
- return history, history, "", [] # history, chatbot_ui, user_input, output_vis
530
-
531
  history = add_message_to_history(history, message, None)
532
  history_for_display = history.copy()
533
- message_out = ""
534
- vis_clear = [] # Clear visualization when new message submitted
535
  return history, history_for_display, message_out, vis_clear
536
 
537
- # Function to generate bot response (triggered after user message is processed)
538
  def bot_response_generator(
539
  history, gen_length, steps, constraints_text, delay,
540
  temperature, top_p, alg, alg_temp
@@ -550,91 +439,39 @@ def create_chatbot_demo():
550
 
551
  try:
552
  vis_states, response_text = dream_generate_response_with_visualization(
553
- messages,
554
- gen_length=gen_length,
555
- steps=steps,
556
- constraints=parsed_constraints,
557
- temperature=temperature,
558
- top_p=top_p,
559
- alg=alg,
560
- alg_temp=alg_temp
561
  )
 
562
 
563
- # Update the history state only ONCE with the final bot response
564
- final_history = history.copy() # Create copy to modify
565
- final_history[-1][1] = response_text.strip() # Update the last element
566
-
567
- # Yield visualization states one by one
568
- # Important: Yield the *original* history for all intermediate steps,
569
- # only yield the final_history with the *last* visualization state.
570
- num_states = len(vis_states)
571
- for i, state in enumerate(vis_states):
572
- current_chatbot_state = history if i < num_states - 1 else final_history
573
- yield current_chatbot_state, state
574
- if delay > 0 and i < num_states - 1: # Don't sleep after last state
575
  time.sleep(delay)
 
 
 
576
 
577
  except Exception as e:
578
  print(f"Error in bot_response_generator: {e}")
579
- import traceback
580
- traceback.print_exc()
581
  error_msg = f"Error: {str(e)}"
582
- error_vis = [(error_msg, "Error")] # Use Error category
583
- # Update history with error message? Optional.
584
- final_history_error = history.copy()
585
- final_history_error[-1][1] = error_msg # Add error to chatbot too
586
- yield final_history_error, error_vis
587
 
588
- # Function to clear everything
589
  def clear_conversation():
590
- print("Clearing conversation.")
591
- return [], [], "", [] # chat_history, chatbot_ui, user_input, output_vis
592
 
593
- # --- Wire UI elements to functions ---
 
 
594
 
595
- # Typing in Textbox and pressing Enter
596
- submit_event = user_input.submit(
597
- fn=user_message_submitted,
598
- inputs=[user_input, chat_history],
599
- outputs=[chat_history, chatbot_ui, user_input, output_vis],
600
- queue=False # Show user message immediately
601
- )
602
-
603
- # Clicking the Send button
604
- click_event = send_btn.click(
605
- fn=user_message_submitted,
606
- inputs=[user_input, chat_history],
607
- outputs=[chat_history, chatbot_ui, user_input, output_vis],
608
- queue=False
609
- )
610
 
611
- # Chain the generation after user message is processed (for both submit and click)
612
- # Use .then() to trigger the generator
613
- generation_inputs = [
614
- chat_history, gen_length, steps, constraints_input, visualization_delay,
615
- temperature, top_p, remasking_strategy, alg_temp
616
- ]
617
- generation_outputs = [chatbot_ui, output_vis]
618
-
619
- submit_event.then(
620
- fn=bot_response_generator,
621
- inputs=generation_inputs,
622
- outputs=generation_outputs
623
- )
624
-
625
- click_event.then(
626
- fn=bot_response_generator,
627
- inputs=generation_inputs,
628
- outputs=generation_outputs
629
- )
630
-
631
- # Clicking the Clear button
632
- clear_btn.click(
633
- fn=clear_conversation,
634
- inputs=[],
635
- outputs=[chat_history, chatbot_ui, user_input, output_vis],
636
- queue=False
637
- )
638
 
639
  return demo
640
 
@@ -643,6 +480,4 @@ if __name__ == "__main__":
643
  print("Creating Gradio demo...")
644
  demo = create_chatbot_demo()
645
  print("Launching Gradio demo...")
646
- # Use queue for potentially long generation times
647
- # share=True generates a public link (useful for Colab/Spaces)
648
- demo.queue().launch(share=True, debug=True) # Add debug=True for more logs
 
1
+ # llada_app.py -> dream_app.py (v2)
2
 
3
  import torch
4
  import numpy as np
 
11
 
12
  # Use try-except for space deployment vs local
13
  try:
 
14
  gpu_check = spaces.GPU
15
  print("Running in Gradio Spaces with GPU environment.")
16
  except AttributeError:
 
17
  print("Running in local environment or without spaces.GPU.")
18
+ def gpu_check(func): return func
 
 
19
 
20
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
21
  print(f"Using device: {device}")
 
23
  # --- Load DREAM Model and Tokenizer ---
24
  model_path = "Dream-org/Dream-v0-Instruct-7B"
25
  print(f"Loading model: {model_path}")
26
+ try:
27
+ model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True).to(device).eval()
28
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
29
+ print("Model and tokenizer loaded.")
30
+ except Exception as e:
31
+ print(f"FATAL: Could not load model/tokenizer. Error: {e}")
32
+ # Optionally exit or raise
33
+ raise SystemExit(f"Failed to load model: {e}")
34
+
35
 
36
  # --- Constants for DREAM ---
37
+ # Find mask token and ID
38
  if tokenizer.mask_token is None:
39
+ print("Warning: Mask token not explicitly set in tokenizer. Trying to add '[MASK]'.")
40
+ # This might require retraining/fine-tuning if the model didn't see it.
41
+ # Check if it exists first before adding
42
+ if '[MASK]' not in tokenizer.get_vocab():
43
+ tokenizer.add_special_tokens({'mask_token': '[MASK]'})
44
+ model.resize_token_embeddings(len(tokenizer)) # Resize model embeddings
45
+ print("Added '[MASK]' and resized embeddings.")
46
+ else:
47
+ tokenizer.mask_token = '[MASK]' # Set it if it exists but wasn't assigned
48
+ print("Found existing '[MASK]', assigned as mask_token.")
49
 
50
  MASK_TOKEN = tokenizer.mask_token
51
  MASK_ID = tokenizer.mask_token_id
52
+ if MASK_ID is None:
53
+ raise ValueError("Failed to get MASK_ID after attempting to set mask_token.")
54
+ print(f"Using MASK_TOKEN='{MASK_TOKEN}' with ID={MASK_ID}")
 
 
 
 
 
 
 
 
 
55
 
56
+ # Get EOS and PAD token IDs
57
+ EOS_TOKEN_ID = tokenizer.eos_token_id
58
+ PAD_TOKEN_ID = tokenizer.pad_token_id
59
+ print(f"Using EOS_TOKEN_ID={EOS_TOKEN_ID}, PAD_TOKEN_ID={PAD_TOKEN_ID}")
60
+ # Handle cases where they might be None (though unlikely for most models)
61
+ if EOS_TOKEN_ID is None:
62
+ print("Warning: EOS token ID not found.")
63
+ if PAD_TOKEN_ID is None:
64
+ print("Warning: PAD token ID not found. Using EOS ID as fallback for hiding.")
65
+ PAD_TOKEN_ID = EOS_TOKEN_ID # Use EOS as a fallback for hiding logic if PAD is missing
66
 
 
 
67
 
68
  # --- Helper Functions (Constraint Parsing, History Formatting) ---
69
+ # (Keep parse_constraints and format_chat_history functions as they were)
70
  def parse_constraints(constraints_text):
71
  """Parse constraints in format: 'position:word, position:word, ...'"""
72
  constraints = {}
 
114
 
115
  # --- Core Generation Logic for DREAM with Visualization ---
116
 
117
+ @gpu_check
118
  def dream_generate_response_with_visualization(
119
  messages,
120
  gen_length=64,
121
+ steps=64,
122
  constraints=None,
123
+ temperature=0.6,
124
+ top_p=0.95,
125
+ alg="entropy",
126
+ alg_temp=0.0,
127
  ):
128
  """
129
  Generate text with DREAM model with visualization using the generation hook.
130
+ Hides special tokens (EOS, PAD) and uses labels for coloring.
 
 
 
 
 
 
 
 
 
 
 
 
131
  """
132
  print("--- Starting DREAM Generation ---")
133
  print(f"Parameters: gen_length={gen_length}, steps={steps}, temperature={temperature}, top_p={top_p}, alg='{alg}', alg_temp={alg_temp}")
134
  print(f"Constraints: {constraints}")
135
 
136
  # --- Input Preparation ---
137
+ if constraints is None: constraints = {}
 
138
 
 
139
  processed_constraints = {}
140
  print("Processing constraints:")
141
  for pos, word in constraints.items():
 
 
142
  tokens = tokenizer.encode(" " + word, add_special_tokens=False)
143
  if not tokens:
144
  print(f" Warning: Could not tokenize constraint word '{word}' at position {pos}. Skipping.")
145
  continue
146
  print(f" Pos {pos}, Word '{word}' -> Tokens {tokens}")
147
  for i, token_id in enumerate(tokens):
 
148
  if pos + i not in processed_constraints:
149
  processed_constraints[pos + i] = token_id
150
  else:
151
  print(f" Warning: Overlapping constraint at position {pos+i}. Keeping first.")
152
 
 
153
  try:
154
  inputs = tokenizer.apply_chat_template(
155
+ messages, return_tensors="pt", return_dict=True, add_generation_prompt=True
 
 
 
156
  )
157
  input_ids = inputs.input_ids.to(device=device)
158
+ attention_mask = inputs.attention_mask.to(device=device)
159
  prompt_length = input_ids.shape[1]
160
  print(f"Input prompt length: {prompt_length}")
 
161
  except Exception as e:
162
  print(f"Error applying chat template: {e}")
163
+ return [([("Error applying chat template.", "Error")],)], f"Error: {e}" # Use 'Error' label
164
 
165
+ # Check context length (DREAM uses 2048)
166
+ if prompt_length + gen_length > 2048:
167
  print(f"Warning: Requested length ({prompt_length + gen_length}) exceeds model max length (2048). Truncating gen_length.")
168
  gen_length = 2048 - prompt_length
169
  if gen_length <= 0:
170
  print("Error: Prompt is already too long.")
171
  return [([("Prompt too long.", "Error")],)], "Error: Prompt too long."
172
 
 
173
  # --- State for Visualization Hook ---
174
  visualization_states = []
175
+ last_x = None
176
 
177
+ # Initial state: Prompt + all masks + initial constraints
178
  initial_x_part = torch.full((1, gen_length), MASK_ID, dtype=torch.long, device=device)
 
179
  for pos, token_id in processed_constraints.items():
180
+ absolute_pos = pos
181
  if 0 <= absolute_pos < gen_length:
182
+ initial_x_part[0, absolute_pos] = token_id
183
+
184
+ initial_state_vis = []
185
+ for i in range(gen_length):
186
+ token_id = initial_x_part[0, i].item()
187
+ if token_id == MASK_ID:
188
+ initial_state_vis.append((MASK_TOKEN, "Mask"))
189
+ elif token_id == EOS_TOKEN_ID or token_id == PAD_TOKEN_ID:
190
+ initial_state_vis.append(("", None)) # Hide special tokens
191
+ elif i in processed_constraints and processed_constraints[i] == token_id:
192
+ token_str = tokenizer.decode([token_id], skip_special_tokens=True).strip()
193
+ display_token = token_str if token_str else "?"
194
+ initial_state_vis.append((display_token, "Constraint"))
195
+ else:
196
+ # Should only be constraints here, but add fallback
197
+ token_str = tokenizer.decode([token_id], skip_special_tokens=True).strip()
198
+ display_token = token_str if token_str else "?"
199
+ initial_state_vis.append((display_token, "Old")) # Treat unexpected initial non-masks as 'Old'
200
+ visualization_states.append(initial_state_vis)
201
 
202
 
203
  # --- Define the Hook Function ---
 
204
  def generation_tokens_hook_func(step, x, logits):
205
+ nonlocal last_x, visualization_states
206
+ # print(f"Hook called for step {step}") # Verbose logging
207
 
208
+ current_x = x.clone()
 
 
 
209
  constrained_x = current_x.clone()
210
+ prompt_len = current_x.shape[1] - gen_length
211
+ if prompt_len < 0:
212
  print("Warning: prompt_len negative in hook, skipping constraints/vis.")
213
+ return current_x
214
 
215
+ # 1. Apply Constraints
216
+ constraints_applied_this_step = False
217
  for pos, token_id in processed_constraints.items():
218
+ absolute_pos = prompt_len + pos
219
+ if prompt_len <= absolute_pos < current_x.shape[1]:
 
220
  if constrained_x[0, absolute_pos] != token_id:
221
  constrained_x[0, absolute_pos] = token_id
222
+ constraints_applied_this_step = True
 
223
 
224
  # 2. Generate Visualization State for *this* step
 
 
225
  current_state_vis = []
226
+ gen_part_current = current_x[0, prompt_len:]
227
+ gen_part_last = last_x[0, prompt_len:] if last_x is not None else None
228
 
229
  for i in range(gen_length):
230
  current_token_id = gen_part_current[i].item()
 
231
 
232
+ # --- Logic to Hide Special Tokens ---
233
+ if current_token_id == EOS_TOKEN_ID or current_token_id == PAD_TOKEN_ID:
234
+ # Maybe show on first appearance? For now, always hide.
235
+ # LLaDA's behavior: "shown once and then disappear"
236
+ # Let's implement the simpler "always hide" first.
237
+ current_state_vis.append(("", None)) # Append empty string, no label -> hidden
238
+ continue # Move to next token
239
+
240
+ # --- Decode and Determine Label ---
241
+ token_str = tokenizer.decode([current_token_id], skip_special_tokens=True).strip()
242
+ display_token = token_str if token_str else MASK_TOKEN if current_token_id == MASK_ID else "?" # Use MASK_TOKEN if decode fails
243
+
244
+ label = None # Default label (no color)
 
 
 
245
  is_constrained = i in processed_constraints
246
 
247
  if current_token_id == MASK_ID:
248
+ label = "Mask"
249
  elif is_constrained and processed_constraints[i] == current_token_id:
250
+ label = "Constraint"
251
+ elif gen_part_last is None or gen_part_last[i].item() == MASK_ID or gen_part_last[i].item() == EOS_TOKEN_ID or gen_part_last[i].item() == PAD_TOKEN_ID:
252
+ # Newly revealed (was mask or hidden special token in previous step)
253
+ label = "New"
254
+ else:
255
+ # Previously revealed and not masked/hidden/constrained
256
+ label = "Old"
 
 
257
 
258
+ current_state_vis.append((display_token, label))
259
 
260
  visualization_states.append(current_state_vis)
261
 
262
  # 3. Update last_x for the *next* step's comparison
 
263
  last_x = constrained_x.clone()
264
 
265
+ # 4. Return the sequence with constraints applied
266
+ return constrained_x
 
267
 
268
  # --- Run DREAM Generation ---
269
  try:
270
  print("Calling model.diffusion_generate...")
 
 
 
271
  initial_full_x = torch.cat([input_ids, initial_x_part], dim=1)
272
+ last_x = initial_full_x.clone() # Initialize last_x *before* the call
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
  output = model.diffusion_generate(
275
  input_ids,
276
  attention_mask=attention_mask,
277
  max_new_tokens=gen_length,
278
+ output_history=False,
279
  return_dict_in_generate=True,
280
  steps=steps,
281
  temperature=temperature,
282
  top_p=top_p,
283
  alg=alg,
284
+ alg_temp=alg_temp if alg != "origin" else 0.0,
285
  generation_tokens_hook_func=generation_tokens_hook_func
286
  )
287
  print("model.diffusion_generate finished.")
288
 
 
289
  final_sequence = output.sequences[0]
290
  response_token_ids = final_sequence[prompt_length:]
291
 
292
+ # Decode final text, skipping special tokens
293
  final_text = tokenizer.decode(
294
  response_token_ids,
295
  skip_special_tokens=True,
 
297
  ).strip()
298
  print(f"Final generated text: {final_text}")
299
 
300
+ # Safeguard: Add final state visualization if needed (using the new label logic)
301
+ if len(visualization_states) <= steps:
302
+ final_state_vis = []
303
+ final_gen_part = final_sequence[prompt_length:]
304
+ for i in range(gen_length):
305
+ token_id = final_gen_part[i].item()
306
+ if token_id == EOS_TOKEN_ID or token_id == PAD_TOKEN_ID:
307
+ final_state_vis.append(("", None))
308
+ continue
309
+
310
+ token_str = tokenizer.decode([token_id], skip_special_tokens=True).strip()
311
+ display_token = token_str if token_str else MASK_TOKEN if token_id == MASK_ID else "?"
312
+ label = None
313
+ is_constrained = i in processed_constraints
314
+
315
+ if token_id == MASK_ID: label = "Mask"
316
+ elif is_constrained and processed_constraints[i] == token_id: label = "Constraint"
317
+ else: label = "Old" # Default to 'Old' for final state non-masked tokens
318
+ final_state_vis.append((display_token, label))
319
+ visualization_states.append(final_state_vis)
320
 
321
 
322
  except Exception as e:
323
  print(f"Error during generation: {e}")
324
  import traceback
325
  traceback.print_exc()
 
326
  error_msg = f"Error during generation: {str(e)}"
327
+ # Use 'Error' label for color mapping
328
+ visualization_states.append([("Error", "Error")])
329
  final_text = f"Generation failed: {e}"
330
 
331
  print("--- DREAM Generation Finished ---")
 
332
  return visualization_states, final_text
333
 
334
 
335
  # --- Gradio UI Setup ---
336
 
337
  css = '''
338
+ .category-legend{display:none}
339
+ /* button{height: 60px} */
340
+ .small_btn {max-width: 100px; height: 40px; flex-grow: 0; margin-left: 5px;}
341
+ .chat-input-row {display: flex; align-items: center;}
342
+ .chat-input-row > * {margin-right: 5px;}
343
+ .chat-input-row > *:last-child {margin-right: 0;}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  def create_chatbot_demo():
346
  with gr.Blocks(css=css) as demo:
347
  gr.Markdown("# Dream 7B - Diffusion Language Model Demo")
348
+ gr.Markdown("Watch the text generate step-by-step. Special tokens (EOS, PAD) are hidden.")
349
  gr.Markdown("[Model Card](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) - [Blog Post](https://hkunlp.github.io/blog/2025/dream/)")
350
 
351
  # STATE MANAGEMENT
 
355
  with gr.Row():
356
  with gr.Column(scale=3):
357
  chatbot_ui = gr.Chatbot(
358
+ label="Conversation", height=500, bubble_full_width=False
 
 
359
  )
 
 
360
  with gr.Row(elem_classes="chat-input-row"):
361
  user_input = gr.Textbox(
362
+ label="Your Message", placeholder="Type your message...",
363
+ scale=4, container=False, show_label=False
 
 
 
364
  )
365
  send_btn = gr.Button("Send", scale=1, elem_classes="small_btn")
366
 
367
  constraints_input = gr.Textbox(
368
  label="Word Constraints (Optional)",
369
+ info="Format: 'pos:word, pos:word'. Example: '0:Once, 5:upon'",
370
+ placeholder="e.g., 0:Hello, 6:world", value=""
 
371
  )
372
  with gr.Column(scale=2):
373
+ # --- Updated HighlightedText with color_map ---
374
  output_vis = gr.HighlightedText(
375
  label="Denoising Process Visualization",
376
+ combine_adjacent=True, # Combine adjacent tokens with same label
377
+ show_legend=False, # Keep legend off
378
+ color_map={ # Map labels to colors
379
+ "Mask": "#A0A0A0", # Lighter Gray for Mask
380
+ "New": "#66CC66", # Light Green
381
+ "Old": "#6699CC", # Light Blue
382
+ "Constraint": "#B266FF", # Lighter Purple/Violet
383
+ "Error": "#FF6666" # Light Red
384
+ }
385
+ )
386
+ gr.Markdown(
387
+ # Update legend text to match labels
388
+ "**Color Legend:** <span style='color:#A0A0A0'>■ Mask</span> | <span style='color:#66CC66'>■ New</span> | <span style='color:#6699CC'>■ Old</span> | <span style='color:#B266FF'>■ Constraint</span>"
389
  )
 
 
390
 
391
 
392
+ # Advanced generation settings (Keep as before)
393
  with gr.Accordion("Generation Settings", open=False):
394
  with gr.Row():
395
+ gen_length = gr.Slider(minimum=16, maximum=512, value=128, step=8, label="Max New Tokens")
396
+ steps = gr.Slider(minimum=8, maximum=512, value=128, step=8, label="Diffusion Steps")
 
 
 
 
 
 
397
  with gr.Row():
398
+ temperature = gr.Slider(minimum=0.0, maximum=1.5, value=0.6, step=0.05, label="Temperature")
399
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="Top-P (Nucleus Sampling)")
 
 
 
 
 
 
400
  with gr.Row():
401
  remasking_strategy = gr.Radio(
402
+ choices=[("Random", "origin"), ("Entropy", "entropy"), ("MaskGit+", "maskgit_plus"), ("TopK Margin", "topk_margin")],
403
+ value="entropy", label="Generation Order Strategy (alg)"
 
 
 
 
 
 
404
  )
405
  alg_temp = gr.Slider(
406
+ minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="Order Randomness (alg_temp)",
 
407
  info="Adds randomness to non-Random strategies. Ignored for Random."
408
  )
 
409
  with gr.Row():
410
+ visualization_delay = gr.Slider(minimum=0.0, maximum=0.5, value=0.05, step=0.01, label="Visualization Delay (seconds)")
 
 
 
411
 
 
412
  clear_btn = gr.Button("Clear Conversation")
413
 
414
+ # --- Event Handlers (Keep as before) ---
 
 
415
  def add_message_to_history(history, message, response):
416
+ history = history.copy(); history.append([message, response]); return history
 
 
417
 
 
418
  def user_message_submitted(message, history):
419
  print(f"User submitted: '{message}'")
420
  if not message or not message.strip():
421
+ print("Empty message submitted, doing nothing."); return history, history, "", []
 
 
422
  history = add_message_to_history(history, message, None)
423
  history_for_display = history.copy()
424
+ message_out = ""; vis_clear = []
 
425
  return history, history_for_display, message_out, vis_clear
426
 
 
427
  def bot_response_generator(
428
  history, gen_length, steps, constraints_text, delay,
429
  temperature, top_p, alg, alg_temp
 
439
 
440
  try:
441
  vis_states, response_text = dream_generate_response_with_visualization(
442
+ messages, gen_length=gen_length, steps=steps, constraints=parsed_constraints,
443
+ temperature=temperature, top_p=top_p, alg=alg, alg_temp=alg_temp
 
 
 
 
 
 
444
  )
445
+ history[-1][1] = response_text.strip() # Update history state
446
 
447
+ if vis_states:
448
+ # Yield initial state first
449
+ yield history, vis_states[0] # Update chatbot, update visualization
450
+ # Animate remaining states
451
+ for state in vis_states[1:]:
 
 
 
 
 
 
 
452
  time.sleep(delay)
453
+ yield history, state # Update chatbot (implicitly), update visualization
454
+ else:
455
+ yield history, [("Generation failed.", "Error")] # Use label
456
 
457
  except Exception as e:
458
  print(f"Error in bot_response_generator: {e}")
459
+ import traceback; traceback.print_exc()
 
460
  error_msg = f"Error: {str(e)}"
461
+ error_vis = [(error_msg, "Error")] # Use label
462
+ yield history, error_vis
 
 
 
463
 
 
464
  def clear_conversation():
465
+ print("Clearing conversation."); return [], [], "", []
 
466
 
467
+ # --- Wire UI elements (Keep as before) ---
468
+ user_input.submit(fn=user_message_submitted, inputs=[user_input, chat_history], outputs=[chat_history, chatbot_ui, user_input, output_vis], queue=False)\
469
+ .then(fn=bot_response_generator, inputs=[history, gen_length, steps, constraints_input, visualization_delay, temperature, top_p, remasking_strategy, alg_temp], outputs=[chatbot_ui, output_vis])
470
 
471
+ send_btn.click(fn=user_message_submitted, inputs=[user_input, chat_history], outputs=[chat_history, chatbot_ui, user_input, output_vis], queue=False)\
472
+ .then(fn=bot_response_generator, inputs=[history, gen_length, steps, constraints_input, visualization_delay, temperature, top_p, remasking_strategy, alg_temp], outputs=[chatbot_ui, output_vis])
 
 
 
 
 
 
 
 
 
 
 
 
 
473
 
474
+ clear_btn.click(fn=clear_conversation, inputs=[], outputs=[chat_history, chatbot_ui, user_input, output_vis], queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
 
476
  return demo
477
 
 
480
  print("Creating Gradio demo...")
481
  demo = create_chatbot_demo()
482
  print("Launching Gradio demo...")
483
+ demo.queue().launch(share=True, debug=True)