multimodalart HF Staff commited on
Commit
168a7c1
·
verified ·
1 Parent(s): 5713ed1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +614 -420
app.py CHANGED
@@ -3,58 +3,31 @@ import torch
3
  import numpy as np
4
  import gradio as gr
5
  import spaces
 
 
 
6
  import time
7
  import re
8
- from transformers import AutoModel, AutoTokenizer
9
- from threading import Lock
10
- from queue import Queue
11
-
12
- # --- Configuration ---
13
- MODEL_PATH = "Dream-org/Dream-v0-Instruct-7B"
14
- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
15
- print(f"Using device: {DEVICE}")
16
-
17
- # --- Load Model and Tokenizer ---
18
- print("Loading model and tokenizer...")
19
- # Need configuration files for trust_remote_code
20
- # Make sure config.json, configuration_dream.py, modeling_dream.py,
21
- # generation_utils.py, generation_config.json are in the same directory
22
- # or accessible in the Hugging Face cache.
23
- model = AutoModel.from_pretrained(
24
- MODEL_PATH,
25
- torch_dtype=torch.bfloat16,
26
- trust_remote_code=True
27
- ).to(DEVICE).eval()
28
- tokenizer = AutoTokenizer.from_pretrained(
29
- MODEL_PATH,
30
- trust_remote_code=True
31
- )
32
- print("Model and tokenizer loaded.")
33
 
34
  # --- Constants ---
35
- # Get IDs from tokenizer/config if possible, otherwise hardcode from provided files
36
- MASK_TOKEN = tokenizer.mask_token # Should be "<|mask|>"
37
- try:
38
- MASK_ID = tokenizer.mask_token_id # Should be 151666
39
- if MASK_ID is None: raise AttributeError # Handle case where it might not be set directly
40
- except AttributeError:
41
- print("Warning: Could not directly get mask_token_id, using hardcoded value 151666.")
42
- MASK_ID = 151666
43
-
44
- try:
45
- EOS_ID = tokenizer.eos_token_id # Should be 151643
46
- PAD_ID = tokenizer.pad_token_id # Should be 151643
47
- if EOS_ID is None or PAD_ID is None: raise AttributeError
48
- except AttributeError:
49
- print("Warning: Could not directly get eos/pad_token_id, using hardcoded value 151643.")
50
- EOS_ID = 151643
51
- PAD_ID = 151643
52
-
53
- # Ensure MASK_TOKEN and MASK_ID are valid
54
- if MASK_TOKEN is None or MASK_ID is None:
55
- raise ValueError("Mask token or ID is not defined correctly.")
56
- if EOS_ID is None or PAD_ID is None:
57
- raise ValueError("EOS/PAD token or ID is not defined correctly.")
58
 
59
  # --- Helper Functions ---
60
 
@@ -71,13 +44,18 @@ def parse_constraints(constraints_text):
71
  try:
72
  pos_str, word = part.split(':', 1)
73
  pos = int(pos_str.strip())
 
74
  word = word.strip()
75
  if word and pos >= 0:
76
  # Tokenize the word - handle potential multi-token words
77
- # Add space prefix for consistency, similar to how model might see words mid-sentence
78
- tokens = tokenizer.encode(" " + word, add_special_tokens=False)
 
79
  for i, token_id in enumerate(tokens):
80
- constraints[pos + i] = token_id
 
 
 
81
  except ValueError:
82
  continue
83
  except Exception as e:
@@ -86,280 +64,459 @@ def parse_constraints(constraints_text):
86
 
87
  return constraints
88
 
 
89
  def format_chat_history(history):
90
  """
91
- Format chat history for the Dream model using its chat template logic.
92
 
93
  Args:
94
  history: List of [user_message, assistant_message] pairs
95
 
96
  Returns:
97
- Formatted list of message dictionaries for the model
98
  """
99
  messages = []
100
- # Add system prompt if history is empty or doesn't start with system
101
- if not history or history[0][0].lower() != 'system':
102
- # Check if the tokenizer's template expects an explicit system message
103
- # The template provided in tokenizer_config.json handles adding a default one
104
- pass # Let apply_chat_template handle the default system message
105
-
106
- for user_msg, assistant_msg in history:
107
- if user_msg: # Handle potential initial system message possibility if needed
108
- messages.append({"role": "user", "content": user_msg})
109
  if assistant_msg is not None: # Skip if None (for the latest user message)
110
  messages.append({"role": "assistant", "content": assistant_msg})
111
 
112
  return messages
113
 
114
- # --- Core Generation Logic with Visualization ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- # Use a thread-safe queue to pass visualization states from the hook
117
- vis_queue = Queue()
118
- # Lock to prevent race conditions when accessing shared state like previous_x
119
- state_lock = Lock()
120
- # Store the previous state for comparison in the hook
121
- previous_x_shared = None
122
 
123
  @spaces.GPU
124
- def generate_response_with_visualization(
125
- messages, # List of message dicts from format_chat_history
126
- max_new_tokens=64,
127
- steps=64, # Default steps based on README example
128
- constraints=None,
129
- temperature=0.6, # Default from demo_token_control
130
- top_p=0.95, # Default from demos
131
- alg="entropy", # Default from demos
132
- alg_temp=0.1, # Default from demo_multiturn_chat
133
- ):
134
  """
135
- Generate text with Dream model and capture visualization states using a hook.
 
136
 
137
  Args:
138
- messages: List of message dictionaries with 'role' and 'content'.
139
- max_new_tokens: Max tokens to generate.
140
- steps: Diffusion steps.
141
- constraints: Dictionary mapping positions (relative to response start) to token IDs.
142
- temperature: Sampling temperature.
143
- top_p: Nucleus sampling p.
144
- alg: Remasking algorithm ('origin', 'entropy', 'maskgit_plus', 'topk_margin').
145
- alg_temp: Temperature for confidence-based algorithms.
 
 
146
 
147
  Returns:
148
- Tuple: (List of visualization states, final generated text string)
149
  """
150
- global previous_x_shared, vis_queue
151
  if constraints is None:
152
- constraints = {}
153
 
154
- visualization_states = []
155
-
156
- # Clear the queue for a new generation
157
- while not vis_queue.empty():
158
- try:
159
- vis_queue.get_nowait()
160
- except Queue.Empty:
161
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
- # Prepare the prompt using chat template
164
- # The template automatically adds the generation prompt like "<|im_start|>assistant\n"
165
- try:
166
- inputs = tokenizer.apply_chat_template(
167
- messages,
168
- return_tensors="pt",
169
- add_generation_prompt=True,
170
- return_dict=True
171
- )
172
- input_ids = inputs.input_ids.to(device=DEVICE)
173
- # Dream doesn't seem to explicitly use attention_mask in simple demos,
174
- # but it's good practice if padding were involved.
175
- # For now, assume no padding in this interactive demo.
176
- attention_mask = inputs.attention_mask.to(device=DEVICE) if 'attention_mask' in inputs else None
177
-
178
- except Exception as e:
179
- print(f"Error applying chat template: {e}")
180
- # Provide a fallback or error state
181
- error_state = [("Error in chat formatting.", "red")]
182
- return [error_state], f"Error: Could not format chat history. {e}"
183
 
184
- prompt_length = input_ids.shape[1]
185
- total_length = prompt_length + max_new_tokens
186
-
187
- # --- Define the Hook Function ---
188
- def generation_tokens_hook_func(step, x, logits):
189
- global previous_x_shared, vis_queue
190
- with state_lock: # Ensure thread safety if needed, though hooks might run sequentially
191
- current_x = x.clone() # Shape: (batch_size, total_length)
192
-
193
- # --- Apply Constraints ---
194
- # Constraints are relative to the start of the *response*
195
- for rel_pos, token_id in constraints.items():
196
- abs_pos = prompt_length + rel_pos
197
- if 0 <= abs_pos < current_x.shape[1]:
198
- # Ensure constraint application doesn't go out of bounds
199
- # Apply constraint for the first batch element (batch size is 1 here)
200
- current_x[0, abs_pos] = token_id
201
-
202
- # --- Create Visualization State ---
203
- current_vis_state = []
204
- x_response = current_x[0, prompt_length:] # Get the response part for batch 0
205
- prev_x_response = previous_x_shared[0, prompt_length:] if previous_x_shared is not None else None
206
-
207
- for i in range(max_new_tokens):
208
- current_token_id = x_response[i].item()
209
- token_str = tokenizer.decode([current_token_id], skip_special_tokens=False) # Keep special tokens for vis
210
-
211
- # Clean up visual representation of special tokens
212
- if token_str == tokenizer.eos_token or token_str == tokenizer.pad_token:
213
- token_str = "[EOS/PAD]" # Make it visually distinct
214
- elif token_str == tokenizer.mask_token:
215
- token_str = "[MASK]"
216
- elif token_str.strip() == "": # Handle empty strings from decoding potentially odd tokens
217
- token_str = "[UNK/SPACE]"
218
-
219
-
220
- color = "#DDDDDD" # Default background
221
-
222
- if current_token_id == MASK_ID:
223
- color = "#444444" # Dark gray for masks
224
- elif prev_x_response is not None and prev_x_response[i].item() == MASK_ID:
225
- # Token was mask, now it's revealed in this step
226
- # Use green for newly revealed
227
- color = "#66CC66" # Light green
228
- else:
229
- # Token was already revealed in a previous step or is a constraint
230
- # Check if it's a constraint applied *now*
231
- is_constraint = (prompt_length + i - prompt_length) in constraints and \
232
- constraints[prompt_length + i - prompt_length] == current_token_id
233
-
234
- if is_constraint:
235
- color = "#FFD700" # Gold for constraints
236
- else:
237
- color = "#6699CC" # Light blue for previously revealed
238
-
239
- current_vis_state.append((token_str, color))
240
-
241
- # --- Update shared state and put vis state in queue ---
242
- previous_x_shared = current_x.clone() # Update for the *next* step's comparison
243
- vis_queue.put(current_vis_state)
244
-
245
- # The hook must return the potentially modified tensor `x`
246
- return current_x
247
- # --- End of Hook Function ---
248
-
249
- # Initialize previous_x_shared before generation starts
250
- # Create initial masked state for visualization
251
- initial_x = input_ids.clone()
252
- if initial_x.shape[1] < total_length:
253
- padding = torch.full((1, total_length - initial_x.shape[1]), MASK_ID, dtype=torch.long, device=DEVICE)
254
- initial_x = torch.cat([initial_x, padding], dim=1)
255
- else:
256
- initial_x = initial_x[:, :total_length] # Truncate if prompt is too long
257
 
258
- # Apply initial constraints to the starting state
259
  for rel_pos, token_id in constraints.items():
260
  abs_pos = prompt_length + rel_pos
261
- if 0 <= abs_pos < initial_x.shape[1]:
262
- initial_x[0, abs_pos] = token_id
263
-
264
- with state_lock:
265
- previous_x_shared = initial_x.clone()
266
-
267
- # Add the initial all-masked state (or with constraints) to the visualization queue
268
- initial_vis_state = []
269
- initial_x_response = initial_x[0, prompt_length:]
270
- for i in range(max_new_tokens):
271
- token_id = initial_x_response[i].item()
272
- if token_id == MASK_ID:
273
- initial_vis_state.append((MASK_TOKEN, "#444444"))
274
- else:
275
- # Must be a pre-applied constraint
276
- token_str = tokenizer.decode([token_id], skip_special_tokens=False)
277
- if token_str == tokenizer.eos_token or token_str == tokenizer.pad_token:
278
- token_str = "[EOS/PAD]"
279
- elif token_str.strip() == "":
280
- token_str = "[UNK/SPACE]"
281
- initial_vis_state.append((token_str, "#FFD700")) # Gold for constraints
282
- vis_queue.put(initial_vis_state)
283
-
284
-
285
- # --- Run Generation ---
286
- try:
287
- # output_history=False because the hook handles state capture
288
- # return_dict_in_generate=True to get the GenerationOutput object
289
- output = model.diffusion_generate(
290
- initial_x, # Start with the potentially constraint-applied tensor
291
- attention_mask=None, # Assuming no padding needed for interactive use
292
- max_new_tokens=max_new_tokens, # This might not be strictly needed if total_length is fixed
293
- output_history=False,
294
- return_dict_in_generate=True,
295
- steps=steps,
296
- temperature=temperature,
297
- top_p=top_p,
298
- alg=alg,
299
- alg_temp=alg_temp if alg != 'origin' else None, # alg_temp only for confidence algs
300
- generation_tokens_hook_func=generation_tokens_hook_func
301
- )
302
 
303
- final_sequence = output.sequences[0] # Batch size 1
304
-
305
- # Decode the final response text, cleaning up special tokens
306
- response_tokens = final_sequence[prompt_length:]
307
- # Filter out EOS/PAD tokens for the final text display
308
- response_tokens_filtered = [tok for tok in response_tokens.tolist() if tok != EOS_ID and tok != PAD_ID]
309
- final_text = tokenizer.decode(response_tokens_filtered,
310
- skip_special_tokens=True,
311
- clean_up_tokenization_spaces=True) # Standard cleanup
312
-
313
- except Exception as e:
314
- print(f"Error during generation: {e}")
315
- import traceback
316
- traceback.print_exc()
317
- # Provide error state
318
- error_state = [("Generation Error.", "red")]
319
- visualization_states.append(error_state)
320
- final_text = f"Error: Generation failed. {e}"
321
- # Add any states captured before the error
322
- while not vis_queue.empty():
323
- try:
324
- visualization_states.append(vis_queue.get_nowait())
325
- except Queue.Empty:
326
- break
327
- return visualization_states, final_text
328
 
329
- # Retrieve all visualization states captured by the hook
330
- while not vis_queue.empty():
331
- try:
332
- visualization_states.append(vis_queue.get_nowait())
333
- except Queue.Empty:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  break
335
 
336
- # If somehow no states were captured, add the initial one
337
- if not visualization_states:
338
- visualization_states.append(initial_vis_state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
- return visualization_states, final_text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
- # --- Gradio UI ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
 
 
346
  css = '''
347
  .category-legend{display:none}
348
  button{height: 60px}
 
 
 
 
 
 
 
349
  '''
 
350
  def create_chatbot_demo():
351
- with gr.Blocks(css=css) as demo:
352
  gr.Markdown("# Dream 7B - Diffusion Language Model Demo")
353
- gr.Markdown("Chat with the Dream 7B Instruct model and visualize the diffusion generation process.")
354
- gr.Markdown("Model: [Dream-org/Dream-v0-Instruct-7B](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)")
 
 
 
 
 
 
 
 
 
 
 
355
 
356
  # STATE MANAGEMENT
357
  chat_history = gr.State([])
 
358
 
359
  # UI COMPONENTS
360
  with gr.Row():
361
  with gr.Column(scale=3):
362
- chatbot_ui = gr.Chatbot(label="Conversation", height=500, avatar_images=["user.png", "robot.png"])
363
 
364
  # Message input
365
  with gr.Group():
@@ -367,192 +524,229 @@ def create_chatbot_demo():
367
  user_input = gr.Textbox(
368
  label="Your Message",
369
  placeholder="Type your message here...",
370
- show_label=False,
371
- scale=9
372
  )
373
  send_btn = gr.Button("Send", scale=1)
374
 
375
  constraints_input = gr.Textbox(
376
- label="Word Constraints (Optional)",
377
- info="Place words at specific positions (0-indexed from response start). Format: 'pos:word, pos:word,...'. Example: '0:Once, 5:upon, 10:a'",
378
- placeholder="0:Once, 5:upon, 10:a",
379
  value=""
380
  )
381
  with gr.Column(scale=2):
 
382
  output_vis = gr.HighlightedText(
383
- label="Diffusion Process Visualization",
384
- combine_adjacent=False,
385
- show_legend=True, # Keep legend hidden via CSS if desired
386
- )
387
- # Legend (colors defined in generate_response_with_visualization)
388
- gr.Markdown(
389
- "<small>Color Legend: <span style='background-color:#444444; color:white;'>[MASK]</span>"
390
- " <span style='background-color:#66CC66;'>Newly Revealed</span>"
391
- " <span style='background-color:#6699CC;'>Previously Revealed</span>"
392
- " <span style='background-color:#FFD700;'>Constraint</span>"
393
- " <span style='background-color:#DDDDDD;'>[EOS/PAD/UNK]</span></small>"
394
  )
395
 
 
396
  # Advanced generation settings
397
  with gr.Accordion("Generation Settings", open=False):
398
- max_new_tokens_slider = gr.Slider(
399
- minimum=16, maximum=512, value=128, step=16, # Increased default/max
400
- label="Max New Tokens (Generation Length)"
401
- )
402
- steps_slider = gr.Slider(
403
- minimum=8, maximum=512, value=128, step=8, # Increased default/max
404
- label="Diffusion Steps"
405
- )
406
- temp_slider = gr.Slider(
407
- minimum=0.0, maximum=1.0, value=0.6, step=0.05, # Finer steps for temp
408
- label="Temperature"
409
- )
410
- top_p_slider = gr.Slider(
411
- minimum=0.0, maximum=1.0, value=0.95, step=0.05,
412
- label="Top-P (Nucleus Sampling)"
413
- )
414
- alg_radio = gr.Radio(
415
- # Choices from README
416
- choices=['origin', 'entropy', 'maskgit_plus', 'topk_margin'],
417
- value='entropy',
418
- label="Remasking Algorithm"
419
- )
420
- alg_temp_slider = gr.Slider(
421
- minimum=0.0, maximum=1.0, value=0.1, step=0.05,
422
- label="Algorithm Temperature (for confidence-based algs)"
423
- )
424
- vis_delay_slider = gr.Slider(
425
- minimum=0.0, maximum=0.5, value=0.03, step=0.01, # Faster default delay
426
- label="Visualization Delay (seconds)"
427
- )
 
 
 
 
 
 
 
 
428
 
429
  # Clear button
430
  clear_btn = gr.Button("Clear Conversation")
431
 
432
- # HELPER FUNCTIONS (UI Logic)
433
  def add_message_to_history(history, message, response):
434
  """Add a message pair to the history state"""
435
- new_history = history + [[message, response]]
436
- return new_history
437
 
438
- def user_message_submitted(message, history):
439
- """ Handle user sending a message: update history, clear input """
440
  if not message or message.strip() == "":
441
- return history, history, "", [] # No change if empty
442
-
443
- # Add user message, response is initially None
444
- new_history = add_message_to_history(history, message, None)
445
-
446
- # Prepare display version (immediately shows user message)
447
- display_history = new_history
448
-
449
- # Clear input box
450
- message_out = ""
451
-
452
- # Clear visualization
453
- vis_out = []
454
-
455
- return new_history, display_history, message_out, vis_out
456
-
457
- def bot_response_generator(history, constraints_str, max_tokens, steps, temp, top_p, alg, alg_temp, delay):
458
- """ Generator function to stream bot response and visualization """
459
- if not history or history[-1][1] is not None: # Ensure there's a user msg waiting for response
460
- print("Warning: Bot response triggered without pending user message.")
461
- yield history, [], "Error: No user message to respond to." # Send error state back?
462
  return
463
 
464
- # Get the full conversation history formatted for the model
465
  last_user_message = history[-1][0]
466
- messages_for_model = format_chat_history(history[:-1]) # History *before* the last user msg
467
- messages_for_model.append({"role": "user", "content": last_user_message})
468
 
469
- # Parse constraints
470
  try:
 
 
 
 
 
 
471
  parsed_constraints = parse_constraints(constraints_str)
472
- except Exception as e:
473
- print(f"Error parsing constraints: {e}")
474
- yield history, [("Constraint Error", "red")], f"Error: Failed to parse constraints: {e}"
475
- return
476
 
477
- # Generate response and visualization states
478
- try:
479
- vis_states, final_response_text = generate_response_with_visualization(
480
- messages_for_model,
481
- max_new_tokens=max_tokens,
 
 
482
  steps=steps,
483
  constraints=parsed_constraints,
484
- temperature=temp,
485
- top_p=top_p,
 
486
  alg=alg,
487
- alg_temp=alg_temp
488
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
  except Exception as e:
490
- print(f"Error in generate_response_with_visualization: {e}")
491
  import traceback
 
492
  traceback.print_exc()
493
- yield history, [("Generation Error", "red")], f"Error: Generation failed: {e}"
494
- return
495
-
496
- # Update the history state with the final response *once*
497
- history[-1][1] = final_response_text # Update the None placeholder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
 
499
- # Yield initial state immediately
500
- if vis_states:
501
- yield history, vis_states[0]
502
- else:
503
- yield history, [] # Should not happen if generation worked
504
-
505
- # Stream intermediate visualization states
506
- for state in vis_states[1:]:
507
- time.sleep(delay)
508
- yield history, state
509
-
510
- # Final yield ensures the chatbot UI has the complete history
511
- # The last state in vis_states should already be yielded by the loop
512
- # yield history, vis_states[-1] if vis_states else []
513
-
514
-
515
- def clear_conversation():
516
- """Clear the conversation history and visualization"""
517
- return [], [], "", [] # history, chatbot_ui, user_input, output_vis
518
-
519
- # EVENT HANDLERS
520
-
521
- # User presses Enter or Send button
522
- submit_args = {
523
- "fn": user_message_submitted,
524
- "inputs": [user_input, chat_history],
525
- "outputs": [chat_history, chatbot_ui, user_input, output_vis]
526
- }
527
- user_input.submit(**submit_args)
528
- send_btn.click(**submit_args)
529
-
530
- # After user message is submitted, trigger bot response generation
531
- generate_args = {
532
- "fn": bot_response_generator,
533
- "inputs": [
534
- chat_history, constraints_input, max_new_tokens_slider, steps_slider,
535
- temp_slider, top_p_slider, alg_radio, alg_temp_slider, vis_delay_slider
536
- ],
537
- "outputs": [chatbot_ui, output_vis] # Update chatbot history and visualization
538
- }
539
- # Trigger generation after submit OR click
540
- user_input.submit(None, None, None, queue=True).then(**generate_args)
541
- send_btn.click(None, None, None, queue=True).then(**generate_args)
542
-
543
-
544
- # Clear button handler
545
  clear_btn.click(
546
- fn=clear_conversation,
547
  inputs=[],
548
- outputs=[chat_history, chatbot_ui, user_input, output_vis]
 
549
  )
550
 
551
  return demo
552
 
553
- # Launch the demo
554
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
555
  demo = create_chatbot_demo()
556
- # queue() allows streaming and handling multiple users
557
- # share=True creates a public link (use with caution)
558
- demo.queue().launch(share=True, debug=True)
 
3
  import numpy as np
4
  import gradio as gr
5
  import spaces
6
+ import torch.nn.functional as F
7
+ from transformers import AutoTokenizer, AutoModel
8
+ from transformers.generation.configuration_utils import GenerationConfig
9
  import time
10
  import re
11
+ import torch.distributions as dists # Import dists for sampling logic
12
+
13
+ # --- Model Loading ---
14
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
+ print(f"Using device: {device}")
16
+
17
+ # Load Dream model and tokenizer
18
+ model_path = "Dream-org/Dream-v0-Instruct-7B"
19
+ # Load configuration first to get token IDs
20
+ config = DreamConfig.from_pretrained(model_path) # Assuming configuration_dream.py is present
21
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
22
+ model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True)
23
+ model = model.to(device).eval()
24
+ print("Model and Tokenizer loaded.")
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  # --- Constants ---
27
+ MASK_TOKEN = tokenizer.mask_token # "<|mask|>"
28
+ MASK_ID = config.mask_token_id # Get from config (e.g., 151666)
29
+ EOS_ID = config.eos_token_id # Get from config (e.g., 151643)
30
+ PAD_ID = config.pad_token_id # Get from config (e.g., 151643)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # --- Helper Functions ---
33
 
 
44
  try:
45
  pos_str, word = part.split(':', 1)
46
  pos = int(pos_str.strip())
47
+ # Use strip() and lower() for robustness if needed, but preserve case for now
48
  word = word.strip()
49
  if word and pos >= 0:
50
  # Tokenize the word - handle potential multi-token words
51
+ # Add space prefix typical for non-leading words if pos > 0
52
+ prefix = " " if pos > 0 else ""
53
+ tokens = tokenizer.encode(prefix + word, add_special_tokens=False)
54
  for i, token_id in enumerate(tokens):
55
+ # Only add if the token is not a special token id already
56
+ # (This prevents accidental replacement of things like MASK_ID)
57
+ if token_id not in [MASK_ID, EOS_ID, PAD_ID]:
58
+ constraints[pos + i] = token_id
59
  except ValueError:
60
  continue
61
  except Exception as e:
 
64
 
65
  return constraints
66
 
67
+
68
  def format_chat_history(history):
69
  """
70
+ Format chat history for the Dream model (using ChatML format potentially)
71
 
72
  Args:
73
  history: List of [user_message, assistant_message] pairs
74
 
75
  Returns:
76
+ Formatted conversation for the model (list of message dicts)
77
  """
78
  messages = []
79
+ # Check if the first message is a system prompt
80
+ if history and history[0][0].lower().startswith("system:"):
81
+ # Special handling if needed, or just treat as user
82
+ # For now, let's assume standard user/assistant alternation
83
+ pass # Or handle system prompt separately if template requires
84
+
85
+ for i, (user_msg, assistant_msg) in enumerate(history):
86
+ # Basic user/assistant structure
87
+ messages.append({"role": "user", "content": user_msg})
88
  if assistant_msg is not None: # Skip if None (for the latest user message)
89
  messages.append({"role": "assistant", "content": assistant_msg})
90
 
91
  return messages
92
 
93
+ # --- Core Generation Logic (Adapted from Dream's _sample) ---
94
+
95
+ def sample_tokens_for_vis(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
96
+ """
97
+ Simplified version of Dream's sample_tokens to get both token and confidence.
98
+ Returns confidence and chosen token ID.
99
+ """
100
+ # Apply temperature
101
+ if temperature > 0:
102
+ logits = logits / temperature
103
+
104
+ # Apply Top-P
105
+ if top_p is not None and top_p < 1.0:
106
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
107
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
108
+ sorted_indices_to_remove = cumulative_probs > top_p
109
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
110
+ sorted_indices_to_remove[..., 0] = 0
111
+ indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_remove)
112
+ logits = logits.masked_fill(indices_to_remove, float('-inf'))
113
+
114
+ # Apply Top-K
115
+ if top_k is not None and top_k > 0:
116
+ top_k = min(top_k, logits.size(-1))
117
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
118
+ logits = logits.masked_fill(indices_to_remove, float('-inf'))
119
+
120
+ # Calculate probabilities
121
+ probs = torch.softmax(logits, dim=-1)
122
+
123
+ # Sample or Argmax
124
+ if temperature > 0:
125
+ # Use torch distributions for robust sampling
126
+ dist = dists.Categorical(probs=probs)
127
+ x0 = dist.sample()
128
+ # Gather confidence for the sampled token
129
+ confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
130
+ else:
131
+ # Argmax for deterministic generation
132
+ confidence, x0 = torch.max(probs, dim=-1)
133
+
134
+ # --- Calculate specific confidence metrics if requested ---
135
+ # Note: These modify the 'confidence' variable *after* sampling x0
136
+ if margin_confidence:
137
+ if probs.shape[-1] >= 2:
138
+ # Ensure logits weren't completely masked, handle edge cases
139
+ if not torch.isinf(logits).all(dim=-1).any():
140
+ # Sort probabilities to get top1 and top2
141
+ sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
142
+ top1_probs = sorted_probs[..., 0]
143
+ top2_probs = sorted_probs[..., 1]
144
+ confidence = top1_probs - top2_probs
145
+ else:
146
+ # Fallback if all logits are -inf (shouldn't normally happen)
147
+ confidence.fill_(0.0) # Or some other indicator
148
+ else:
149
+ # Only one possible token, margin is undefined or 1? Set to top1 prob.
150
+ confidence, _ = torch.max(probs, dim=-1)
151
+
152
+ elif neg_entropy:
153
+ epsilon = 1e-9 # Slightly smaller epsilon
154
+ log_probs = torch.log(probs + epsilon)
155
+ # Negative entropy is sum(p * log(p))
156
+ confidence = torch.sum(probs * log_probs, dim=-1) # Lower value (more negative) is higher confidence
157
+
158
+ return confidence, x0
159
 
 
 
 
 
 
 
160
 
161
  @spaces.GPU
162
+ @torch.no_grad()
163
+ def generate_response_with_visualization_dream(
164
+ messages, gen_length=64, steps=64,
165
+ constraints=None, temperature=0.2, top_p=0.95, top_k=None, # Added top_k
166
+ alg="entropy", alg_temp=0.1, # Dream specific params
167
+ yield_intermediate=True # Control yielding behavior
168
+ ):
 
 
 
169
  """
170
+ Generate text with Dream model with real-time visualization.
171
+ Adapts logic from Dream's _sample method.
172
 
173
  Args:
174
+ messages: List of message dictionaries with 'role' and 'content'
175
+ gen_length: Max new tokens to generate
176
+ steps: Number of diffusion steps
177
+ constraints: Dictionary mapping positions to *token IDs*
178
+ temperature: Sampling temperature
179
+ top_p: Nucleus sampling probability
180
+ top_k: Top-k sampling
181
+ alg: Remasking strategy ('origin', 'maskgit_plus', 'topk_margin', 'entropy')
182
+ alg_temp: Temperature for confidence-based remasking randomness
183
+ yield_intermediate: Whether to yield intermediate states for visualization
184
 
185
  Returns:
186
+ Yields visualization states or returns final state list, and final text.
187
  """
 
188
  if constraints is None:
189
+ constraints = {} # keys are positions relative to start of response
190
 
191
+ # --- Prepare Input ---
192
+ chat_input_text = tokenizer.apply_chat_template(
193
+ messages, add_generation_prompt=True, tokenize=False
194
+ )
195
+ input_ids = tokenizer(chat_input_text, return_tensors="pt")['input_ids'].to(device)
196
+ prompt_length = input_ids.shape[1]
197
+ max_length = prompt_length + gen_length
198
+
199
+ # Clamp max_length if it exceeds model capacity (use config value if available)
200
+ model_max_len = getattr(config, 'max_position_embeddings', 2048) # Default fallback
201
+ if max_length > model_max_len:
202
+ print(f"Warning: Requested length ({max_length}) exceeds model max ({model_max_len}). Clamping.")
203
+ max_length = model_max_len
204
+ gen_length = max_length - prompt_length
205
+ if gen_length <= 0:
206
+ print("Warning: Prompt is already at or exceeding model max length. Cannot generate.")
207
+ if yield_intermediate:
208
+ yield [], "Error: Prompt too long."
209
+ return
210
+ else:
211
+ return [], "Error: Prompt too long."
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
+ # Initialize sequence 'x' with input_ids and padding with MASK_ID
215
+ x = torch.full((1, max_length), MASK_ID, dtype=torch.long, device=device)
216
+ x[:, :prompt_length] = input_ids.clone()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
+ # Apply initial constraints to x (relative position -> absolute position)
219
  for rel_pos, token_id in constraints.items():
220
  abs_pos = prompt_length + rel_pos
221
+ if abs_pos < max_length:
222
+ # Ensure we don't overwrite prompt or special tokens accidentally
223
+ if token_id not in [MASK_ID, EOS_ID, PAD_ID]:
224
+ x[:, abs_pos] = token_id
225
+ else:
226
+ print(f"Warning: Skipped constraint for special token ID {token_id} at pos {rel_pos}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
+ # --- Visualization Setup ---
230
+ visualization_states = []
231
+ revealed_eos_pad = set() # Track positions where EOS/PAD was shown once
232
+
233
+ def get_vis_state(current_x, old_x, step_confidences=None):
234
+ nonlocal revealed_eos_pad
235
+ state = []
236
+ newly_revealed_in_step = False # Flag if any token changed from MASK
237
+ current_revealed_eos_pad = set() # Track EOS/PAD revealed *in this step*
238
+
239
+ for i in range(gen_length):
240
+ abs_pos = prompt_length + i
241
+ current_token_id = current_x[0, abs_pos].item()
242
+ old_token_id = old_x[0, abs_pos].item()
243
+
244
+ is_eos_or_pad = (current_token_id == EOS_ID or current_token_id == PAD_ID)
245
+
246
+ # Handle EOS/PAD hiding: Show once, then hide
247
+ if is_eos_or_pad and abs_pos in revealed_eos_pad:
248
+ state.append(("", "#FFFFFF")) # Make it invisible (white on white/transparent)
249
+ continue # Skip rest of logic for this pos
250
+
251
+ token_str = tokenizer.decode([current_token_id], skip_special_tokens=False) # Decode even specials initially
252
+
253
+ if current_token_id == MASK_ID:
254
+ color = "#444444" # Dark Gray for Mask
255
+ token_str = MASK_TOKEN # Display mask token string
256
+ elif old_token_id == MASK_ID: # Newly revealed in this step
257
+ newly_revealed_in_step = True
258
+ confidence = step_confidences.get(abs_pos, 0.5) # Get confidence if available, default 0.5
259
+
260
+ # Color based on confidence (adjust thresholds as needed)
261
+ # Note: Entropy confidence is negative, more negative = higher confidence
262
+ if alg == 'entropy':
263
+ # Example thresholds for negative entropy (adjust based on observation)
264
+ if confidence > -1.0: # Low confidence (high entropy)
265
+ color = "#FF6666" # Light Red
266
+ elif confidence > -3.0: # Medium confidence
267
+ color = "#FFAA33" # Orange
268
+ else: # High confidence (low entropy)
269
+ color = "#66CC66" # Light Green
270
+ else: # Standard confidence (probability or margin)
271
+ if confidence < 0.3:
272
+ color = "#FF6666" # Light Red
273
+ elif confidence < 0.7:
274
+ color = "#FFAA33" # Orange
275
+ else:
276
+ color = "#66CC66" # Light Green
277
+
278
+ # If it's EOS/PAD revealed now, mark for future hiding
279
+ if is_eos_or_pad:
280
+ current_revealed_eos_pad.add(abs_pos)
281
+
282
+ else: # Previously revealed
283
+ color = "#6699CC" # Light Blue
284
+
285
+ # Clean up token string for display (optional)
286
+ # token_str = token_str.replace(" ", " ") # Keep spaces visible
287
+
288
+ state.append((token_str, color))
289
+
290
+ # Update the global set of revealed EOS/PAD positions
291
+ revealed_eos_pad.update(current_revealed_eos_pad)
292
+
293
+ return state, newly_revealed_in_step
294
+
295
+ # Add initial state (all masked, constraints applied)
296
+ initial_vis_state, _ = get_vis_state(x, x) # Pass x as old_x initially
297
+ visualization_states.append(initial_vis_state)
298
+ if yield_intermediate:
299
+ yield initial_vis_state # Yield the starting state
300
+
301
+ # --- Diffusion Loop ---
302
+ timesteps = torch.linspace(1.0, 1e-3, steps + 1, device=device) # Use epsilon from Dream's defaults if needed
303
+
304
+ # Store the state before the loop starts
305
+ old_x = x.clone()
306
+
307
+ for i in range(steps):
308
+ # --- Core Dream Step ---
309
+ mask_index = (x == MASK_ID)
310
+ if not mask_index.any(): # Stop if no masks left
311
+ print(f"No masks left at step {i}. Stopping generation.")
312
  break
313
 
314
+ # Prepare attention mask (full attention for Dream unless specified otherwise)
315
+ # Dream's modeling code handles standard causal masking internally based on position_ids
316
+ # For diffusion, we typically allow attending to everything (masked or not)
317
+ # The `model` forward pass expects a standard causal mask or None
318
+ # Let's use None, assuming the model handles positions correctly
319
+ attention_mask = None # Or potentially create a full mask: torch.ones_like(x)
320
+
321
+ # Create position_ids (simple range for now)
322
+ position_ids = torch.arange(0, x.shape[1], device=device).unsqueeze(0)
323
+
324
+ # Model forward pass
325
+ outputs = model(input_ids=x, attention_mask=attention_mask, position_ids=position_ids)
326
+ logits = outputs.logits
327
+ # logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1) # Dream applies shift in utils, replicate if needed
328
+
329
+ # Select logits for masked positions ONLY
330
+ # Need to handle batch dimension (which is 1 here)
331
+ current_mask_indices_flat = torch.where(mask_index.flatten())[0]
332
+ if len(current_mask_indices_flat) == 0:
333
+ print(f"No mask indices found flat at step {i}. Stopping generation.")
334
+ break
335
 
336
+ # Use advanced indexing to get logits for masked positions
337
+ # Logits shape: [batch_size, seq_len, vocab_size]
338
+ # Mask_index shape: [batch_size, seq_len]
339
+ # We need logits corresponding to True values in mask_index
340
+ # Example: batch_idx = torch.where(mask_index)[0], seq_idx = torch.where(mask_index)[1]
341
+ # mask_logits = logits[batch_idx, seq_idx]
342
+ batch_indices, seq_indices = torch.where(mask_index)
343
+ mask_logits = logits[batch_indices, seq_indices] # Shape: [num_masked_tokens, vocab_size]
344
+
345
+ if mask_logits.numel() == 0: # Double check after indexing
346
+ print(f"No mask logits selected at step {i}. Stopping generation.")
347
+ break
348
 
349
+ t = timesteps[i]
350
+ s = timesteps[i + 1]
351
+
352
+ # --- Remasking Logic (Simplified from Dream's _sample) ---
353
+ step_confidences = {} # Store confidences for revealed tokens in this step {abs_pos: confidence}
354
+
355
+ if alg == 'origin':
356
+ p_transfer = (1.0 - s / t) if i < steps - 1 else 1.0
357
+ # Sample for all masked positions
358
+ confidence, x0_masked = sample_tokens_for_vis(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
359
+ # Decide which ones to transfer based on random probability
360
+ transfer_mask = torch.rand(x0_masked.shape, device=device) < p_transfer
361
+ # Create a tensor of MASK_IDs, and fill in the transferred tokens
362
+ updates_for_masked_pos = torch.full_like(x0_masked, MASK_ID)
363
+ updates_for_masked_pos[transfer_mask] = x0_masked[transfer_mask]
364
+ # Update x at the masked positions
365
+ x[mask_index] = updates_for_masked_pos
366
+
367
+ # Store confidences for the *transferred* tokens for visualization
368
+ transferred_indices_flat = current_mask_indices_flat[transfer_mask]
369
+ transferred_confidences = confidence[transfer_mask]
370
+ for flat_idx, conf in zip(transferred_indices_flat, transferred_confidences):
371
+ abs_pos = flat_idx.item() # Convert flat index back to seq position (assuming batch=1)
372
+ step_confidences[abs_pos] = conf.item()
373
+
374
+
375
+ else: # Confidence-based algorithms ('maskgit_plus', 'topk_margin', 'entropy')
376
+ use_margin = (alg == 'topk_margin')
377
+ use_entropy = (alg == 'entropy')
378
+ # Sample potential replacements for ALL masked positions first
379
+ confidence, x0_masked = sample_tokens_for_vis(
380
+ mask_logits,
381
+ temperature=temperature,
382
+ top_p=top_p,
383
+ top_k=top_k,
384
+ margin_confidence=use_margin,
385
+ neg_entropy=use_entropy
386
+ )
387
 
388
+ num_mask_tokens = mask_index.sum().item()
389
+ # Calculate how many tokens to unmask/transfer in this step
390
+ num_transfer_tokens = int(num_mask_tokens * (1.0 - s / t)) if i < steps - 1 else num_mask_tokens
391
+
392
+ if num_transfer_tokens > 0 and confidence.numel() > 0:
393
+ transfer_indices_relative = None # Indices relative to the masked tokens
394
+ if alg_temp is None or alg_temp <= 0:
395
+ # Deterministic: Select top-k confidence scores among masked tokens
396
+ # Ensure k is not larger than the number of masked tokens
397
+ k = min(num_transfer_tokens, confidence.shape[0])
398
+ if k > 0:
399
+ _, transfer_indices_relative = torch.topk(confidence, k)
400
+ else:
401
+ # Stochastic: Sample based on confidence scores
402
+ # Ensure probabilities are valid
403
+ conf_probs = F.softmax(confidence / alg_temp, dim=-1)
404
+ if not torch.isnan(conf_probs).any() and not torch.isinf(conf_probs).any() and conf_probs.sum() > 1e-6:
405
+ # Ensure k is not larger than the number of masked tokens
406
+ k = min(num_transfer_tokens, confidence.shape[0])
407
+ if k > 0:
408
+ transfer_indices_relative = torch.multinomial(conf_probs, num_samples=k, replacement=False)
409
+ else:
410
+ print(f"Warning: Invalid confidence probabilities at step {i}. Falling back to top-k.")
411
+ # Fallback to deterministic if sampling fails
412
+ k = min(num_transfer_tokens, confidence.shape[0])
413
+ if k > 0:
414
+ _, transfer_indices_relative = torch.topk(confidence, k)
415
+
416
+
417
+ if transfer_indices_relative is not None and transfer_indices_relative.numel() > 0:
418
+ # Create updates, initially all MASK_ID
419
+ updates_for_masked_pos = torch.full_like(x0_masked, MASK_ID)
420
+ # Place the selected sampled tokens into the updates tensor
421
+ updates_for_masked_pos[transfer_indices_relative] = x0_masked[transfer_indices_relative]
422
+ # Update x at the original masked positions
423
+ x[mask_index] = updates_for_masked_pos
424
+
425
+ # Store confidences for the *transferred* tokens for visualization
426
+ selected_confidences = confidence[transfer_indices_relative]
427
+ # Get the absolute positions corresponding to these relative indices
428
+ original_indices_flat = current_mask_indices_flat[transfer_indices_relative]
429
+ for flat_idx, conf in zip(original_indices_flat, selected_confidences):
430
+ abs_pos = flat_idx.item()
431
+ step_confidences[abs_pos] = conf.item()
432
 
433
+ else:
434
+ # No tokens were selected to transfer, x remains unchanged for masked parts
435
+ pass # x[mask_index] remains MASK_ID essentially
436
+
437
+ else:
438
+ # If num_transfer_tokens is 0, x remains unchanged for masked parts
439
+ pass
440
+
441
+ # --- Apply Constraints and Finalize Step ---
442
+ # Ensure constraints are always maintained AFTER updates
443
+ for rel_pos, token_id in constraints.items():
444
+ abs_pos = prompt_length + rel_pos
445
+ if abs_pos < max_length:
446
+ # Check if the position was masked before applying constraint
447
+ # if mask_index[0, abs_pos]: # Only apply if it *was* a mask, maybe? Optional.
448
+ x[:, abs_pos] = token_id
449
+
450
+ # --- Visualization Update ---
451
+ current_vis_state, newly_revealed = get_vis_state(x, old_x, step_confidences)
452
+
453
+ # Only add/yield if something actually changed or if it's the last step
454
+ if newly_revealed or i == steps - 1:
455
+ visualization_states.append(current_vis_state)
456
+ if yield_intermediate:
457
+ yield current_vis_state
458
+
459
+ # Update old_x for the next iteration
460
+ old_x = x.clone()
461
+
462
+
463
+ # --- Final Output ---
464
+ response_tokens = x[0, prompt_length:]
465
+ # Decode, cleaning up potential special tokens unless they are intended
466
+ final_text = tokenizer.decode(response_tokens,
467
+ skip_special_tokens=True, # Skip things like <|mask|> in final output
468
+ clean_up_tokenization_spaces=True)
469
+
470
+ # If not yielding intermediates, return the full list now
471
+ if not yield_intermediate:
472
+ return visualization_states, final_text
473
+ else:
474
+ # If yielding intermediates, we still need a way to signal completion
475
+ # and return the final text. Gradio's yield typically handles this if
476
+ # the last yielded value is the final one. We'll return the final text
477
+ # separately after the loop finishes in the calling function.
478
+ # The loop yields states, the calling function returns the final text.
479
+ pass # Final text is handled outside the generator function
480
 
481
+
482
+ # --- Gradio UI ---
483
  css = '''
484
  .category-legend{display:none}
485
  button{height: 60px}
486
+ .token-revealed { transition: background-color 0.5s ease; } /* Optional: Add transition effect */
487
+ .token-masked { background-color: #444444; color: white; padding: 1px 2px; margin: 1px; border-radius: 3px; display: inline-block; }
488
+ .token-new-high { background-color: #66CC66; color: black; padding: 1px 2px; margin: 1px; border-radius: 3px; display: inline-block; }
489
+ .token-new-mid { background-color: #FFAA33; color: black; padding: 1px 2px; margin: 1px; border-radius: 3px; display: inline-block; }
490
+ .token-new-low { background-color: #FF6666; color: black; padding: 1px 2px; margin: 1px; border-radius: 3px; display: inline-block; }
491
+ .token-old { background-color: #6699CC; color: white; padding: 1px 2px; margin: 1px; border-radius: 3px; display: inline-block; }
492
+ .token-hidden { display: none; } /* Hide EOS/PAD after first reveal */
493
  '''
494
+
495
  def create_chatbot_demo():
496
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
497
  gr.Markdown("# Dream 7B - Diffusion Language Model Demo")
498
+ gr.Markdown(
499
+ "[[Model Card](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)] "
500
+ "[[Blog](https://hkunlp.github.io/blog/2025/dream/)] "
501
+ "[[Original LLaDA Demo Inspiration](https://huggingface.co/spaces/GSAI-ML/LLaDA-demo)]"
502
+ )
503
+ gr.Markdown(
504
+ "**Note:** This demo visualizes the diffusion process in real-time. "
505
+ "Tokens start masked (<font color='#444444'>[MASK]</font>) and are revealed step-by-step. "
506
+ "Colors indicate confidence: <font color='#66CC66'>High</font>, "
507
+ "<font color='#FFAA33'>Medium</font>, <font color='#FF6666'>Low</font>. "
508
+ "Previously revealed tokens are <font color='#6699CC'>blue</font>. "
509
+ f"EOS/PAD tokens ({tokenizer.decode([EOS_ID])}) are hidden after appearing once."
510
+ )
511
 
512
  # STATE MANAGEMENT
513
  chat_history = gr.State([])
514
+ current_response_text = gr.State("") # Store the final text separately
515
 
516
  # UI COMPONENTS
517
  with gr.Row():
518
  with gr.Column(scale=3):
519
+ chatbot_ui = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False)
520
 
521
  # Message input
522
  with gr.Group():
 
524
  user_input = gr.Textbox(
525
  label="Your Message",
526
  placeholder="Type your message here...",
527
+ scale=7,
528
+ show_label=False
529
  )
530
  send_btn = gr.Button("Send", scale=1)
531
 
532
  constraints_input = gr.Textbox(
533
+ label="Word Constraints (Relative Position)",
534
+ info="Place words at specific 0-indexed positions in the *response*. Format: 'pos:word, pos:word'. Example: '0:Once, 5:upon, 10:time'",
535
+ placeholder="0:Hello, 10:world",
536
  value=""
537
  )
538
  with gr.Column(scale=2):
539
+ # Use HighlightedText with specific classes for better styling control
540
  output_vis = gr.HighlightedText(
541
+ label="Denoising Process Visualization",
542
+ # Show legend mapping colors to confidence might be useful if classes aren't self-explanatory
543
+ # For now, using the description markdown above.
544
+ show_legend=False,
545
+ # Use custom classes defined in CSS
546
+ # color_map={ # This might not work directly with dynamic classes, CSS is better
547
+ # "MASK": "#444444", "NEW_H": "#66CC66", "NEW_M": "#FFAA33",
548
+ # "NEW_L": "#FF6666", "OLD": "#6699CC", "HIDDEN": "#FFFFFF"
549
+ # }
550
+ combine_adjacent=False, # Keep tokens separate
551
+ height=550, # Adjust height as needed
552
  )
553
 
554
+
555
  # Advanced generation settings
556
  with gr.Accordion("Generation Settings", open=False):
557
+ with gr.Row():
558
+ gen_length = gr.Slider(
559
+ minimum=16, maximum=512, value=64, step=8, # Increased max length
560
+ label="Max New Tokens"
561
+ )
562
+ steps = gr.Slider(
563
+ minimum=8, maximum=512, value=64, step=4, # Allow more steps
564
+ label="Diffusion Steps"
565
+ )
566
+ with gr.Row():
567
+ temperature = gr.Slider(
568
+ minimum=0.0, maximum=1.5, value=0.2, step=0.05, # Wider range for temp
569
+ label="Temperature"
570
+ )
571
+ top_p = gr.Slider(
572
+ minimum=0.0, maximum=1.0, value=0.95, step=0.05,
573
+ label="Top-P (Nucleus Sampling)"
574
+ )
575
+ # top_k = gr.Slider(
576
+ # minimum=0, maximum=200, value=0, step=5, # Allow Top-K=0 (disabled)
577
+ # label="Top-K (0 to disable)"
578
+ # )
579
+ with gr.Row():
580
+ # Dream specific algorithm choice
581
+ alg_strategy = gr.Radio(
582
+ choices=["entropy", "maskgit_plus", "topk_margin", "origin"],
583
+ value="entropy",
584
+ label="Algorithm (`alg`)"
585
+ )
586
+ alg_temp = gr.Slider(
587
+ minimum=0.0, maximum=1.0, value=0.1, step=0.01,
588
+ label="Algorithm Temp (`alg_temp`)"
589
+ )
590
+ with gr.Row():
591
+ visualization_delay = gr.Slider(
592
+ minimum=0.0, maximum=0.5, value=0.03, step=0.01, # Faster default delay
593
+ label="Visualization Delay (seconds)"
594
+ )
595
 
596
  # Clear button
597
  clear_btn = gr.Button("Clear Conversation")
598
 
599
+ # --- Helper Functions for UI ---
600
  def add_message_to_history(history, message, response):
601
  """Add a message pair to the history state"""
602
+ history.append([message, response])
603
+ return history
604
 
605
+ def user_message_action(message, history):
606
+ """Handles user sending a message: updates history, clears input."""
607
  if not message or message.strip() == "":
608
+ return history, history, "", [], "" # Return empty vis, empty response
609
+
610
+ # Add user message with None response placeholder
611
+ history = add_message_to_history(history, message, None)
612
+ # Return updated history for chatbot display, clear input box
613
+ return history, history, "", [], "" # Clear vis and response text state too
614
+
615
+ def bot_response_generator(
616
+ history, gen_length, steps, constraints_str, delay,
617
+ temperature, top_p, # top_k,
618
+ alg, alg_temp
619
+ ):
620
+ """Generates bot response and yields visualization states."""
621
+ if not history or history[-1][1] is not None: # Check if last message already has a response
622
+ print("History empty or last message already processed.")
623
+ yield history, [], "" # Yield empty state if no work to do
 
 
 
 
 
624
  return
625
 
 
626
  last_user_message = history[-1][0]
627
+ print(f"Generating response for: {last_user_message}")
 
628
 
 
629
  try:
630
+ # Format history for the model (excluding the last None response)
631
+ messages = format_chat_history(history[:-1])
632
+ # Add the current user message
633
+ messages.append({"role": "user", "content": last_user_message})
634
+
635
+ # Parse constraints into token IDs
636
  parsed_constraints = parse_constraints(constraints_str)
637
+ print(f"Parsed constraints: {parsed_constraints}")
 
 
 
638
 
639
+
640
+ final_text = "" # Initialize final_text
641
+
642
+ # Use the generator function
643
+ response_generator = generate_response_with_visualization_dream(
644
+ messages,
645
+ gen_length=gen_length,
646
  steps=steps,
647
  constraints=parsed_constraints,
648
+ temperature=temperature,
649
+ top_p=top_p if top_p > 0 else None, # Pass None if 0
650
+ top_k=None, # Pass None if 0 top_k if top_k > 0 else None,
651
  alg=alg,
652
+ alg_temp=alg_temp if alg_temp > 0 else None, # Pass None if 0
653
+ yield_intermediate=True
654
+ )
655
+
656
+ # Iterate through the yielded visualization states
657
+ last_state = None
658
+ for vis_state in response_generator:
659
+ last_state = vis_state
660
+ # Update chatbot with placeholder during generation
661
+ history[-1][1] = "..." # Indicate thinking
662
+ yield history, vis_state, "..." # Yield history, current vis state, placeholder text
663
+ if delay > 0:
664
+ time.sleep(delay)
665
+
666
+ # --- Generation Finished ---
667
+ # Extract final text (needs to be done *after* the generator is exhausted)
668
+ # Re-run the generation without yielding intermediates to get the final text reliably
669
+ # (Or modify the generator to return it, but this is simpler for now)
670
+ # TODO: Optimize this - maybe the generator could return the final text at the end?
671
+
672
+ print("Re-generating final text (non-streaming)...")
673
+ final_vis_states, final_text = generate_response_with_visualization_dream(
674
+ messages, gen_length, steps, parsed_constraints, temperature,
675
+ top_p if top_p > 0 else None, None, #top_k if top_k > 0 else None,
676
+ alg, alg_temp if alg_temp > 0 else None,
677
+ yield_intermediate=False # Get final result only
678
+ )
679
+ print(f"Final Text: {final_text}")
680
+
681
+
682
+ # Update the history with the actual final response
683
+ history[-1][1] = final_text.strip() if final_text else "[No response]"
684
+
685
+ # Yield the final state one last time
686
+ yield history, final_vis_states[-1] if final_vis_states else [], final_text.strip()
687
+
688
  except Exception as e:
 
689
  import traceback
690
+ print(f"Error during generation: {e}")
691
  traceback.print_exc()
692
+ error_msg = f"Error: {str(e)}"
693
+ history[-1][1] = error_msg # Show error in chat
694
+ # Show error in visualization (red text)
695
+ error_vis = [(error_msg, "#FF0000")]
696
+ yield history, error_vis, error_msg
697
+
698
+
699
+ def clear_conversation_action():
700
+ """Clears chat history, visualization, and response text."""
701
+ return [], [], "", [] # History, Chatbot UI, Response Text, Visualization
702
+
703
+
704
+ # --- Event Wiring ---
705
+
706
+ # 1. User Submits Message (Textbox Enter or Button Click)
707
+ submit_triggers = [user_input.submit, send_btn.click]
708
+ for trigger in submit_triggers:
709
+ trigger.then(
710
+ fn=user_message_action,
711
+ inputs=[user_input, chat_history],
712
+ outputs=[chat_history, chatbot_ui, user_input, output_vis, current_response_text], # Update history state, chatbot UI, clear input, clear vis, clear response state
713
+ queue=True # Enable queue for handling multiple users
714
+ ).then(
715
+ # 2. Trigger Bot Response Generation (Generator Function)
716
+ fn=bot_response_generator,
717
+ inputs=[
718
+ chat_history, gen_length, steps, constraints_input, visualization_delay,
719
+ temperature, top_p, # top_k,
720
+ alg_strategy, alg_temp
721
+ ],
722
+ outputs=[chatbot_ui, output_vis, current_response_text] # Stream updates to chatbot, visualization, and store final text
723
+ )
724
 
725
+ # Clear Button Action
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
726
  clear_btn.click(
727
+ fn=clear_conversation_action,
728
  inputs=[],
729
+ outputs=[chat_history, chatbot_ui, current_response_text, output_vis],
730
+ queue=False # No need to queue clear action
731
  )
732
 
733
  return demo
734
 
735
+ # --- Launch ---
736
  if __name__ == "__main__":
737
+ # Make sure the necessary Dream model files (modeling_dream.py, configuration_dream.py etc.)
738
+ # are in the same directory or accessible in the Python path.
739
+ # Also ensure 'generation_utils.py' is available if needed by the model loading/config.
740
+ # Check if 'modeling_dream.py' exists before launching
741
+ import os
742
+ if not os.path.exists("modeling_dream.py") or not os.path.exists("configuration_dream.py"):
743
+ print("\nERROR: Could not find 'modeling_dream.py' and/or 'configuration_dream.py'.")
744
+ print("Please make sure these files (from the 'dream_model.txt' source) are in the same directory as this script.")
745
+ print("You might need to extract them from the provided text file.")
746
+ # exit() # Optional: stop execution if files are missing
747
+
748
+ print("Creating Gradio Demo...")
749
  demo = create_chatbot_demo()
750
+ print("Launching Gradio Demo...")
751
+ # Use queueing for better user experience with potentially long generation times
752
+ demo.queue().launch(share=True, debug=True) # Enable debug for more detailed logs