multimodalart HF Staff commited on
Commit
f4ff30a
·
verified ·
1 Parent(s): cfc13ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +329 -389
app.py CHANGED
@@ -2,57 +2,62 @@
2
  import torch
3
  import numpy as np
4
  import gradio as gr
5
- import spaces
6
  import torch.nn.functional as F
7
  from transformers import AutoTokenizer, AutoModel, AutoConfig
8
  import time
9
- import copy
 
 
 
 
 
 
 
10
 
11
  # Determine device
12
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
  print(f"Using device: {device}")
14
 
15
- # --- Model and Tokenizer Loading ---
16
- model_path = "Dream-org/Dream-v0-Instruct-7B"
17
-
18
- print(f"Loading tokenizer from {model_path}...")
19
- # Load configuration first to get special token IDs
20
- config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
21
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
22
-
23
- print(f"Loading model from {model_path}...")
24
  model = AutoModel.from_pretrained(
25
  model_path,
26
- torch_dtype=torch.bfloat16,
27
  trust_remote_code=True
28
- ).to(device).eval()
29
- print("Model loaded successfully.")
 
30
 
31
- # --- Constants from Dream Model ---
32
- # Get IDs directly from config or tokenizer if available
33
  MASK_TOKEN = tokenizer.mask_token
34
- MASK_ID = config.mask_token_id if hasattr(config, 'mask_token_id') else tokenizer.mask_token_id
35
- EOS_ID = config.eos_token_id if hasattr(config, 'eos_token_id') else tokenizer.eos_token_id
36
- PAD_ID = config.pad_token_id if hasattr(config, 'pad_token_id') else tokenizer.pad_token_id # Often same as EOS
37
-
38
- print(f"MASK_TOKEN: '{MASK_TOKEN}', MASK_ID: {MASK_ID}")
39
- print(f"EOS_ID: {EOS_ID}, PAD_ID: {PAD_ID}")
40
- if MASK_ID is None:
41
- raise ValueError("Could not determine MASK_ID from model config or tokenizer.")
42
- if EOS_ID is None:
43
- raise ValueError("Could not determine EOS_ID from model config or tokenizer.")
44
- if PAD_ID is None:
45
- raise ValueError("Could not determine PAD_ID from model config or tokenizer.")
46
-
47
 
48
  # --- Helper Functions ---
49
 
50
- def parse_constraints(constraints_text, tokenizer):
51
- """Parse constraints in format: 'position:word, position:word, ...'"""
 
 
 
 
52
  constraints = {}
53
- processed_constraints_tokens = {}
54
  if not constraints_text:
55
- return constraints, processed_constraints_tokens
56
 
57
  parts = constraints_text.split(',')
58
  for part in parts:
@@ -60,270 +65,292 @@ def parse_constraints(constraints_text, tokenizer):
60
  continue
61
  pos_str, word = part.split(':', 1)
62
  try:
 
63
  pos = int(pos_str.strip())
64
  word = word.strip()
65
- if word and pos >= 0:
66
- # Store original word constraint for display/debugging if needed
67
- constraints[pos] = word
68
- # Tokenize the word (add space for consistency if not BOS)
69
- # Note: Dream tokenizer might handle spaces differently, adjust if needed
70
- prefix = " " if pos > 0 else ""
71
- tokens = tokenizer.encode(prefix + word, add_special_tokens=False)
72
- for i, token_id in enumerate(tokens):
73
- # Prevent overwriting multi-token constraints partially
74
- if pos + i not in processed_constraints_tokens:
75
- processed_constraints_tokens[pos + i] = token_id
76
  except ValueError:
77
- continue
78
  except Exception as e:
79
- print(f"Error tokenizing constraint word '{word}': {e}")
80
- continue
 
 
81
 
82
- # Sort by position for consistent application
83
- processed_constraints_tokens = dict(sorted(processed_constraints_tokens.items()))
84
- print(f"Parsed Constraints (Word): {constraints}")
85
- print(f"Parsed Constraints (Tokens): {processed_constraints_tokens}")
86
- return constraints, processed_constraints_tokens
87
 
88
- def format_chat_history(history):
89
  """
90
- Format chat history for the Dream model using its chat template convention.
91
 
92
  Args:
93
- history: List of [user_message, assistant_message] pairs
 
94
 
95
  Returns:
96
- Formatted list of message dictionaries for the model
97
  """
98
  messages = []
99
- # Add system prompt if not present (standard practice)
100
- if not history or history[0][0] is None or history[0][0].lower() != "system":
101
- messages.append({"role": "system", "content": "You are a helpful assistant."})
 
102
 
103
  for user_msg, assistant_msg in history:
104
- if user_msg is not None: # Handle potential system message case
105
  messages.append({"role": "user", "content": user_msg})
106
- if assistant_msg: # Skip if None (for the latest user message)
 
107
  messages.append({"role": "assistant", "content": assistant_msg})
108
 
109
  return messages
110
 
111
- # --- Core Generation Logic with Visualization Hook ---
112
-
113
- @spaces.GPU
114
- def generate_response_with_visualization(
115
- messages, # List of message dictionaries
116
- gen_length=64,
117
- steps=64,
118
- constraints_text="", # Raw constraint text
119
- temperature=0.2,
120
- top_p=0.95,
121
- top_k=None, # Added for Dream
122
- alg="entropy", # Changed from remasking
123
- alg_temp=0.0, # Added for Dream
124
- visualization_delay=0.05,
125
- tokenizer=tokenizer,
126
- model=model,
127
- device=device,
128
- MASK_ID=MASK_ID,
129
- EOS_ID=EOS_ID,
130
- PAD_ID=PAD_ID
131
- ):
132
  """
133
- Generate text with Dream model with real-time visualization using a hook.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  """
135
- visualization_states = []
136
- final_text = ""
137
- # Use a list to hold previous_x, allowing nonlocal modification
138
- # Initialize with None, it will be set after the first hook call
139
- shared_state = {'previous_x': None}
140
 
 
 
 
 
 
 
 
 
 
 
 
141
 
 
142
  try:
143
- # --- 1. Prepare Inputs ---
144
- _, parsed_constraints_tokens = parse_constraints(constraints_text, tokenizer)
145
-
146
- # Apply chat template
147
- # Important: Keep tokenize=False initially to get prompt length correctly
148
- # The template adds roles and special tokens like <|im_start|> etc.
149
- chat_input_text = tokenizer.apply_chat_template(
150
- messages,
151
- add_generation_prompt=True, # Adds the prompt for the assistant's turn
152
- tokenize=False
153
  )
154
-
155
- # Tokenize the full templated chat string
156
- inputs = tokenizer(chat_input_text, return_tensors="pt", return_dict=True)
157
  input_ids = inputs.input_ids.to(device)
158
- attention_mask = inputs.attention_mask.to(device) # Use mask from tokenizer
159
-
160
  prompt_length = input_ids.shape[1]
161
- total_length = prompt_length + gen_length
162
-
163
- # --- 2. Initialize Generation Sequence ---
164
- # Start with the prompt, pad the rest with MASK_ID
165
- x = torch.full((1, total_length), MASK_ID, dtype=torch.long, device=device)
166
- x[:, :prompt_length] = input_ids.clone()
167
- attention_mask = F.pad(attention_mask, (0, gen_length), value=1) # Extend attention mask
168
-
169
- # Apply initial constraints to the masked sequence `x`
170
- for pos, token_id in parsed_constraints_tokens.items():
171
- absolute_pos = prompt_length + pos
172
- if absolute_pos < total_length:
173
- print(f"Applying initial constraint at pos {absolute_pos}: token {token_id}")
174
- x[:, absolute_pos] = token_id
175
-
176
- # Store initial state (prompt + all masked) for visualization
177
- initial_state_vis = []
178
- # Add prompt tokens (optional visualization, could be grayed out or skipped)
179
- # for i in range(prompt_length):
180
- # token_str = tokenizer.decode([x[0, i].item()], skip_special_tokens=True)
181
- # initial_state_vis.append((token_str if token_str else " ", "#AAAAAA")) # Gray for prompt
182
-
183
- # Add masked tokens for the generation part
184
- for _ in range(gen_length):
185
- initial_state_vis.append((MASK_TOKEN, "#444444")) # Dark gray for masks
186
- visualization_states.append(initial_state_vis)
187
- shared_state['previous_x'] = x.clone() # Initialize previous_x
188
-
189
-
190
- # --- 3. Define the Visualization Hook ---
191
- def generation_tokens_hook_func(step, current_x_hook, logits):
192
- # nonlocal previous_x # Allow modification of the outer scope variable
193
- current_x_hook = current_x_hook.clone() # Work on a copy
194
-
195
- # --- Apply constraints within the hook ---
196
- # This ensures constraints are respected even if the model tries to overwrite them
197
- for pos, token_id in parsed_constraints_tokens.items():
198
- absolute_pos = prompt_length + pos
199
- if absolute_pos < total_length:
200
- current_x_hook[:, absolute_pos] = token_id
201
- # --- End Constraint Application ---
202
-
203
- if shared_state['previous_x'] is None: # First call
204
- shared_state['previous_x'] = current_x_hook.clone()
205
- return current_x_hook # Must return the (potentially modified) sequence
206
-
207
- # Generate visualization state for this step
208
- current_state_vis = []
209
- prev_x_step = shared_state['previous_x']
210
-
211
- for i in range(gen_length):
212
- pos = prompt_length + i # Absolute position in the sequence
213
- current_token_id = current_x_hook[0, pos].item()
214
- prev_token_id = prev_x_step[0, pos].item()
215
-
216
- # Decode token, handling special tokens we want to hide
217
- token_str = ""
218
- color = "#444444" # Default: Dark Gray (Mask)
219
- token_str_raw = tokenizer.decode([current_token_id], skip_special_tokens=False) # Keep special tokens for ID check
220
-
221
- if current_token_id == MASK_ID:
222
- token_str = MASK_TOKEN
223
- color = "#444444" # Dark gray
224
- elif current_token_id == EOS_ID or current_token_id == PAD_ID:
225
- token_str = "" # Hide EOS/PAD visually
226
- color = "#DDDDDD" # Use a light gray or make transparent if possible
227
- else:
228
- # Decode without special tokens for display if it's not MASK/EOS/PAD
229
- token_str = tokenizer.decode([current_token_id], skip_special_tokens=True).strip()
230
- if not token_str: token_str = token_str_raw # Fallback if strip removes everything (e.g., space)
231
-
232
- if prev_token_id == MASK_ID:
233
- # Newly revealed in this step
234
- color = "#66CC66" # Light green (Simplified from confidence levels)
235
- else:
236
- # Previously revealed
237
- color = "#6699CC" # Light blue
238
-
239
- current_state_vis.append((token_str if token_str else " ", color)) # Ensure non-empty tuple element
240
-
241
- visualization_states.append(current_state_vis)
242
- shared_state['previous_x'] = current_x_hook.clone() # Update previous_x for the next step
243
 
244
- return current_x_hook # Return the sequence (constraints applied)
 
 
 
 
 
245
 
246
- # --- 4. Run Diffusion Generation ---
247
- print("Starting diffusion generation...")
248
- start_time = time.time()
249
  output = model.diffusion_generate(
250
- input_ids=x[:, :prompt_length], # Pass only the initial prompt to diffusion_generate
251
- # as it handles the masking internally based on MASK_ID
252
- attention_mask=attention_mask, # Provide the full attention mask
253
  max_new_tokens=gen_length,
254
  output_history=False, # We capture history via the hook
255
  return_dict_in_generate=True,
256
  steps=steps,
257
  temperature=temperature,
258
- top_p=top_p,
259
- top_k=top_k,
260
  alg=alg,
261
- alg_temp=alg_temp if alg != 'origin' else None, # alg_temp only for confidence-based
262
- # Pass the hook function
263
- generation_tokens_hook_func=generation_tokens_hook_func,
264
- # Ensure the initial masked sequence `x` is used correctly if needed by internal logic
265
- # Depending on the exact implementation of diffusion_generate, passing x directly might be needed
266
- # Check Dream's generation_utils if issues arise. For now, assume it uses input_ids + max_new_tokens
267
  )
268
  end_time = time.time()
269
- print(f"Diffusion generation finished in {end_time - start_time:.2f} seconds.")
270
 
271
- # --- 5. Process Final Output ---
272
- # The hook has already built visualization_states
273
  final_sequence = output.sequences[0]
274
-
275
- # Decode the generated part, skipping special tokens for the final text output
276
  response_tokens = final_sequence[prompt_length:]
277
- # Filter out PAD tokens before final decode, keep EOS if needed conceptually, but skip for clean text
278
- response_tokens_cleaned = [tok for tok in response_tokens if tok != PAD_ID] # Keep EOS initially if needed
279
 
280
- final_text = tokenizer.decode(
281
- response_tokens_cleaned,
282
- skip_special_tokens=True, # Skip EOS, BOS, etc.
283
- clean_up_tokenization_spaces=True # Recommended for cleaner output
 
284
  ).strip()
285
 
286
- # Ensure the last state in visualization matches the final text (debug check)
287
- # print(f"Last Vis State Tokens: {''.join([t[0] for t in visualization_states[-1]]).strip()}")
288
- # print(f"Final Decoded Text: {final_text}")
289
 
290
  except Exception as e:
291
- print(f"Error during generation: {e}")
292
  import traceback
293
  traceback.print_exc()
294
- # Add error message to visualization
295
- error_msg = f"Error: {str(e)}"
296
- visualization_states.append([(error_msg, "red")])
297
- final_text = error_msg # Display error in the chatbot too
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
- # Make sure at least the initial state is present
300
- if not visualization_states:
301
- visualization_states.append([("Error: No states generated.", "red")])
 
 
 
 
 
 
302
 
 
 
303
 
304
- return visualization_states, final_text
 
305
 
306
- # --- Gradio UI Definition ---
 
307
 
 
 
 
 
308
  css = '''
309
  .category-legend{display:none}
310
- button{height: 60px}
311
- .token-text { white-space: pre; } /* Preserve spaces in tokens */
312
- footer { display: none !important; visibility: hidden !important; }
313
  '''
314
  def create_chatbot_demo():
315
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
316
  gr.Markdown("# Dream 7B - Diffusion Language Model Demo")
317
  gr.Markdown(
318
  "[[Model Card](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)] "
319
- "[[Blog Post](https://hkunlp.github.io/blog/2025/dream/)] "
320
- "(Note: Visualization shows token reveal steps, colors indicate status: Gray=Masked, Green=Newly Revealed, Blue=Previously Revealed)"
321
  )
322
 
323
  # STATE MANAGEMENT
324
  chat_history = gr.State([])
325
- # Store constraints parsed into token IDs
326
- parsed_constraints_state = gr.State({})
327
 
328
  # UI COMPONENTS
329
  with gr.Row():
@@ -331,8 +358,9 @@ def create_chatbot_demo():
331
  chatbot_ui = gr.Chatbot(
332
  label="Conversation",
333
  height=500,
334
- bubble_full_width=False # Makes bubbles wrap content
335
- )
 
336
 
337
  # Message input
338
  with gr.Group():
@@ -340,217 +368,129 @@ def create_chatbot_demo():
340
  user_input = gr.Textbox(
341
  label="Your Message",
342
  placeholder="Type your message here...",
 
 
343
  show_label=False,
344
- scale=7
345
  )
346
- send_btn = gr.Button("Send", scale=1)
 
347
 
348
  constraints_input = gr.Textbox(
349
- label="Word Constraints (Experimental)",
350
- info="Place specific words at positions (0-indexed). Format: 'pos:word, pos:word'. Example: '0:Once, 5:upon, 10:time'. Multi-token words supported.",
351
- placeholder="0:The, 10:story",
352
  value=""
353
  )
354
  with gr.Column(scale=2):
355
  output_vis = gr.HighlightedText(
356
  label="Denoising Process Visualization",
357
  combine_adjacent=False,
358
- show_legend=False, # Legend not very informative here
 
359
  )
360
 
361
  # Advanced generation settings
362
  with gr.Accordion("Generation Settings", open=False):
363
  with gr.Row():
364
  gen_length = gr.Slider(
365
- minimum=16, maximum=512, value=128, step=8,
366
  label="Max New Tokens"
367
  )
368
  steps = gr.Slider(
369
- minimum=8, maximum=512, value=128, step=4,
370
- label="Denoising Steps"
371
  )
372
  with gr.Row():
373
  temperature = gr.Slider(
374
- minimum=0.0, maximum=1.0, value=0.2, step=0.05,
375
  label="Temperature"
376
  )
 
 
 
 
 
 
377
  top_p = gr.Slider(
378
- minimum=0.1, maximum=1.0, value=0.95, step=0.05,
379
- label="Top-P"
380
  )
381
  top_k = gr.Slider(
382
  minimum=0, maximum=200, value=0, step=5,
383
  label="Top-K (0=disabled)"
384
  )
 
385
  with gr.Row():
386
- alg = gr.Radio(
387
  choices=['origin', 'maskgit_plus', 'topk_margin', 'entropy'],
388
- value='entropy',
389
- label="Sampling Algorithm (`alg`)"
390
- )
391
- alg_temp = gr.Slider(
392
- minimum=0.0, maximum=1.0, value=0.0, step=0.05,
393
- label="Algorithm Temp (`alg_temp`, adds randomness to confidence-based `alg`)"
394
- )
395
 
396
  with gr.Row():
397
  visualization_delay = gr.Slider(
398
- minimum=0.0, maximum=0.5, value=0.02, step=0.01,
399
  label="Visualization Delay (seconds)"
400
  )
401
 
402
  # Clear button
403
  clear_btn = gr.Button("Clear Conversation")
404
 
405
- # --- Event Handlers ---
406
- def add_message(history, message, response):
407
- """Add a message pair to the history and return the updated history"""
408
- # Ensure history is a list
409
- if not isinstance(history, list):
410
- history = []
411
- history.append([message, response])
412
- return history
413
-
414
- def user_message_submitted(message, history):
415
- """Process a submitted user message"""
416
- if not message.strip():
417
- return history, history, "", [] # No change if empty
418
-
419
- # Add user message (response is None for now)
420
- history = add_message(history, message, None)
421
-
422
- # Return updated history for display, clear input box
423
- return history, history, "", [] # history, chatbot_ui, user_input, output_vis
424
-
425
-
426
- def bot_response_stream(
427
- history, # Current chat history (list of lists)
428
- gen_length, steps, constraints, # Generation settings
429
- temperature, top_p, top_k, alg, alg_temp, # Sampling settings
430
- delay # Visualization delay
431
- ):
432
- """Generate bot response and stream visualization states"""
433
- if not history or history[-1][1] is not None: # Check if history is present and last response isn't already set
434
- print("Skipping bot response generation: No new user message.")
435
- # Yield empty state if needed to prevent errors downstream
436
- # Ensure history is returned correctly if nothing happens
437
- yield history, [], "Internal Error: No user message found."
438
- return
439
-
440
- # Format messages for the model
441
- # Exclude the last entry as it only contains the user message
442
- messages_for_model = format_chat_history(history) # Already includes system prompt
443
-
444
- print("\n--- Generating Bot Response ---")
445
- print(f"History: {history}")
446
- print(f"Messages for model: {messages_for_model}")
447
- print(f"Constraints text: '{constraints}'")
448
- print(f"Gen length: {gen_length}, Steps: {steps}, Temp: {temperature}, Top-P: {top_p}, Top-K: {top_k}, Alg: {alg}, Alg Temp: {alg_temp}")
449
-
450
- # Call the generation function
451
- vis_states, response_text = generate_response_with_visualization(
452
- messages_for_model,
453
- gen_length=gen_length,
454
- steps=steps,
455
- constraints_text=constraints,
456
- temperature=temperature,
457
- top_p=top_p if top_p < 1.0 else None, # None disables top-p
458
- top_k=top_k if top_k > 0 else None, # None disables top-k
459
- alg=alg,
460
- alg_temp=alg_temp,
461
- visualization_delay=delay,
462
- # Pass other necessary args like tokenizer, model if not global
463
- )
464
-
465
- print(f"Generated response text: '{response_text}'")
466
- print(f"Number of visualization states: {len(vis_states)}")
467
-
468
-
469
- # Update the history with the final response
470
- # Make sure history is mutable if needed or reassign
471
- if history:
472
- history[-1][1] = response_text
473
- else:
474
- print("Warning: History was empty when trying to update response.")
475
-
476
-
477
- # Stream the visualization states
478
- if not vis_states:
479
- print("Warning: No visualization states were generated.")
480
- # Yield something to prevent downstream errors
481
- yield history, [("Error: No visualization.", "red")], response_text
482
- return
483
 
484
- # Yield initial state immediately if desired, or just start loop
485
- # yield history, vis_states[0], response_text
486
 
487
- for state in vis_states:
488
- yield history, state, response_text # Yield updated history, current vis state, final text
489
- time.sleep(delay) # Pause between steps
 
 
490
 
491
- # Final yield to ensure the last state is displayed
492
- yield history, vis_states[-1], response_text
 
 
493
 
494
 
495
  def clear_conversation():
496
- """Clear the conversation history and visualization"""
497
- return [], [], "", [] # history, chatbot, user_input, output_vis
498
-
499
- # --- Event Wiring ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
 
501
- # Clear button
502
  clear_btn.click(
503
- fn=clear_conversation,
504
  inputs=[],
505
  outputs=[chat_history, chatbot_ui, user_input, output_vis]
506
  )
507
 
508
- # User message submission flow (2-step using .then)
509
- # 1. User submits message -> Update history and chatbot UI immediately
510
- submit_action = user_input.submit(
511
- fn=user_message_submitted,
512
- inputs=[user_input, chat_history],
513
- outputs=[chat_history, chatbot_ui, user_input, output_vis] # Update chatbot, clear input
514
- )
515
-
516
- # Connect send button to the same function
517
- send_action = send_btn.click(
518
- fn=user_message_submitted,
519
- inputs=[user_input, chat_history],
520
- outputs=[chat_history, chatbot_ui, user_input, output_vis]
521
- )
522
-
523
- # 2. After UI update -> Trigger bot response generation and streaming
524
- # Use the updated chat_history from the first step
525
- submit_action.then(
526
- fn=bot_response_stream,
527
- inputs=[
528
- chat_history, gen_length, steps, constraints_input,
529
- temperature, top_p, top_k, alg, alg_temp,
530
- visualization_delay
531
- ],
532
- outputs=[chatbot_ui, output_vis, user_input] # Update chatbot, visualization. Keep user_input as output to potentially display final text/error? (Check Gradio docs for Textbox output binding on yield) Let's remove user_input from outputs here.
533
- )
534
-
535
- send_action.then(
536
- fn=bot_response_stream,
537
- inputs=[
538
- chat_history, gen_length, steps, constraints_input,
539
- temperature, top_p, top_k, alg, alg_temp,
540
- visualization_delay
541
- ],
542
- outputs=[chatbot_ui, output_vis] # Update chatbot and visualization
543
- )
544
-
545
- # Clear input after send/submit (already handled in user_message_submitted)
546
- # submit_action.then(lambda: "", outputs=user_input)
547
- # send_action.then(lambda: "", outputs=user_input)
548
-
549
-
550
  return demo
551
 
552
- # --- Launch the Gradio App ---
553
  if __name__ == "__main__":
554
  demo = create_chatbot_demo()
555
- # Using queue for streaming and handling multiple users
556
- demo.queue().launch(debug=True, share=True)
 
2
  import torch
3
  import numpy as np
4
  import gradio as gr
5
+ import spaces # Ensure spaces is installed if needed for GPU decorator
6
  import torch.nn.functional as F
7
  from transformers import AutoTokenizer, AutoModel, AutoConfig
8
  import time
9
+ import re
10
+ from typing import List, Dict, Tuple, Optional
11
+
12
+ # Load model configuration to get special token IDs
13
+ config = AutoConfig.from_pretrained("Dream-org/Dream-v0-Instruct-7B", trust_remote_code=True)
14
+ # Use AutoModel for the base model loading, relying on trust_remote_code=True
15
+ # for the custom DreamModel class and generation mixin.
16
+ model_path = "Dream-org/Dream-v0-Instruct-7B"
17
 
18
  # Determine device
19
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
  print(f"Using device: {device}")
21
 
22
+ # Load model and tokenizer
23
+ print("Loading tokenizer...")
 
 
 
 
24
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
25
+ print("Loading model...")
26
+ # Ensure torch_dtype is set appropriately for your hardware if needed
27
  model = AutoModel.from_pretrained(
28
  model_path,
29
+ torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32, # Use bfloat16 only on CUDA
30
  trust_remote_code=True
31
+ )
32
+ model = model.to(device).eval()
33
+ print("Model loaded.")
34
 
35
+ # Constants from Dream's config/tokenizer
36
+ # Use attributes from loaded config/tokenizer objects
37
  MASK_TOKEN = tokenizer.mask_token
38
+ MASK_ID = config.mask_token_id
39
+ PAD_ID = config.pad_token_id
40
+ EOS_ID = config.eos_token_id
41
+ # Make sure EOS_ID and PAD_ID are handled correctly; Dream uses the same ID for both
42
+ SPECIAL_TOKEN_IDS = {PAD_ID, EOS_ID, MASK_ID}
43
+ # Add other special tokens defined in tokenizer_config.json if needed for hiding
44
+ # Get IDs for im_start, im_end etc. if they should also be hidden/handled specially
45
+ IM_START_ID = tokenizer.convert_tokens_to_ids("<|im_start|>")
46
+ IM_END_ID = tokenizer.convert_tokens_to_ids("<|im_end|>")
47
+ SPECIAL_TOKEN_IDS.add(IM_START_ID)
48
+ SPECIAL_TOKEN_IDS.add(IM_END_ID)
 
 
49
 
50
  # --- Helper Functions ---
51
 
52
+ def parse_constraints(constraints_text: str) -> Dict[int, List[int]]:
53
+ """
54
+ Parse constraints in format: 'position:word, position:word, ...'
55
+ Returns a dictionary mapping the starting position (0-indexed from the start
56
+ of the *generated* sequence) to a list of token IDs for the constraint word.
57
+ """
58
  constraints = {}
 
59
  if not constraints_text:
60
+ return constraints
61
 
62
  parts = constraints_text.split(',')
63
  for part in parts:
 
65
  continue
66
  pos_str, word = part.split(':', 1)
67
  try:
68
+ # Position relative to the start of the *generation*
69
  pos = int(pos_str.strip())
70
  word = word.strip()
71
+ # Tokenize the word - add leading space if not BOS? Dream handles spaces.
72
+ # Check Dream tokenizer behavior for spaces. Assuming standard behavior:
73
+ token_ids = tokenizer.encode(" " + word if pos > 0 else word, add_special_tokens=False)
74
+
75
+ if token_ids and pos >= 0:
76
+ constraints[pos] = token_ids
 
 
 
 
 
77
  except ValueError:
78
+ continue # Ignore malformed constraint parts
79
  except Exception as e:
80
+ print(f"Warning: Error processing constraint '{part}': {e}")
81
+ continue
82
+
83
+ return constraints
84
 
 
 
 
 
 
85
 
86
+ def format_chat_history(history: List[List[Optional[str]]]) -> List[Dict[str, str]]:
87
  """
88
+ Format chat history for the Dream model's chat template.
89
 
90
  Args:
91
+ history: List of [user_message, assistant_message] pairs.
92
+ The last assistant_message might be None.
93
 
94
  Returns:
95
+ Formatted list of message dictionaries for tokenizer.apply_chat_template.
96
  """
97
  messages = []
98
+ # Check if the first message is a system prompt, handle accordingly if needed
99
+ # Based on Dream's examples, the template adds a default system prompt if none exists.
100
+ # If history starts with System, it should be handled by the template.
101
+ # Let's assume the template handles the system prompt correctly.
102
 
103
  for user_msg, assistant_msg in history:
104
+ if user_msg: # Defensive check
105
  messages.append({"role": "user", "content": user_msg})
106
+ # Add assistant message only if it exists (it won't for the last turn before generation)
107
+ if assistant_msg:
108
  messages.append({"role": "assistant", "content": assistant_msg})
109
 
110
  return messages
111
 
112
+ # --- Core Generation Logic with Live Visualization ---
113
+
114
+ @spaces.GPU # Decorator for Hugging Face Spaces GPU usage
115
+ def generate_dream_response(
116
+ history: List[List[Optional[str]]],
117
+ gen_length: int,
118
+ steps: int,
119
+ constraints_text: str,
120
+ temperature: float,
121
+ top_p: Optional[float],
122
+ top_k: Optional[int],
123
+ alg: str,
124
+ alg_temp: Optional[float],
125
+ visualization_delay: float
126
+ ) -> List[Tuple[str, str]]:
 
 
 
 
 
 
127
  """
128
+ Generates text using the Dream model and yields visualization states live.
129
+
130
+ Args:
131
+ history: Chat history.
132
+ gen_length: Max new tokens to generate.
133
+ steps: Number of diffusion steps.
134
+ constraints_text: User-provided constraints string.
135
+ temperature: Sampling temperature.
136
+ top_p: Top-p sampling nucleus.
137
+ top_k: Top-k sampling.
138
+ alg: Remasking algorithm ('origin', 'maskgit_plus', 'topk_margin', 'entropy').
139
+ alg_temp: Temperature for confidence-based algorithms.
140
+ visualization_delay: Delay between visualization steps.
141
+
142
+ Yields:
143
+ Tuple[List[List[Optional[str]]], List[Tuple[str, Optional[str]]], str]:
144
+ - Updated history
145
+ - Visualization data for HighlightedText
146
+ - Final response text (repeated in each yield)
147
  """
 
 
 
 
 
148
 
149
+ if not history or not history[-1][0]:
150
+ # No user message to respond to
151
+ yield history, [("No input message found.", "red")], ""
152
+ return
153
+
154
+ # --- 1. Preparation ---
155
+ last_user_message = history[-1][0]
156
+ messages_for_template = format_chat_history(history) # Includes the latest user message
157
+
158
+ # Parse constraints relative to the *generated* sequence
159
+ parsed_constraints = parse_constraints(constraints_text) # Dict[rel_pos, List[token_id]]
160
 
161
+ # Prepare inputs using the chat template
162
  try:
163
+ inputs = tokenizer.apply_chat_template(
164
+ messages_for_template,
165
+ return_tensors="pt",
166
+ return_dict=True,
167
+ add_generation_prompt=True # Important for instruct models
 
 
 
 
 
168
  )
 
 
 
169
  input_ids = inputs.input_ids.to(device)
170
+ attention_mask = inputs.attention_mask.to(device)
 
171
  prompt_length = input_ids.shape[1]
172
+ except Exception as e:
173
+ print(f"Error applying chat template: {e}")
174
+ yield history, [("Error preparing input.", "red")], ""
175
+ return
176
+
177
+ # Calculate total sequence length for the model
178
+ # Max length constraint from model config (e.g., 2048 for original Dream?)
179
+ # Let's use a reasonable default or allow configuration if needed.
180
+ # The provided code uses max_position_embeddings=131072, let's stick to user input + gen_length.
181
+ total_length = prompt_length + gen_length
182
+
183
+ # --- 2. Visualization Setup ---
184
+ # This list will store the token sequence (just the generated part) at each step
185
+ step_sequence_history: List[torch.Tensor] = []
186
+ previous_step_tokens = None # Keep track of the previous step's state
187
+
188
+ # Define the hook function *inside* this function to capture state
189
+ def live_visualization_hook(step: Optional[int], x: torch.Tensor, logits: Optional[torch.Tensor]) -> torch.Tensor:
190
+ nonlocal step_sequence_history, parsed_constraints, prompt_length
191
+
192
+ # --- Apply Constraints ---
193
+ # Constraints are applied *after* the model proposes tokens but *before* they are finalized for the step
194
+ # Note: The hook receives the state *before* the next model call in the next step,
195
+ # or the final state after the last step. Let's apply constraints consistently.
196
+ # The `diffusion_generate` calls the hook *after* updating x based on sampling.
197
+ current_x = x.clone() # Work on a copy
198
+
199
+ for rel_pos, word_token_ids in parsed_constraints.items():
200
+ abs_start_pos = prompt_length + rel_pos
201
+ abs_end_pos = abs_start_pos + len(word_token_ids)
202
+
203
+ # Ensure the constraint fits within the generation length
204
+ if abs_start_pos < total_length and abs_end_pos <= total_length:
205
+ try:
206
+ constraint_tensor = torch.tensor(word_token_ids, dtype=torch.long, device=current_x.device)
207
+ # Force the constraint tokens onto the sequence
208
+ current_x[0, abs_start_pos:abs_end_pos] = constraint_tensor
209
+ except IndexError:
210
+ print(f"Warning: Constraint at {rel_pos} ('{tokenizer.decode(word_token_ids)}') goes out of bounds.")
211
+ except Exception as e:
212
+ print(f"Warning: Failed to apply constraint at {rel_pos}: {e}")
213
+
214
+ # Store the state *after* constraints for visualization
215
+ # We only need the generated part
216
+ generated_part = current_x[0, prompt_length:].clone().cpu() # Move to CPU to save GPU memory
217
+ step_sequence_history.append(generated_part)
218
+
219
+ # Return the (potentially modified by constraints) tensor x
220
+ return current_x # Pass the constrained version to the next step
221
+
222
+ # --- 3. Run Generation ---
223
+ final_response_text = ""
224
+ try:
225
+ print(f"Starting Dream generation: prompt_len={prompt_length}, gen_len={gen_length}, steps={steps}")
226
+ start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
+ # Initial masked state for visualization
229
+ initial_generated_state = torch.full((gen_length,), MASK_ID, dtype=torch.long)
230
+ # Apply constraints to the *initial* visual state if they start at pos 0
231
+ temp_initial_x = torch.cat((input_ids[0], initial_generated_state.to(device)), dim=0).unsqueeze(0)
232
+ initial_vis_x = live_visualization_hook(None, temp_initial_x, None) # Apply constraints via hook logic
233
+ step_sequence_history.insert(0, initial_vis_x[0, prompt_length:].cpu()) # Prepend initial state
234
 
 
 
 
235
  output = model.diffusion_generate(
236
+ input_ids,
237
+ attention_mask=attention_mask,
 
238
  max_new_tokens=gen_length,
239
  output_history=False, # We capture history via the hook
240
  return_dict_in_generate=True,
241
  steps=steps,
242
  temperature=temperature,
243
+ top_p=top_p if top_p is not None and top_p < 1.0 else None, # Ensure top_p < 1 or None
244
+ top_k=top_k if top_k is not None and top_k > 0 else None, # Ensure top_k > 0 or None
245
  alg=alg,
246
+ alg_temp=alg_temp if alg in ['maskgit_plus', 'topk_margin', 'entropy'] else None, # Only relevant for some algs
247
+ generation_tokens_hook_func=live_visualization_hook
 
 
 
 
248
  )
249
  end_time = time.time()
250
+ print(f"Dream generation finished in {end_time - start_time:.2f} seconds.")
251
 
252
+ # --- 4. Process Final Output ---
 
253
  final_sequence = output.sequences[0]
 
 
254
  response_tokens = final_sequence[prompt_length:]
 
 
255
 
256
+ # Decode the final response text
257
+ final_response_text = tokenizer.decode(
258
+ response_tokens,
259
+ skip_special_tokens=True, # Skip EOS, PAD, MASK etc. in the final text
260
+ clean_up_tokenization_spaces=True
261
  ).strip()
262
 
263
+ # Update history with the final response
264
+ history[-1][1] = final_response_text
 
265
 
266
  except Exception as e:
267
+ print(f"Error during generation or processing: {e}")
268
  import traceback
269
  traceback.print_exc()
270
+ yield history, [("Error during generation.", "red")], ""
271
+ return
272
+
273
+ # --- 5. Stream Visualization ---
274
+ print(f"Streaming {len(step_sequence_history)} visualization steps...")
275
+ previous_tokens_vis = None
276
+ for i, current_tokens_vis in enumerate(step_sequence_history):
277
+ # print(f" Step {i}: {current_tokens_vis.tolist()}") # Debug
278
+ vis_data = []
279
+ current_decoded_tokens = []
280
+
281
+ # Compare current step tokens with previous step tokens
282
+ for j in range(gen_length):
283
+ current_tok_id = current_tokens_vis[j].item()
284
+ previous_tok_id = previous_tokens_vis[j].item() if previous_tokens_vis is not None else MASK_ID
285
+
286
+ # Decode token - handle potential errors for single IDs if needed
287
+ try:
288
+ # Use skip_special_tokens=False here to see the actual tokens
289
+ decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False)
290
+ # Explicitly handle mask token display
291
+ if current_tok_id == MASK_ID:
292
+ display_token = MASK_TOKEN
293
+ else:
294
+ display_token = decoded_token
295
+
296
+ except Exception:
297
+ display_token = f"[ID:{current_tok_id}]" # Fallback
298
+
299
+ # Determine color and handle hiding of special tokens (like LLaDA demo)
300
+ color = None
301
+ token_to_display = display_token
302
+
303
+ if current_tok_id == MASK_ID:
304
+ color = "#444444" # Dark Gray for masks
305
+ elif previous_tok_id == MASK_ID: # Token was just revealed
306
+ # Simple green for newly revealed, no confidence score available from hook
307
+ color = "#66CC66" # Light Green
308
+ else: # Token was already revealed
309
+ color = "#6699CC" # Light Blue
310
+
311
+ # LLaDA hiding effect: If it's a special token (EOS/PAD) *and* it was revealed before this step, hide it.
312
+ if current_tok_id in {PAD_ID, EOS_ID} and previous_tok_id == current_tok_id:
313
+ # Hide by making it empty or using a background color - empty string is simpler
314
+ token_to_display = ""
315
+ color = "#FFFFFF" # Or just make it blend in
316
 
317
+ # Add token and color to visualization data
318
+ if token_to_display: # Avoid adding empty strings if hiding
319
+ vis_data.append((token_to_display, color))
320
+ elif len(vis_data) > 0 and isinstance(vis_data[-1], tuple):
321
+ # If hidden, and previous was text, add a space for visual separation?
322
+ # This might complicate things, let's omit for now.
323
+ pass
324
+ # elif len(vis_data) == 0: # If first token is hidden
325
+ # vis_data.append(("", None)) # Placeholder?
326
 
327
+ # Update previous state for next iteration
328
+ previous_tokens_vis = current_tokens_vis
329
 
330
+ # Yield the current visualization state
331
+ yield history, vis_data, final_response_text
332
 
333
+ # Pause for the specified delay
334
+ time.sleep(visualization_delay)
335
 
336
+ print("Visualization streaming complete.")
337
+
338
+
339
+ # --- Gradio UI ---
340
  css = '''
341
  .category-legend{display:none}
342
+ button{min-height: 60px}
 
 
343
  '''
344
  def create_chatbot_demo():
345
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
346
  gr.Markdown("# Dream 7B - Diffusion Language Model Demo")
347
  gr.Markdown(
348
  "[[Model Card](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)] "
349
+ "[[Blog](https://hkunlp.github.io/blog/2025/dream/)]"
 
350
  )
351
 
352
  # STATE MANAGEMENT
353
  chat_history = gr.State([])
 
 
354
 
355
  # UI COMPONENTS
356
  with gr.Row():
 
358
  chatbot_ui = gr.Chatbot(
359
  label="Conversation",
360
  height=500,
361
+ show_copy_button=True,
362
+ bubble_full_width=False
363
+ )
364
 
365
  # Message input
366
  with gr.Group():
 
368
  user_input = gr.Textbox(
369
  label="Your Message",
370
  placeholder="Type your message here...",
371
+ scale=7,
372
+ autofocus=True,
373
  show_label=False,
374
+ container=False # Remove container for tighter packing
375
  )
376
+ send_btn = gr.Button("Send", scale=1, variant="primary")
377
+
378
 
379
  constraints_input = gr.Textbox(
380
+ label="Word Constraints (Optional)",
381
+ info="Place words at specific positions (0-indexed from start of generation). Format: 'pos:word, pos:word,...'. Example: '0:Once, 5:upon, 10:time'",
382
+ placeholder="0:Hello, 10:world",
383
  value=""
384
  )
385
  with gr.Column(scale=2):
386
  output_vis = gr.HighlightedText(
387
  label="Denoising Process Visualization",
388
  combine_adjacent=False,
389
+ show_legend=False, # Legend isn't very informative here
390
+ interactive=False # Not interactive
391
  )
392
 
393
  # Advanced generation settings
394
  with gr.Accordion("Generation Settings", open=False):
395
  with gr.Row():
396
  gen_length = gr.Slider(
397
+ minimum=16, maximum=512, value=128, step=8, # Increased max length
398
  label="Max New Tokens"
399
  )
400
  steps = gr.Slider(
401
+ minimum=8, maximum=512, value=128, step=8, # Increased max steps
402
+ label="Diffusion Steps"
403
  )
404
  with gr.Row():
405
  temperature = gr.Slider(
406
+ minimum=0.0, maximum=1.0, value=0.4, step=0.05,
407
  label="Temperature"
408
  )
409
+ alg_temp = gr.Slider(
410
+ minimum=0.0, maximum=1.0, value=0.1, step=0.05,
411
+ label="Remasking Temp (for confidence algs)"
412
+ )
413
+
414
+ with gr.Row():
415
  top_p = gr.Slider(
416
+ minimum=0.0, maximum=1.0, value=0.95, step=0.05,
417
+ label="Top-P (0=disabled)"
418
  )
419
  top_k = gr.Slider(
420
  minimum=0, maximum=200, value=0, step=5,
421
  label="Top-K (0=disabled)"
422
  )
423
+
424
  with gr.Row():
425
+ remasking_strategy = gr.Radio(
426
  choices=['origin', 'maskgit_plus', 'topk_margin', 'entropy'],
427
+ value='entropy', # Default to entropy as in example
428
+ label="Remasking Strategy (Algorithm)"
429
+ )
 
 
 
 
430
 
431
  with gr.Row():
432
  visualization_delay = gr.Slider(
433
+ minimum=0.0, maximum=0.5, value=0.02, step=0.01, # Faster default
434
  label="Visualization Delay (seconds)"
435
  )
436
 
437
  # Clear button
438
  clear_btn = gr.Button("Clear Conversation")
439
 
440
+ # Current response text box (hidden, maybe useful for debugging)
441
+ # current_response = gr.Textbox(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
 
443
+ # --- Event Handlers ---
 
444
 
445
+ def add_user_message_to_history(message: str, history: List[List[Optional[str]]]):
446
+ """Adds user message, clears input, prepares for bot response."""
447
+ if not message.strip():
448
+ gr.Warning("Please enter a message.")
449
+ return history, history, "", [("Enter a message", "grey")] # Keep vis empty or show prompt
450
 
451
+ # Add user message with placeholder for bot response
452
+ history.append([message, None])
453
+ # Return updated history for chatbot, empty input box, empty visualization
454
+ return history, history, "", []
455
 
456
 
457
  def clear_conversation():
458
+ """Clears the chat history and visualization."""
459
+ return [], [], "", []
460
+
461
+ # --- Connect UI elements ---
462
+
463
+ # User Input Submission (Textbox Enter or Send Button Click)
464
+ submit_triggers = [user_input.submit, send_btn.click]
465
+
466
+ # 1. Add user message to UI immediately
467
+ for trigger in submit_triggers:
468
+ trigger.then(
469
+ add_user_message_to_history,
470
+ inputs=[user_input, chat_history],
471
+ outputs=[chat_history, chatbot_ui, user_input, output_vis] # Update chat, clear input, clear vis
472
+ ).then( # 2. Trigger bot response generation (as a generator)
473
+ generate_dream_response,
474
+ inputs=[
475
+ chat_history, gen_length, steps, constraints_input,
476
+ temperature, top_p, top_k, remasking_strategy, alg_temp,
477
+ visualization_delay
478
+ ],
479
+ outputs=[chatbot_ui, output_vis] # Stream updates to chatbot and visualization
480
+ # Note: The final text response is implicitly handled by updating chatbot_ui
481
+ )
482
 
483
+ # Clear Button Action
484
  clear_btn.click(
485
+ clear_conversation,
486
  inputs=[],
487
  outputs=[chat_history, chatbot_ui, user_input, output_vis]
488
  )
489
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
  return demo
491
 
492
+ # --- Launch ---
493
  if __name__ == "__main__":
494
  demo = create_chatbot_demo()
495
+ # Use queue for handling multiple users and streaming
496
+ demo.queue().launch(debug=True, share=True) # Add share=True for public link if needed