multimodalart HF Staff commited on
Commit
7c32497
·
verified ·
1 Parent(s): 3d09f97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +241 -147
app.py CHANGED
@@ -7,50 +7,90 @@ import torch.nn.functional as F
7
  from transformers import AutoTokenizer, AutoModel, AutoConfig
8
  import time
9
  import re
10
- from typing import List, Dict, Tuple, Optional
11
  import torch.distributions as dists # Added import
 
12
 
13
  # --- START: Copied Helper functions from generation_utils.py ---
14
  # [Keep the copied functions: top_p_logits, top_k_logits, sample_tokens]
15
  def top_p_logits(logits, top_p=None):
16
- if top_p is None or top_p >= 1.0: return logits
 
 
17
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
18
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
19
  sorted_indices_to_remove = cumulative_probs > top_p
20
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone(); sorted_indices_to_remove[..., 0] = 0
21
- mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device).scatter_(-1, sorted_indices, sorted_indices_to_remove)
22
- return logits.masked_fill(mask, torch.finfo(logits.dtype).min)
 
 
 
23
 
24
  def top_k_logits(logits, top_k=None):
25
- if top_k is None or top_k <= 0: return logits
 
 
26
  top_k = min(top_k, logits.size(-1))
27
  indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
28
- return logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
 
29
 
30
  def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
31
- if temperature > 0: safe_temp = max(temperature, 1e-6); logits = logits / safe_temp
32
- if top_p is not None and 0.0 < top_p < 1.0: logits = top_p_logits(logits, top_p)
33
- if top_k is not None and top_k > 0: logits = top_k_logits(logits, top_k)
34
- is_all_neg_inf = torch.all(logits == torch.finfo(logits.dtype).min, dim=-1, keepdim=True)
35
- if torch.any(is_all_neg_inf): uniform_logits = torch.zeros_like(logits); logits = torch.where(is_all_neg_inf, uniform_logits, logits)
 
 
 
 
 
 
 
 
 
36
  probs = torch.softmax(logits, dim=-1)
37
- probs = torch.clamp(probs, min=0.0); probs = probs / probs.sum(dim=-1, keepdim=True); probs = torch.nan_to_num(probs, nan=0.0)
 
 
 
 
 
38
  if temperature > 0:
39
- try: x0 = dists.Categorical(probs=probs).sample(); confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
40
- except Exception as e: print(f"Warning: Sampling failed: {e}. Argmax fallback."); confidence, x0 = probs.max(dim=-1)
41
- else: confidence, x0 = probs.max(dim=-1)
42
- if margin_confidence: sorted_probs, _ = torch.sort(probs, dim=-1, descending=True); top1_probs = sorted_probs[..., 0]; top2_probs = sorted_probs[..., 1] if sorted_probs.shape[-1] > 1 else top1_probs; confidence = top1_probs - top2_probs
43
- if neg_entropy: epsilon = 1e-10; log_probs = torch.log(probs + epsilon); confidence = torch.sum(probs * log_probs, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  confidence = torch.nan_to_num(confidence, nan=0.0)
45
  return confidence, x0
46
  # --- END: Copied Helper functions ---
47
 
48
- # [Keep model loading, constants as before]
 
49
  # Load model configuration to get special token IDs
50
  config = AutoConfig.from_pretrained("Dream-org/Dream-v0-Instruct-7B", trust_remote_code=True)
51
  model_path = "Dream-org/Dream-v0-Instruct-7B"
52
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
53
  print(f"Using device: {device}")
 
54
  print("Loading tokenizer...")
55
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
56
  print("Loading model...")
@@ -58,25 +98,34 @@ model = AutoModel.from_pretrained(
58
  model_path,
59
  torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32,
60
  trust_remote_code=True,
61
- attn_implementation="sdpa"
62
  )
63
  model = model.to(device).eval()
64
  print("Model loaded.")
 
65
  MASK_TOKEN = tokenizer.mask_token
66
  MASK_ID = tokenizer.mask_token_id
67
  PAD_ID = tokenizer.pad_token_id
68
  EOS_ID = tokenizer.eos_token_id
69
- if MASK_ID is None: raise ValueError("Cannot determine MASK_ID.")
 
 
 
70
  SPECIAL_TOKEN_IDS = {PAD_ID, EOS_ID, MASK_ID}
71
  try:
72
  IM_START_ID = tokenizer.convert_tokens_to_ids("<|im_start|>")
73
  IM_END_ID = tokenizer.convert_tokens_to_ids("<|im_end|>")
74
  SPECIAL_TOKEN_IDS.add(IM_START_ID)
75
  SPECIAL_TOKEN_IDS.add(IM_END_ID)
76
- except KeyError: IM_START_ID, IM_END_ID = None, None
 
 
 
 
77
 
78
  # --- Helper Functions ---
79
  def parse_constraints(constraints_text: str) -> Dict[int, List[int]]:
 
80
  constraints = {}
81
  if not constraints_text: return constraints
82
  parts = constraints_text.split(',')
@@ -88,26 +137,35 @@ def parse_constraints(constraints_text: str) -> Dict[int, List[int]]:
88
  pos = int(pos_str.strip())
89
  word = word.strip()
90
  token_ids = []
91
- if word: text_to_encode = (" " + word) if (pos > 0 and not word.startswith(" ")) else word; token_ids = tokenizer.encode(text_to_encode, add_special_tokens=False)
 
 
92
  if token_ids and pos >= 0: constraints[pos] = token_ids
93
  elif not token_ids and word: print(f"Warning: Could not tokenize constraint word '{word}'")
94
  except ValueError: print(f"Warning: Invalid position '{pos_str}' in constraint part '{part}'")
95
  except Exception as e: print(f"Warning: Error processing constraint '{part}': {e}")
96
  return constraints
97
 
98
- # Removed format_chat_history as history will be in the correct format
99
 
100
  def apply_constraints_to_state(
101
- x: torch.Tensor, prompt_length: int, total_length: int,
102
- parsed_constraints: Dict[int, List[int]], current_step: Optional[int] = None
 
 
 
103
  ) -> torch.Tensor:
 
104
  modified_x = x.clone()
105
  for rel_pos, word_token_ids in parsed_constraints.items():
106
- abs_start_pos = prompt_length + rel_pos; abs_end_pos = abs_start_pos + len(word_token_ids)
 
107
  if abs_start_pos < total_length and abs_end_pos <= total_length:
108
- try: constraint_tensor = torch.tensor(word_token_ids, dtype=torch.long, device=modified_x.device); modified_x[0, abs_start_pos:abs_end_pos] = constraint_tensor
109
- except IndexError: print(f"Warning (Step {current_step}): Constraint idx error at {rel_pos}")
110
- except Exception as e: print(f"Warning (Step {current_step}): Constraint apply error at {rel_pos}: {e}")
 
 
111
  return modified_x
112
 
113
 
@@ -116,7 +174,7 @@ def apply_constraints_to_state(
116
  @spaces.GPU
117
  @torch.no_grad()
118
  def generate_dream_response(
119
- history: List[Dict[str, str]], # MODIFIED: Expect List[Dict]
120
  gen_length: int,
121
  steps: int,
122
  constraints_text: str,
@@ -126,32 +184,35 @@ def generate_dream_response(
126
  alg: str,
127
  alg_temp: Optional[float],
128
  visualization_delay: float
129
- ): # Removed -> type hint for cleaner yield handling
130
  """ Generates text step-by-step and yields visualization states live. """
131
 
132
- if not history or history[-1]["role"] != "user": # Check last message is from user
133
- yield history, [("No user message found to respond to.", "red")]
 
 
134
  return
135
 
136
  # --- 1. Preparation ---
137
- # History is already formatted for the template
138
  parsed_constraints = parse_constraints(constraints_text)
139
 
 
 
 
140
  try:
141
- # apply_chat_template expects List[Dict[str, str]]
142
  inputs = tokenizer.apply_chat_template(
143
- history, # Use history directly
144
  return_tensors="pt",
145
  return_dict=True,
146
- add_generation_prompt=True # Crucial: Adds the "<|im_start|>assistant\n" prompt
147
  )
148
  input_ids = inputs.input_ids.to(device)
149
  prompt_attention_mask = inputs.attention_mask.to(device) if 'attention_mask' in inputs else torch.ones_like(input_ids)
150
- prompt_length = input_ids.shape[1] # Length *after* adding the generation prompt
151
  except Exception as e:
152
  print(f"Error applying chat template: {e}")
153
- # Yield current history and error message for visualization
154
- yield history, [("Error preparing input.", "red")]
155
  return
156
 
157
  eps = 1e-3
@@ -162,12 +223,9 @@ def generate_dream_response(
162
  # --- 2. Initialize Generation State ---
163
  total_length = prompt_length + gen_length
164
  initial_generation_part = torch.full((1, gen_length), MASK_ID, dtype=torch.long, device=device)
165
- # input_ids already includes the assistant prompt, so just append masks
166
  x = torch.cat((input_ids, initial_generation_part), dim=1)
167
 
168
- # --- Prepare Attention Mask for SDPA ---
169
  generation_attention_mask = torch.ones((1, gen_length), dtype=torch.long, device=device)
170
- # prompt_attention_mask corresponds to input_ids (which includes assistant prompt)
171
  full_attention_mask_long = torch.cat((prompt_attention_mask, generation_attention_mask), dim=1)
172
 
173
  attention_mask_for_model = full_attention_mask_long.to(model.dtype)
@@ -175,27 +233,27 @@ def generate_dream_response(
175
  attention_mask_for_model = (1.0 - attention_mask_for_model) * large_neg_val
176
  attention_mask_for_model = attention_mask_for_model.unsqueeze(1).unsqueeze(2) # [B, 1, 1, N]
177
 
178
- # --- Timesteps ---
179
  timesteps = torch.linspace(1, eps, steps + 1, device=device)
180
-
181
- # Apply initial constraints (relative to start of generation = prompt_length)
182
  x = apply_constraints_to_state(x, prompt_length, total_length, parsed_constraints, current_step=-1)
183
 
184
  # --- 3. Visualization & History Setup ---
185
  previous_tokens_vis = None
186
- # MODIFIED: Append placeholder assistant message to the history state *before* looping
187
- history.append({"role": "assistant", "content": ""})
 
 
188
 
189
  # --- 4. Initial Yield (Masked State) ---
190
  initial_generated_tokens = x[0, prompt_length:].cpu()
191
  vis_data_initial = []
192
  for tok_id in initial_generated_tokens.tolist():
193
- display_token = MASK_TOKEN; color = "#444444"
 
194
  vis_data_initial.append((display_token, color))
195
 
196
  previous_tokens_vis = initial_generated_tokens
197
- # Yield the history (which now includes the empty assistant message) and initial vis
198
- yield history, vis_data_initial
199
  time.sleep(visualization_delay)
200
 
201
  # --- 5. Step-by-Step Diffusion Loop ---
@@ -203,70 +261,106 @@ def generate_dream_response(
203
  start_time = time.time()
204
  for i in range(steps):
205
  mask_index = (x == MASK_ID)
206
- if not mask_index.any(): break # Stop early
207
-
208
- outputs = model(input_ids=x, attention_mask=attention_mask_for_model, return_dict=True)
 
 
 
 
 
 
209
  logits = outputs.logits
210
- logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1) # Align logits
211
 
212
  mask_logits = logits[mask_index]
213
- if mask_logits.numel() == 0: break # Stop early
 
 
214
 
215
- t = timesteps[i]; s = timesteps[i + 1]
 
216
  x_new_masked_part = torch.full_like(x[mask_index], MASK_ID, device=device, dtype=torch.long)
217
 
218
- # [Keep sampling/remasking logic ('origin' and confidence-based) exactly the same]
219
  if alg == 'origin':
220
  p_transfer = (1.0 - s / t) if i < steps - 1 else 1.0
221
  num_masked = mask_logits.shape[0]
222
  transfer_indices_relative = torch.rand(num_masked, device=device) < p_transfer
223
  logits_to_sample = mask_logits[transfer_indices_relative]
224
- if logits_to_sample.numel() > 0: _, sampled_tokens = sample_tokens(logits_to_sample, temperature=temperature, top_p=top_p_val, top_k=top_k_val); x_new_masked_part[transfer_indices_relative] = sampled_tokens
225
- else:
226
- use_margin=(alg == 'topk_margin'); use_entropy=(alg == 'entropy')
 
 
 
 
 
 
 
227
  confidence, x0_candidates = sample_tokens(mask_logits, temperature=temperature, top_p=top_p_val, top_k=top_k_val, margin_confidence=use_margin, neg_entropy=use_entropy)
 
228
  num_mask_token = mask_logits.shape[0]
229
  target_num_revealed_float = num_mask_token * (1.0 - s / t)
230
  number_transfer_tokens = int(target_num_revealed_float) if i < steps - 1 else num_mask_token
 
231
  if number_transfer_tokens > 0:
232
  num_samples = min(number_transfer_tokens, num_mask_token)
233
  if num_samples > 0:
234
- transfer_indices_relative = torch.tensor([], dtype=torch.long, device=device)
235
- if alg_temp_val is None or alg_temp_val <= 0: # Top-k confidence
236
  sort_metric = confidence if alg != 'entropy' else -confidence
237
  k_topk = min(num_samples, sort_metric.numel())
238
  if k_topk > 0: _, transfer_indices_relative = torch.topk(sort_metric, k=k_topk)
239
- else: # Sample based on confidence temperature
240
  if confidence.numel() > 0:
241
- conf_probs = confidence / alg_temp_val; conf_probs = torch.nan_to_num(conf_probs, nan=0.0, posinf=1e9, neginf=-1e9); conf_probs = torch.clamp(conf_probs - conf_probs.max(), min=-30); conf_probs = F.softmax(conf_probs, dim=-1); conf_probs = torch.clamp(conf_probs, min=0.0); conf_probs = torch.nan_to_num(conf_probs, nan=0.0)
242
- prob_sum = conf_probs.sum(); target_sum_tensor = torch.tensor(1.0, device=device, dtype=prob_sum.dtype)
243
- if not torch.isclose(prob_sum, target_sum_tensor, atol=1e-4) and prob_sum > 0: safe_prob_sum = torch.max(prob_sum, torch.tensor(1e-12, device=device, dtype=prob_sum.dtype)); conf_probs = conf_probs / safe_prob_sum
 
 
 
 
 
 
 
 
244
  final_prob_sum_check = conf_probs.sum()
245
  if conf_probs.numel() > 0 and num_samples > 0 and torch.all(conf_probs >= 0) and torch.isclose(final_prob_sum_check, target_sum_tensor, atol=1e-4):
246
  try: transfer_indices_relative = torch.multinomial(conf_probs, num_samples=num_samples, replacement=False)
247
- except RuntimeError as e: print(f"Warning step {i}: Multinomial failed ('{e}'). Fallback."); sort_metric = confidence if alg != 'entropy' else -confidence; k_fallback = min(num_samples, sort_metric.numel()); if k_fallback > 0: _, transfer_indices_relative = torch.topk(sort_metric, k=k_fallback)
248
- else: sort_metric = confidence if alg != 'entropy' else -confidence; k_fallback = min(num_samples, sort_metric.numel()); if k_fallback > 0: _, transfer_indices_relative = torch.topk(sort_metric, k=k_fallback)
 
 
 
 
 
 
 
 
 
249
  # Apply transfer
250
  if transfer_indices_relative.numel() > 0:
251
- valid_indices = transfer_indices_relative < x0_candidates.shape[0]; valid_transfer_indices = transfer_indices_relative[valid_indices]
252
- if valid_transfer_indices.numel() > 0:
253
- if valid_transfer_indices.max() < x_new_masked_part.shape[0]: x_new_masked_part[valid_transfer_indices] = x0_candidates[valid_transfer_indices].clone()
254
- else: print(f"Warning step {i}: transfer_indices OOB for x_new_masked_part.")
 
255
 
256
- x[mask_index] = x_new_masked_part # Update state
257
 
258
- # --- Apply Constraints ---
259
- # Remember prompt_length now includes the assistant prompt turn
260
  x = apply_constraints_to_state(x, prompt_length, total_length, parsed_constraints, current_step=i)
261
 
262
- # --- Yield Visualization ---
263
  current_generated_tokens = x[0, prompt_length:].cpu()
264
  vis_data = []
265
- # [Keep visualization formatting logic the same]
266
  for j in range(gen_length):
267
  current_tok_id = current_generated_tokens[j].item()
268
  previous_tok_id = previous_tokens_vis[j].item() if previous_tokens_vis is not None and j < len(previous_tokens_vis) else MASK_ID
269
- try: decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False, clean_up_tokenization_spaces=False); display_token = MASK_TOKEN if current_tok_id == MASK_ID else decoded_token
 
 
270
  except Exception: display_token = f"[ID:{current_tok_id}]"
271
  color = None; token_to_display = display_token
272
  if current_tok_id == MASK_ID: color = "#444444"
@@ -278,13 +372,16 @@ def generate_dream_response(
278
 
279
  previous_tokens_vis = current_generated_tokens
280
 
281
- # MODIFIED: Update the *content* of the last history item
282
  intermediate_response_tokens = x[0, prompt_length:]
283
- intermediate_response_text = tokenizer.decode(intermediate_response_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True).strip()
284
- history[-1]["content"] = intermediate_response_text # Update last dict entry
 
 
 
 
285
 
286
- # Yield the updated history list and current vis data
287
- yield history, vis_data
288
  time.sleep(visualization_delay)
289
 
290
  end_time = time.time()
@@ -293,17 +390,22 @@ def generate_dream_response(
293
  # --- 6. Final Processing & Yield ---
294
  final_sequence = x[0]
295
  response_tokens = final_sequence[prompt_length:]
296
- final_response_text = tokenizer.decode(response_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True).strip()
297
- # Update the final content in the history object
298
- history[-1]["content"] = final_response_text
 
 
 
299
 
300
  final_generated_tokens = x[0, prompt_length:].cpu()
301
  vis_data_final = []
302
- # [Keep final visualization formatting logic the same]
303
  for j in range(gen_length):
304
  current_tok_id = final_generated_tokens[j].item()
305
  previous_tok_id = previous_tokens_vis[j].item() if previous_tokens_vis is not None and j < len(previous_tokens_vis) else MASK_ID
306
- try: decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False, clean_up_tokenization_spaces=False); display_token = MASK_TOKEN if current_tok_id == MASK_ID else decoded_token
 
 
307
  except Exception: display_token = f"[ID:{current_tok_id}]"
308
  color = None; token_to_display = display_token
309
  if current_tok_id == MASK_ID: color = "#444444"
@@ -313,18 +415,16 @@ def generate_dream_response(
313
  if should_hide and previous_tok_id == current_tok_id: token_to_display = ""; color = None
314
  if token_to_display: vis_data_final.append((token_to_display, color))
315
 
316
- # Yield final history and visualization
317
- yield history, vis_data_final
318
  print("Visualization streaming complete.")
319
 
320
  except Exception as e:
321
  print(f"Error during generation or processing: {e}")
322
- import traceback
323
  traceback.print_exc()
324
- # Set error message in the last history item? Or yield separate error?
325
- # Let's just yield the current history and error vis
326
- history[-1]["content"] = f"Error: {e}" # Put error in assistant message
327
- yield history, [("Error during generation.", "red")]
328
  return
329
 
330
 
@@ -341,18 +441,19 @@ def create_chatbot_demo():
341
  "[[Blog](https://hkunlp.github.io/blog/2025/dream/)]"
342
  )
343
 
344
- # STATE: No explicit state needed if chatbot manages it via input/output
 
 
 
 
 
 
 
 
345
 
346
  with gr.Row():
347
  with gr.Column(scale=3):
348
- # MODIFIED: Use type="messages"
349
- chatbot_ui = gr.Chatbot(
350
- label="Conversation",
351
- type="messages", # Use dictionary format
352
- height=500,
353
- show_copy_button=True,
354
- bubble_full_width=False,
355
- )
356
  with gr.Group():
357
  with gr.Row():
358
  user_input = gr.Textbox(
@@ -362,97 +463,90 @@ def create_chatbot_demo():
362
  send_btn = gr.Button("Send", scale=1, variant="primary")
363
  constraints_input = gr.Textbox(
364
  label="Word Constraints (Optional)",
365
- info="Format: 'pos:word, pos:word,...'. Example: '0:Once, 5:upon, 10:time'",
366
  placeholder="0:Hello, 10:world", value=""
367
  )
368
  with gr.Column(scale=2):
369
  output_vis = gr.HighlightedText(
370
- label="Denoising Process Visualization",
371
- combine_adjacent=True, show_legend=False, interactive=False
 
 
 
372
  )
373
- # REMOVED: Separate response text display
374
 
 
375
  with gr.Accordion("Generation Settings", open=False):
376
- # [Settings sliders remain the same]
377
  with gr.Row():
378
  gen_length = gr.Slider(minimum=16, maximum=512, value=128, step=8, label="Max New Tokens")
379
  steps = gr.Slider(minimum=8, maximum=512, value=128, step=8, label="Diffusion Steps")
380
  with gr.Row():
381
  temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.05, label="Temperature (0 = greedy)")
382
- alg_temp = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="Remasking Temp (Confidence Algs)")
383
  with gr.Row():
384
  top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="Top-P (0 disables)")
385
  top_k = gr.Slider(minimum=0, maximum=200, value=0, step=5, label="Top-K (0 disables)")
386
  with gr.Row():
387
- remasking_strategy = gr.Radio(choices=['origin', 'maskgit_plus', 'topk_margin', 'entropy'], value='entropy', label="Remasking Strategy (Algorithm)")
388
  with gr.Row():
389
- visualization_delay = gr.Slider(minimum=0.0, maximum=0.5, value=0.03, step=0.01, label="Visualization Delay (seconds)")
390
 
391
  clear_btn = gr.Button("Clear Conversation")
392
 
393
  # --- Event Handlers ---
394
 
395
- # MODIFIED: add_user_message uses dictionary format
396
  def add_user_message(message: str, history: List[Dict[str, str]]):
397
- """Adds user message in dictionary format, clears input."""
398
  if not message.strip():
399
  gr.Warning("Please enter a message.")
400
- return history, "" # Return unchanged history, don't clear input here
401
- # Append user message as a dictionary
402
  history.append({"role": "user", "content": message})
403
- # Return updated history, clear input box
404
  return history, ""
405
 
406
- def clear_all():
407
- """Clears chatbot, visualization, and input."""
408
- return [], [], "" # Chatbot, Vis, Input
409
-
410
- # --- Connect UI elements ---
411
-
412
- # Define the inputs for the generation function
413
- # MODIFIED: Input is chatbot_ui (provides List[Dict])
414
  generation_inputs = [
415
- chatbot_ui, # Get history directly from chatbot component
416
  gen_length, steps, constraints_input,
417
  temperature, top_p, top_k, remasking_strategy, alg_temp,
418
  visualization_delay
419
  ]
420
- # Define the outputs for the generation function
421
- # MODIFIED: Output history (List[Dict]) to chatbot_ui, vis_data to output_vis
422
- generation_outputs = [chatbot_ui, output_vis]
423
 
424
- # Handle Textbox Submission (Enter key)
 
 
425
  submit_listener = user_input.submit(
426
- fn=add_user_message, # Use modified function
427
- inputs=[user_input, chatbot_ui], # Pass chatbot state
428
- outputs=[chatbot_ui, user_input], # Update chatbot state, clear input
429
- queue=False # User message add should be quick
430
  ).then(
431
  fn=generate_dream_response,
432
  inputs=generation_inputs,
433
- outputs=generation_outputs, # Stream history to chatbot, vis to output_vis
434
- show_progress="hidden"
435
  )
436
 
437
- # Handle Send Button Click
438
  click_listener = send_btn.click(
439
- fn=add_user_message, # Use modified function
440
- inputs=[user_input, chatbot_ui], # Pass chatbot state
441
- outputs=[chatbot_ui, user_input], # Update chatbot state, clear input
442
- queue=False # User message add should be quick
443
  ).then(
444
  fn=generate_dream_response,
445
  inputs=generation_inputs,
446
- outputs=generation_outputs, # Stream history to chatbot, vis to output_vis
447
  show_progress="hidden"
448
  )
449
 
450
  # Clear Button Action
451
  clear_btn.click(
452
- clear_all, # Use modified clear function
453
  inputs=[],
454
- outputs=[chatbot_ui, output_vis, user_input], # Clear chatbot, vis, input
455
- queue=False
456
  )
457
 
458
  return demo
 
7
  from transformers import AutoTokenizer, AutoModel, AutoConfig
8
  import time
9
  import re
10
+ from typing import List, Dict, Tuple, Optional, Any # Added Any
11
  import torch.distributions as dists # Added import
12
+ import traceback # For better error printing
13
 
14
  # --- START: Copied Helper functions from generation_utils.py ---
15
  # [Keep the copied functions: top_p_logits, top_k_logits, sample_tokens]
16
  def top_p_logits(logits, top_p=None):
17
+ """ Applies top-p filtering to logits. """
18
+ if top_p is None or top_p >= 1.0:
19
+ return logits
20
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
21
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
22
  sorted_indices_to_remove = cumulative_probs > top_p
23
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
24
+ sorted_indices_to_remove[..., 0] = 0
25
+ mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
26
+ mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
27
+ logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
28
+ return logits
29
 
30
  def top_k_logits(logits, top_k=None):
31
+ """ Applies top-k filtering to logits. """
32
+ if top_k is None or top_k <= 0:
33
+ return logits
34
  top_k = min(top_k, logits.size(-1))
35
  indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
36
+ logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
37
+ return logits
38
 
39
  def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
40
+ """ Samples tokens based on logits and calculates confidence. """
41
+ if temperature > 0:
42
+ safe_temp = max(temperature, 1e-6)
43
+ logits = logits / safe_temp
44
+ if top_p is not None and 0.0 < top_p < 1.0:
45
+ logits = top_p_logits(logits, top_p)
46
+ if top_k is not None and top_k > 0:
47
+ logits = top_k_logits(logits, top_k)
48
+
49
+ is_all_neg_inf = torch.all(logits <= torch.finfo(logits.dtype).min, dim=-1, keepdim=True)
50
+ if torch.any(is_all_neg_inf):
51
+ uniform_logits = torch.zeros_like(logits)
52
+ logits = torch.where(is_all_neg_inf, uniform_logits, logits)
53
+
54
  probs = torch.softmax(logits, dim=-1)
55
+ probs = torch.clamp(probs, min=0.0)
56
+ prob_sum = probs.sum(dim=-1, keepdim=True)
57
+ safe_prob_sum = torch.max(prob_sum, torch.tensor(1e-12, device=probs.device, dtype=probs.dtype))
58
+ probs = probs / safe_prob_sum
59
+ probs = torch.nan_to_num(probs, nan=0.0)
60
+
61
  if temperature > 0:
62
+ try:
63
+ x0 = dists.Categorical(probs=probs).sample()
64
+ confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
65
+ except Exception as e:
66
+ print(f"Warning: Error during Categorical sampling: {e}. Falling back to argmax.")
67
+ confidence, x0 = probs.max(dim=-1)
68
+ else:
69
+ confidence, x0 = probs.max(dim=-1)
70
+
71
+ if margin_confidence:
72
+ sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
73
+ top1_probs = sorted_probs[..., 0]
74
+ top2_probs = sorted_probs[..., 1] if sorted_probs.shape[-1] > 1 else top1_probs
75
+ confidence = top1_probs - top2_probs
76
+ elif neg_entropy: # Use elif to avoid calculating entropy if margin_confidence was True
77
+ epsilon = 1e-10
78
+ log_probs = torch.log(probs + epsilon)
79
+ confidence = torch.sum(probs * log_probs, dim=-1) # Negative entropy
80
+ # Else: confidence is just the probability of the sampled token if temperature > 0, or max prob otherwise
81
+
82
  confidence = torch.nan_to_num(confidence, nan=0.0)
83
  return confidence, x0
84
  # --- END: Copied Helper functions ---
85
 
86
+
87
+ # --- Model Loading and Constants ---
88
  # Load model configuration to get special token IDs
89
  config = AutoConfig.from_pretrained("Dream-org/Dream-v0-Instruct-7B", trust_remote_code=True)
90
  model_path = "Dream-org/Dream-v0-Instruct-7B"
91
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
92
  print(f"Using device: {device}")
93
+
94
  print("Loading tokenizer...")
95
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
96
  print("Loading model...")
 
98
  model_path,
99
  torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32,
100
  trust_remote_code=True,
101
+ attn_implementation="sdpa" # Explicitly request SDPA
102
  )
103
  model = model.to(device).eval()
104
  print("Model loaded.")
105
+
106
  MASK_TOKEN = tokenizer.mask_token
107
  MASK_ID = tokenizer.mask_token_id
108
  PAD_ID = tokenizer.pad_token_id
109
  EOS_ID = tokenizer.eos_token_id
110
+
111
+ if MASK_ID is None:
112
+ raise ValueError("Cannot determine MASK_ID. Check model's tokenizer configuration.")
113
+
114
  SPECIAL_TOKEN_IDS = {PAD_ID, EOS_ID, MASK_ID}
115
  try:
116
  IM_START_ID = tokenizer.convert_tokens_to_ids("<|im_start|>")
117
  IM_END_ID = tokenizer.convert_tokens_to_ids("<|im_end|>")
118
  SPECIAL_TOKEN_IDS.add(IM_START_ID)
119
  SPECIAL_TOKEN_IDS.add(IM_END_ID)
120
+ except KeyError:
121
+ print("Warning: <|im_start|> or <|im_end|> not found in tokenizer vocab.")
122
+ IM_START_ID = None
123
+ IM_END_ID = None
124
+
125
 
126
  # --- Helper Functions ---
127
  def parse_constraints(constraints_text: str) -> Dict[int, List[int]]:
128
+ """ Parses word constraints. """
129
  constraints = {}
130
  if not constraints_text: return constraints
131
  parts = constraints_text.split(',')
 
137
  pos = int(pos_str.strip())
138
  word = word.strip()
139
  token_ids = []
140
+ if word:
141
+ text_to_encode = (" " + word) if (pos > 0 and not word.startswith(" ")) else word
142
+ token_ids = tokenizer.encode(text_to_encode, add_special_tokens=False)
143
  if token_ids and pos >= 0: constraints[pos] = token_ids
144
  elif not token_ids and word: print(f"Warning: Could not tokenize constraint word '{word}'")
145
  except ValueError: print(f"Warning: Invalid position '{pos_str}' in constraint part '{part}'")
146
  except Exception as e: print(f"Warning: Error processing constraint '{part}': {e}")
147
  return constraints
148
 
149
+ # Removed format_chat_history as the state will now be in the correct format
150
 
151
  def apply_constraints_to_state(
152
+ x: torch.Tensor,
153
+ prompt_length: int,
154
+ total_length: int,
155
+ parsed_constraints: Dict[int, List[int]],
156
+ current_step: Optional[int] = None
157
  ) -> torch.Tensor:
158
+ """ Applies constraints directly to the state tensor `x`. """
159
  modified_x = x.clone()
160
  for rel_pos, word_token_ids in parsed_constraints.items():
161
+ abs_start_pos = prompt_length + rel_pos
162
+ abs_end_pos = abs_start_pos + len(word_token_ids)
163
  if abs_start_pos < total_length and abs_end_pos <= total_length:
164
+ try:
165
+ constraint_tensor = torch.tensor(word_token_ids, dtype=torch.long, device=modified_x.device)
166
+ modified_x[0, abs_start_pos:abs_end_pos] = constraint_tensor
167
+ except IndexError: print(f"Warning (Step {current_step}): Constraint at {rel_pos} ('{tokenizer.decode(word_token_ids)}') goes out of bounds.")
168
+ except Exception as e: print(f"Warning (Step {current_step}): Failed to apply constraint at {rel_pos}: {e}")
169
  return modified_x
170
 
171
 
 
174
  @spaces.GPU
175
  @torch.no_grad()
176
  def generate_dream_response(
177
+ history_dict_list: List[Dict[str, str]], # Now expects list of dicts
178
  gen_length: int,
179
  steps: int,
180
  constraints_text: str,
 
184
  alg: str,
185
  alg_temp: Optional[float],
186
  visualization_delay: float
187
+ ) -> List[Tuple[str, str]]:
188
  """ Generates text step-by-step and yields visualization states live. """
189
 
190
+ if not history_dict_list or history_dict_list[-1]['role'] != 'user':
191
+ # Handle cases where history is empty or doesn't end with user message
192
+ # This check might be redundant if add_user_message handles it, but good for safety.
193
+ yield history_dict_list, [("No user message found.", "red")], ""
194
  return
195
 
196
  # --- 1. Preparation ---
 
197
  parsed_constraints = parse_constraints(constraints_text)
198
 
199
+ # Prepare history for the model template (don't include the empty assistant msg yet)
200
+ history_for_template = history_dict_list # Already in list-of-dicts format
201
+
202
  try:
 
203
  inputs = tokenizer.apply_chat_template(
204
+ history_for_template, # Pass the list of dicts directly
205
  return_tensors="pt",
206
  return_dict=True,
207
+ add_generation_prompt=True # Crucial: Adds the '<|im_start|>assistant\n' turn
208
  )
209
  input_ids = inputs.input_ids.to(device)
210
  prompt_attention_mask = inputs.attention_mask.to(device) if 'attention_mask' in inputs else torch.ones_like(input_ids)
211
+ prompt_length = input_ids.shape[1]
212
  except Exception as e:
213
  print(f"Error applying chat template: {e}")
214
+ traceback.print_exc()
215
+ yield history_dict_list, [("Error preparing input.", "red")], ""
216
  return
217
 
218
  eps = 1e-3
 
223
  # --- 2. Initialize Generation State ---
224
  total_length = prompt_length + gen_length
225
  initial_generation_part = torch.full((1, gen_length), MASK_ID, dtype=torch.long, device=device)
 
226
  x = torch.cat((input_ids, initial_generation_part), dim=1)
227
 
 
228
  generation_attention_mask = torch.ones((1, gen_length), dtype=torch.long, device=device)
 
229
  full_attention_mask_long = torch.cat((prompt_attention_mask, generation_attention_mask), dim=1)
230
 
231
  attention_mask_for_model = full_attention_mask_long.to(model.dtype)
 
233
  attention_mask_for_model = (1.0 - attention_mask_for_model) * large_neg_val
234
  attention_mask_for_model = attention_mask_for_model.unsqueeze(1).unsqueeze(2) # [B, 1, 1, N]
235
 
 
236
  timesteps = torch.linspace(1, eps, steps + 1, device=device)
 
 
237
  x = apply_constraints_to_state(x, prompt_length, total_length, parsed_constraints, current_step=-1)
238
 
239
  # --- 3. Visualization & History Setup ---
240
  previous_tokens_vis = None
241
+ final_response_text = ""
242
+ # The history_dict_list is the state we update and yield for the chatbot UI
243
+ # Add the empty assistant message placeholder *to the history state* now
244
+ history_dict_list.append({"role": "assistant", "content": ""})
245
 
246
  # --- 4. Initial Yield (Masked State) ---
247
  initial_generated_tokens = x[0, prompt_length:].cpu()
248
  vis_data_initial = []
249
  for tok_id in initial_generated_tokens.tolist():
250
+ display_token = MASK_TOKEN
251
+ color = "#444444"
252
  vis_data_initial.append((display_token, color))
253
 
254
  previous_tokens_vis = initial_generated_tokens
255
+ # Yield the history (which now includes the empty assistant turn)
256
+ yield history_dict_list, vis_data_initial, ""
257
  time.sleep(visualization_delay)
258
 
259
  # --- 5. Step-by-Step Diffusion Loop ---
 
261
  start_time = time.time()
262
  for i in range(steps):
263
  mask_index = (x == MASK_ID)
264
+ if not mask_index.any():
265
+ print(f"No mask tokens left at step {i}. Stopping early.")
266
+ break
267
+
268
+ outputs = model(
269
+ input_ids=x,
270
+ attention_mask=attention_mask_for_model,
271
+ position_ids=None, use_cache=False, return_dict=True
272
+ )
273
  logits = outputs.logits
274
+ logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
275
 
276
  mask_logits = logits[mask_index]
277
+ if mask_logits.numel() == 0:
278
+ print(f"No masked tokens found for logit selection at step {i}. Stopping.")
279
+ break
280
 
281
+ t = timesteps[i]
282
+ s = timesteps[i + 1]
283
  x_new_masked_part = torch.full_like(x[mask_index], MASK_ID, device=device, dtype=torch.long)
284
 
285
+ # [Keep sampling logic the same - 'origin' and confidence-based]
286
  if alg == 'origin':
287
  p_transfer = (1.0 - s / t) if i < steps - 1 else 1.0
288
  num_masked = mask_logits.shape[0]
289
  transfer_indices_relative = torch.rand(num_masked, device=device) < p_transfer
290
  logits_to_sample = mask_logits[transfer_indices_relative]
291
+ if logits_to_sample.numel() > 0:
292
+ _, sampled_tokens = sample_tokens(logits_to_sample, temperature=temperature, top_p=top_p_val, top_k=top_k_val)
293
+ if transfer_indices_relative.sum() == sampled_tokens.numel(): # Basic check
294
+ x_new_masked_part[transfer_indices_relative] = sampled_tokens
295
+ else: print(f"Warning step {i} (origin): Mismatch transfer indices and sampled tokens.")
296
+
297
+
298
+ else: # Confidence-based
299
+ use_margin = (alg == 'topk_margin')
300
+ use_entropy = (alg == 'entropy')
301
  confidence, x0_candidates = sample_tokens(mask_logits, temperature=temperature, top_p=top_p_val, top_k=top_k_val, margin_confidence=use_margin, neg_entropy=use_entropy)
302
+
303
  num_mask_token = mask_logits.shape[0]
304
  target_num_revealed_float = num_mask_token * (1.0 - s / t)
305
  number_transfer_tokens = int(target_num_revealed_float) if i < steps - 1 else num_mask_token
306
+
307
  if number_transfer_tokens > 0:
308
  num_samples = min(number_transfer_tokens, num_mask_token)
309
  if num_samples > 0:
310
+ transfer_indices_relative = torch.tensor([], dtype=torch.long, device=device) # Init empty
311
+ if alg_temp_val is None or alg_temp_val <= 0: # Top-k
312
  sort_metric = confidence if alg != 'entropy' else -confidence
313
  k_topk = min(num_samples, sort_metric.numel())
314
  if k_topk > 0: _, transfer_indices_relative = torch.topk(sort_metric, k=k_topk)
315
+ else: # Sample based on temp
316
  if confidence.numel() > 0:
317
+ conf_probs = confidence / alg_temp_val
318
+ conf_probs = torch.nan_to_num(conf_probs, nan=0.0, posinf=1e9, neginf=-1e9)
319
+ conf_probs = torch.clamp(conf_probs - conf_probs.max(), min=-30)
320
+ conf_probs = F.softmax(conf_probs, dim=-1)
321
+ conf_probs = torch.clamp(conf_probs, min=0.0)
322
+ conf_probs = torch.nan_to_num(conf_probs, nan=0.0)
323
+ prob_sum = conf_probs.sum()
324
+ target_sum_tensor = torch.tensor(1.0, device=device, dtype=prob_sum.dtype)
325
+ if not torch.isclose(prob_sum, target_sum_tensor, atol=1e-4) and prob_sum > 0:
326
+ safe_prob_sum = torch.max(prob_sum, torch.tensor(1e-12, device=device, dtype=prob_sum.dtype))
327
+ conf_probs = conf_probs / safe_prob_sum
328
  final_prob_sum_check = conf_probs.sum()
329
  if conf_probs.numel() > 0 and num_samples > 0 and torch.all(conf_probs >= 0) and torch.isclose(final_prob_sum_check, target_sum_tensor, atol=1e-4):
330
  try: transfer_indices_relative = torch.multinomial(conf_probs, num_samples=num_samples, replacement=False)
331
+ except RuntimeError as e:
332
+ print(f"Warning step {i}: Multinomial sampling failed ('{e}'). Falling back to top-k.")
333
+ sort_metric = confidence if alg != 'entropy' else -confidence
334
+ k_multinomial_fallback = min(num_samples, sort_metric.numel())
335
+ if k_multinomial_fallback > 0: _, transfer_indices_relative = torch.topk(sort_metric, k=k_multinomial_fallback)
336
+ else: # Fallback if probs invalid for multinomial
337
+ # print(f"Warning step {i}: Invalid probabilities for multinomial sampling (sum={final_prob_sum_check:.4f}). Falling back to top-k.")
338
+ sort_metric = confidence if alg != 'entropy' else -confidence
339
+ k_multinomial_fallback = min(num_samples, sort_metric.numel())
340
+ if k_multinomial_fallback > 0: _, transfer_indices_relative = torch.topk(sort_metric, k=k_multinomial_fallback)
341
+
342
  # Apply transfer
343
  if transfer_indices_relative.numel() > 0:
344
+ if x0_candidates.numel() > 0 and transfer_indices_relative.max() < x0_candidates.shape[0]:
345
+ if transfer_indices_relative.max() < x_new_masked_part.shape[0]:
346
+ x_new_masked_part[transfer_indices_relative] = x0_candidates[transfer_indices_relative].clone()
347
+ else: print(f"Warning step {i}: transfer_indices out of bounds for x_new_masked_part.")
348
+ else: print(f"Warning step {i}: transfer_indices out of bounds for x0_candidates or x0_candidates empty.")
349
 
 
350
 
351
+ x[mask_index] = x_new_masked_part
 
352
  x = apply_constraints_to_state(x, prompt_length, total_length, parsed_constraints, current_step=i)
353
 
354
+ # --- Yield Visualization & Update History ---
355
  current_generated_tokens = x[0, prompt_length:].cpu()
356
  vis_data = []
357
+ # [Visualization formatting logic remains the same]
358
  for j in range(gen_length):
359
  current_tok_id = current_generated_tokens[j].item()
360
  previous_tok_id = previous_tokens_vis[j].item() if previous_tokens_vis is not None and j < len(previous_tokens_vis) else MASK_ID
361
+ try:
362
+ decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False, clean_up_tokenization_spaces=False)
363
+ display_token = MASK_TOKEN if current_tok_id == MASK_ID else decoded_token
364
  except Exception: display_token = f"[ID:{current_tok_id}]"
365
  color = None; token_to_display = display_token
366
  if current_tok_id == MASK_ID: color = "#444444"
 
372
 
373
  previous_tokens_vis = current_generated_tokens
374
 
 
375
  intermediate_response_tokens = x[0, prompt_length:]
376
+ intermediate_response_text = tokenizer.decode(
377
+ intermediate_response_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True
378
+ ).strip()
379
+
380
+ # --- Update the *last* message in history_dict_list ---
381
+ history_dict_list[-1]['content'] = intermediate_response_text
382
 
383
+ # Yield the updated history list (for chatbot UI), vis data, and response text
384
+ yield history_dict_list, vis_data, intermediate_response_text
385
  time.sleep(visualization_delay)
386
 
387
  end_time = time.time()
 
390
  # --- 6. Final Processing & Yield ---
391
  final_sequence = x[0]
392
  response_tokens = final_sequence[prompt_length:]
393
+ final_response_text = tokenizer.decode(
394
+ response_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True
395
+ ).strip()
396
+
397
+ # Ensure the final text is in the history object before the last yield
398
+ history_dict_list[-1]['content'] = final_response_text
399
 
400
  final_generated_tokens = x[0, prompt_length:].cpu()
401
  vis_data_final = []
402
+ # [Final visualization formatting logic remains the same]
403
  for j in range(gen_length):
404
  current_tok_id = final_generated_tokens[j].item()
405
  previous_tok_id = previous_tokens_vis[j].item() if previous_tokens_vis is not None and j < len(previous_tokens_vis) else MASK_ID
406
+ try:
407
+ decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False, clean_up_tokenization_spaces=False)
408
+ display_token = MASK_TOKEN if current_tok_id == MASK_ID else decoded_token
409
  except Exception: display_token = f"[ID:{current_tok_id}]"
410
  color = None; token_to_display = display_token
411
  if current_tok_id == MASK_ID: color = "#444444"
 
415
  if should_hide and previous_tok_id == current_tok_id: token_to_display = ""; color = None
416
  if token_to_display: vis_data_final.append((token_to_display, color))
417
 
418
+ yield history_dict_list, vis_data_final, final_response_text
 
419
  print("Visualization streaming complete.")
420
 
421
  except Exception as e:
422
  print(f"Error during generation or processing: {e}")
 
423
  traceback.print_exc()
424
+ # Attempt to add error message to history if possible
425
+ if history_dict_list and history_dict_list[-1]['role'] == 'assistant':
426
+ history_dict_list[-1]['content'] = f"Error: {e}"
427
+ yield history_dict_list, [("Error during generation.", "red")], f"Error: {e}" # Also show error in text box
428
  return
429
 
430
 
 
441
  "[[Blog](https://hkunlp.github.io/blog/2025/dream/)]"
442
  )
443
 
444
+ # Use Chatbot directly for state, matching the expected format
445
+ chatbot_ui = gr.Chatbot(
446
+ label="Conversation",
447
+ height=500,
448
+ show_copy_button=True,
449
+ bubble_full_width=False,
450
+ value=[], # Initialize empty
451
+ type="messages" # Crucial: Use the messages format
452
+ )
453
 
454
  with gr.Row():
455
  with gr.Column(scale=3):
456
+ # Chatbot moved above this row
 
 
 
 
 
 
 
457
  with gr.Group():
458
  with gr.Row():
459
  user_input = gr.Textbox(
 
463
  send_btn = gr.Button("Send", scale=1, variant="primary")
464
  constraints_input = gr.Textbox(
465
  label="Word Constraints (Optional)",
466
+ info="Format: 'pos:word, pos:word,...'. Example: '0:Once, 5:upon'",
467
  placeholder="0:Hello, 10:world", value=""
468
  )
469
  with gr.Column(scale=2):
470
  output_vis = gr.HighlightedText(
471
+ label="Denoising Process Visualization", combine_adjacent=True,
472
+ show_legend=False, interactive=False
473
+ )
474
+ response_text_display = gr.Textbox(
475
+ label="Current/Final Response", interactive=False, lines=5
476
  )
 
477
 
478
+ # [Keep Accordion with Generation Settings the same]
479
  with gr.Accordion("Generation Settings", open=False):
 
480
  with gr.Row():
481
  gen_length = gr.Slider(minimum=16, maximum=512, value=128, step=8, label="Max New Tokens")
482
  steps = gr.Slider(minimum=8, maximum=512, value=128, step=8, label="Diffusion Steps")
483
  with gr.Row():
484
  temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.05, label="Temperature (0 = greedy)")
485
+ alg_temp = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="Remasking Temp (Conf Algs)")
486
  with gr.Row():
487
  top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="Top-P (0 disables)")
488
  top_k = gr.Slider(minimum=0, maximum=200, value=0, step=5, label="Top-K (0 disables)")
489
  with gr.Row():
490
+ remasking_strategy = gr.Radio(choices=['origin', 'maskgit_plus', 'topk_margin', 'entropy'], value='entropy', label="Remasking Strategy")
491
  with gr.Row():
492
+ visualization_delay = gr.Slider(minimum=0.0, maximum=0.5, value=0.03, step=0.01, label="Visualization Delay (s)")
493
 
494
  clear_btn = gr.Button("Clear Conversation")
495
 
496
  # --- Event Handlers ---
497
 
498
+ # User function: Appends user message to the history (list of dicts)
499
  def add_user_message(message: str, history: List[Dict[str, str]]):
 
500
  if not message.strip():
501
  gr.Warning("Please enter a message.")
502
+ return history, "" # Return unchanged history, empty input
 
503
  history.append({"role": "user", "content": message})
504
+ # Return updated history for chatbot UI, and clear input box
505
  return history, ""
506
 
507
+ # Bot function (now the generator)
508
+ # Inputs: Chatbot history (list of dicts), generation params
509
+ # Outputs: Chatbot history (updated list of dicts), visualization, response text
 
 
 
 
 
510
  generation_inputs = [
511
+ chatbot_ui, # Pass chatbot state directly (list of dicts)
512
  gen_length, steps, constraints_input,
513
  temperature, top_p, top_k, remasking_strategy, alg_temp,
514
  visualization_delay
515
  ]
516
+ generation_outputs = [chatbot_ui, output_vis, response_text_display]
 
 
517
 
518
+ # --- Connect UI elements ---
519
+
520
+ # Textbox Submission (Enter key)
521
  submit_listener = user_input.submit(
522
+ fn=add_user_message,
523
+ inputs=[user_input, chatbot_ui],
524
+ outputs=[chatbot_ui, user_input] # Update chatbot UI and clear input
 
525
  ).then(
526
  fn=generate_dream_response,
527
  inputs=generation_inputs,
528
+ outputs=generation_outputs,
529
+ show_progress="hidden" # Hide default progress bar
530
  )
531
 
532
+ # Send Button Click
533
  click_listener = send_btn.click(
534
+ fn=add_user_message,
535
+ inputs=[user_input, chatbot_ui],
536
+ outputs=[chatbot_ui, user_input] # Update chatbot UI and clear input
 
537
  ).then(
538
  fn=generate_dream_response,
539
  inputs=generation_inputs,
540
+ outputs=generation_outputs,
541
  show_progress="hidden"
542
  )
543
 
544
  # Clear Button Action
545
  clear_btn.click(
546
+ lambda: ([], [], ""), # Function to return empty values
547
  inputs=[],
548
+ outputs=[chatbot_ui, output_vis, response_text_display], # Clear chatbot, vis, text
549
+ queue=False # No need to queue clearing usually
550
  )
551
 
552
  return demo