Spaces:
Running
on
Zero
Running
on
Zero
File size: 26,826 Bytes
c691b46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 |
# llada_app.py -> dream_app.py
import torch
import numpy as np
import gradio as gr
import spaces
# import torch.nn.functional as F # Not needed for DREAM's basic visualization
from transformers import AutoTokenizer, AutoModel
import time
import re # Keep for parsing constraints
# Use try-except for space deployment vs local
try:
# Used for spaces deployment with GPU
gpu_check = spaces.GPU
print("Running in Gradio Spaces with GPU environment.")
except AttributeError:
# Fallback for local execution or environments without spaces.GPU
print("Running in local environment or without spaces.GPU.")
# Define a dummy decorator if spaces.GPU is not available
def gpu_check(func):
return func
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# --- Load DREAM Model and Tokenizer ---
model_path = "Dream-org/Dream-v0-Instruct-7B"
print(f"Loading model: {model_path}")
model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
print("Model and tokenizer loaded.")
# --- Constants for DREAM ---
# Find the mask token and ID from the DREAM tokenizer
if tokenizer.mask_token is None:
# Handle cases where a mask token might not be explicitly set
# You might need to choose a suitable placeholder or investigate further
# For now, let's try adding one if it's missing and check its id
# This is speculative and might depend on the specific tokenizer setup
print("Warning: Mask token not found in tokenizer. Attempting to add.")
tokenizer.add_special_tokens({'mask_token': '[MASK]'})
model.resize_token_embeddings(len(tokenizer)) # Important if vocab size changed
if tokenizer.mask_token is None:
raise ValueError("Could not set a mask token for the tokenizer.")
MASK_TOKEN = tokenizer.mask_token
MASK_ID = tokenizer.mask_token_id
print(f"Using MASK_TOKEN='{MASK_TOKEN}' with ID={MASK_ID}")
# --- Helper Functions (Constraint Parsing, History Formatting) ---
def parse_constraints(constraints_text):
"""Parse constraints in format: 'position:word, position:word, ...'"""
constraints = {}
if not constraints_text:
return constraints
parts = constraints_text.split(',')
for part in parts:
part = part.strip() # Trim whitespace
if ':' not in part:
continue
try:
pos_str, word = part.split(':', 1)
pos = int(pos_str.strip())
word = word.strip()
# Allow empty words if needed, but usually we want a word
if word and pos >= 0:
constraints[pos] = word
except ValueError:
print(f"Warning: Could not parse constraint part: '{part}'")
continue
return constraints
def format_chat_history(history):
"""
Format chat history for the DREAM model (standard messages format)
Args:
history: List of [user_message, assistant_message] pairs
Returns:
Formatted conversation for the model (list of dictionaries)
"""
messages = []
# Add system prompt if desired (check DREAM examples/recommendations)
# messages.append({"role": "system", "content": "You are a helpful assistant."}) # Optional
for user_msg, assistant_msg in history:
if user_msg: # Handle potential None message if clearing failed
messages.append({"role": "user", "content": user_msg})
if assistant_msg: # Skip if None (for the latest user message awaiting response)
messages.append({"role": "assistant", "content": assistant_msg})
return messages
# --- Core Generation Logic for DREAM with Visualization ---
@gpu_check # Use the potentially dummy decorator
def dream_generate_response_with_visualization(
messages,
gen_length=64,
steps=64, # Default based on DREAM examples
constraints=None,
temperature=0.6, # Default based on DREAM examples
top_p=0.95, # Default based on DREAM examples
alg="entropy", # Default based on DREAM examples
alg_temp=0.0, # Default based on DREAM examples
):
"""
Generate text with DREAM model with visualization using the generation hook.
Args:
messages: List of message dictionaries with 'role' and 'content'
gen_length: Length of text to generate (max_new_tokens)
steps: Number of diffusion steps
constraints: Dictionary mapping positions (relative to response start) to words
temperature: Sampling temperature
top_p: Nucleus sampling p
alg: Remasking algorithm ('origin', 'maskgit_plus', 'topk_margin', 'entropy')
alg_temp: Temperature for confidence-based algorithms
Returns:
Tuple: (List of visualization states, final generated text string)
"""
print("--- Starting DREAM Generation ---")
print(f"Parameters: gen_length={gen_length}, steps={steps}, temperature={temperature}, top_p={top_p}, alg='{alg}', alg_temp={alg_temp}")
print(f"Constraints: {constraints}")
# --- Input Preparation ---
if constraints is None:
constraints = {}
# Convert word constraints to token IDs (handle multi-token words)
processed_constraints = {}
print("Processing constraints:")
for pos, word in constraints.items():
# Prepend space for consistent tokenization, similar to LLaDA example
tokens = tokenizer.encode(" " + word, add_special_tokens=False)
if not tokens:
print(f" Warning: Could not tokenize constraint word '{word}' at position {pos}. Skipping.")
continue
print(f" Pos {pos}, Word '{word}' -> Tokens {tokens}")
for i, token_id in enumerate(tokens):
# Ensure we don't overwrite parts of multi-token constraints accidentally
if pos + i not in processed_constraints:
processed_constraints[pos + i] = token_id
else:
print(f" Warning: Overlapping constraint at position {pos+i}. Keeping first.")
# Prepare the prompt using chat template
# Note: DREAM examples use add_generation_prompt=True
try:
inputs = tokenizer.apply_chat_template(
messages,
return_tensors="pt",
return_dict=True,
add_generation_prompt=True # Crucial for instruction-tuned models like Dream-Instruct
)
input_ids = inputs.input_ids.to(device=device)
attention_mask = inputs.attention_mask.to(device=device) # Get attention mask
prompt_length = input_ids.shape[1]
print(f"Input prompt length: {prompt_length}")
print(f"Input IDs: {input_ids}")
except Exception as e:
print(f"Error applying chat template: {e}")
# Provide a fallback or raise the error
# Fallback: Simple concatenation (less ideal for instruction models)
# chat_input = "".join([f"{msg['role']}: {msg['content']}\n" for msg in messages]) + "assistant:"
# input_ids = tokenizer(chat_input, return_tensors="pt").input_ids.to(device)
# attention_mask = torch.ones_like(input_ids)
# prompt_length = input_ids.shape[1]
# print(f"Warning: Using basic concatenation due to template error. Prompt length: {prompt_length}")
return [([("Error applying chat template.", "red")],)], f"Error: {e}"
if prompt_length + gen_length > 2048: # Check context length (DREAM uses 2048)
print(f"Warning: Requested length ({prompt_length + gen_length}) exceeds model max length (2048). Truncating gen_length.")
gen_length = 2048 - prompt_length
if gen_length <= 0:
print("Error: Prompt is already too long.")
return [([("Prompt too long.", "red")],)], "Error: Prompt too long."
# --- State for Visualization Hook ---
visualization_states = []
last_x = None # Store the sequence from the previous step
# Initial state: Prompt + all masks
initial_x_part = torch.full((1, gen_length), MASK_ID, dtype=torch.long, device=device)
# Apply initial constraints to the masked part *before* showing the first state
for pos, token_id in processed_constraints.items():
absolute_pos = pos # Position relative to start of generation
if 0 <= absolute_pos < gen_length:
initial_x_part[0, absolute_pos] = token_id
initial_state_vis = []
for i in range(gen_length):
token_id = initial_x_part[0, i].item()
if token_id == MASK_ID:
initial_state_vis.append((MASK_TOKEN, "#444444")) # Mask color
else:
# This must be a constraint applied initially
token_str = tokenizer.decode([token_id], skip_special_tokens=True)
initial_state_vis.append((token_str if token_str else "?", "#800080")) # Constraint color (purple)
visualization_states.append(initial_state_vis)
# --- Define the Hook Function ---
def generation_tokens_hook_func(step, x, logits):
nonlocal last_x, visualization_states # Allow modification of outer scope variables
print(f"Hook called for step {step}")
current_x = x.clone() # Work on a copy for comparison
# 1. Apply Constraints *before* generating visualization
# Constraints are relative to the start of the *generated* part
constrained_x = current_x.clone()
prompt_len = current_x.shape[1] - gen_length # Recalculate just in case
if prompt_len < 0:
print("Warning: prompt_len negative in hook, skipping constraints/vis.")
return current_x # Return unmodified if something is wrong
constraints_applied_this_step = False
for pos, token_id in processed_constraints.items():
absolute_pos = prompt_len + pos
if prompt_len <= absolute_pos < current_x.shape[1]:
if constrained_x[0, absolute_pos] != token_id:
constrained_x[0, absolute_pos] = token_id
constraints_applied_this_step = True
# print(f" Constraint applied at pos {pos} ({absolute_pos}) -> token {token_id}")
# 2. Generate Visualization State for *this* step
current_state_vis = []
# Compare current_x (before explicit constraint application in *this* hook call)
# with last_x (state from *previous* hook call / initial state)
# Generate based on the state *before* reapplying constraints here,
# but *after* the model's diffusion step determined current_x.
gen_part_current = current_x[0, prompt_len:]
gen_part_last = last_x[0, prompt_len:] if last_x is not None else None
for i in range(gen_length):
current_token_id = gen_part_current[i].item()
token_str = tokenizer.decode([current_token_id], skip_special_tokens=True).strip()
# Use a placeholder if decoding results in empty string
display_token = token_str if token_str else MASK_TOKEN if current_token_id == MASK_ID else "?"
# Check if this position is constrained
is_constrained = i in processed_constraints
if current_token_id == MASK_ID:
color = "#444444" # Dark Gray for masks
elif is_constrained and processed_constraints[i] == current_token_id:
color = "#800080" # Purple for correctly constrained tokens
elif gen_part_last is None or gen_part_last[i].item() == MASK_ID:
# Newly revealed (was mask in previous step or initial state)
color = "#66CC66" # Light Green
else:
# Previously revealed and not masked
color = "#6699CC" # Light Blue
current_state_vis.append((display_token, color))
visualization_states.append(current_state_vis)
# 3. Update last_x for the *next* step's comparison
# Store the state *after* applying constraints for accurate comparison next time
last_x = constrained_x.clone()
# 4. Return the sequence with constraints applied for the model's next step
# print(f"Hook returning constrained_x: {constrained_x[:, prompt_len:]}")
return constrained_x # Return the sequence with constraints enforced
# --- Run DREAM Generation ---
try:
print("Calling model.diffusion_generate...")
# Make sure last_x is initialized correctly before the first hook call
# It should represent the state *before* the first diffusion step.
initial_full_x = torch.cat([input_ids, initial_x_part], dim=1)
last_x = initial_full_x.clone()
output = model.diffusion_generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=gen_length,
output_history=False, # We build history in the hook
return_dict_in_generate=True,
steps=steps,
temperature=temperature,
top_p=top_p,
alg=alg,
alg_temp=alg_temp if alg != "origin" else 0.0, # alg_temp only for confidence algs
generation_tokens_hook_func=generation_tokens_hook_func
)
print("model.diffusion_generate finished.")
# Extract final generated sequence (response part only)
# The hook ensures the returned sequence has constraints applied
final_sequence = output.sequences[0]
response_token_ids = final_sequence[prompt_length:]
# Decode the final response
final_text = tokenizer.decode(
response_token_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True # Recommended for cleaner output
).strip()
print(f"Final generated text: {final_text}")
# Add the very final state to visualization if the hook didn't capture it
# (Should be captured, but as a safeguard)
if len(visualization_states) <= steps: # Hook might run 'steps' times
final_state_vis = []
final_gen_part = final_sequence[prompt_length:]
for i in range(gen_length):
token_id = final_gen_part[i].item()
token_str = tokenizer.decode([token_id], skip_special_tokens=True).strip()
display_token = token_str if token_str else MASK_TOKEN if token_id == MASK_ID else "?"
is_constrained = i in processed_constraints
if token_id == MASK_ID: color = "#444444"
elif is_constrained and processed_constraints[i] == token_id: color = "#800080"
else: color = "#6699CC" # Default to blue for final state tokens
final_state_vis.append((display_token, color))
visualization_states.append(final_state_vis)
except Exception as e:
print(f"Error during generation: {e}")
import traceback
traceback.print_exc()
# Add error message to visualization
error_msg = f"Error during generation: {str(e)}"
visualization_states.append([("Error", "red")])
final_text = f"Generation failed: {e}"
print("--- DREAM Generation Finished ---")
return visualization_states, final_text
# --- Gradio UI Setup ---
css = '''
.category-legend{display:none}
/* button{height: 60px} */ /* Optional: Adjust button height */
.small_btn {
max-width: 100px; /* Adjust as needed */
height: 40px; /* Adjust as needed */
flex-grow: 0; /* Prevent button from growing */
margin-left: 5px; /* Add some space */
}
.chat-input-row {
display: flex;
align-items: center; /* Vertically align items */
}
.chat-input-row > * {
margin-right: 5px; /* Space between textbox and button */
}
.chat-input-row > *:last-child {
margin-right: 0;
}
'''
def create_chatbot_demo():
with gr.Blocks(css=css) as demo:
gr.Markdown("# Dream 7B - Diffusion Language Model Demo")
gr.Markdown("A demonstration of the Dream 7B diffusion-based language model. Watch the text generate step-by-step.")
gr.Markdown("[Model Card](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) - [Blog Post](https://hkunlp.github.io/blog/2025/dream/)")
# STATE MANAGEMENT
chat_history = gr.State([])
# UI COMPONENTS
with gr.Row():
with gr.Column(scale=3):
chatbot_ui = gr.Chatbot(
label="Conversation",
height=500,
bubble_full_width=False # Improves layout for shorter messages
)
# Message input Row
with gr.Row(elem_classes="chat-input-row"):
user_input = gr.Textbox(
label="Your Message",
placeholder="Type your message here and press Enter...",
scale=4, # Give textbox more space
container=False, # Remove container background/padding
show_label=False
)
send_btn = gr.Button("Send", scale=1, elem_classes="small_btn")
constraints_input = gr.Textbox(
label="Word Constraints (Optional)",
info="Force specific words at positions (0-indexed from response start). Format: 'pos:word, pos:word'. Example: '0:Once, 5:upon, 10:time'",
placeholder="e.g., 0:Hello, 6:world",
value="" # Default empty
)
with gr.Column(scale=2):
output_vis = gr.HighlightedText(
label="Denoising Process Visualization",
combine_adjacent=False,
show_legend=False, # Keep legend off as requested
# Color map for legend (though hidden)
# color_map={
# "Mask": "#444444",
# "New": "#66CC66",
# "Old": "#6699CC",
# "Constraint": "#800080",
# "Error": "red"
# }
)
gr.Markdown(
"**Color Legend:** <span style='color:#444444'>■ Mask</span> | <span style='color:#66CC66'>■ Newly Generated</span> | <span style='color:#6699CC'>■ Previously Generated</span> | <span style='color:#800080'>■ Constraint</span>"
)
# Advanced generation settings
with gr.Accordion("Generation Settings", open=False):
with gr.Row():
gen_length = gr.Slider(
minimum=16, maximum=512, value=128, step=8, # Increased max length
label="Max New Tokens"
)
steps = gr.Slider(
minimum=8, maximum=512, value=128, step=8, # Increased max steps
label="Diffusion Steps"
)
with gr.Row():
temperature = gr.Slider(
minimum=0.0, maximum=1.5, value=0.6, step=0.05, # Wider range for temp
label="Temperature"
)
top_p = gr.Slider(
minimum=0.0, maximum=1.0, value=0.95, step=0.05,
label="Top-P (Nucleus Sampling)"
)
with gr.Row():
# Map UI choices to DREAM's alg parameters
remasking_strategy = gr.Radio(
choices=[
("Random", "origin"), # User friendly name -> actual param
("Entropy", "entropy"),
("MaskGit+", "maskgit_plus"),
("TopK Margin", "topk_margin"),
],
value="entropy", # Default
label="Generation Order Strategy (alg)"
)
alg_temp = gr.Slider(
minimum=0.0, maximum=1.0, value=0.1, step=0.05,
label="Order Randomness (alg_temp)" ,
info="Adds randomness to non-Random strategies. Ignored for Random."
)
with gr.Row():
visualization_delay = gr.Slider(
minimum=0.0, maximum=0.5, value=0.05, step=0.01,
label="Visualization Delay (seconds)"
)
# Clear button
clear_btn = gr.Button("Clear Conversation")
# Hidden textbox to potentially store intermediate response (might not be needed)
# current_response = gr.Textbox(visible=False)
# --- Event Handlers ---
# Helper to add message to history state
def add_message_to_history(history, message, response):
history = history.copy() # Modify copy
history.append([message, response])
return history
# Function when user submits message (Enter or Send button)
def user_message_submitted(message, history):
print(f"User submitted: '{message}'")
if not message or not message.strip():
print("Empty message submitted, doing nothing.")
# Return unchanged state if message is empty
# Need to return values for all outputs of the .submit/.click
return history, history, "", [] # history, chatbot_ui, user_input, output_vis
# Add user message to history (with None for bot response initially)
history = add_message_to_history(history, message, None)
# Prepare updated history for display in Chatbot UI
history_for_display = history.copy()
# Clear the input textbox
message_out = ""
# Clear the visualization
vis_clear = []
# Return updated history state, chatbot display, cleared input, cleared visualization
return history, history_for_display, message_out, vis_clear
# Function to generate bot response (triggered after user message is processed)
def bot_response_generator(
history, gen_length, steps, constraints_text, delay,
temperature, top_p, alg, alg_temp
):
print("--- Generating Bot Response ---")
if not history or history[-1][1] is not None:
print("History empty or last message already has response. Skipping generation.")
# Yield current state if called unnecessarily
yield history, [], "No response generated."
return
# Get the conversation history in the format the model expects
messages = format_chat_history(history) # Includes the latest user query
# Parse constraints from the textbox
parsed_constraints = parse_constraints(constraints_text)
try:
# Generate response with visualization
vis_states, response_text = dream_generate_response_with_visualization(
messages,
gen_length=gen_length,
steps=steps,
constraints=parsed_constraints,
temperature=temperature,
top_p=top_p,
alg=alg,
alg_temp=alg_temp
)
# Update the history state with the final bot response
history[-1][1] = response_text.strip()
# Yield the initial visualization state immediately
if vis_states:
yield history, vis_states[0] # Update chatbot, update visualization
else:
# Handle case where generation failed before first state
yield history, [("Generation failed.", "red")]
# Then animate through the rest of the visualization states
for state in vis_states[1:]:
time.sleep(delay)
yield history, state # Update chatbot (implicitly via history), update visualization
except Exception as e:
print(f"Error in bot_response_generator: {e}")
import traceback
traceback.print_exc()
error_msg = f"Error: {str(e)}"
# Show error in visualization
error_vis = [(error_msg, "red")]
# Update history with error message? Optional.
# history[-1][1] = error_msg
yield history, error_vis
# Function to clear everything
def clear_conversation():
print("Clearing conversation.")
return [], [], "", [] # chat_history, chatbot_ui, user_input, output_vis
# --- Wire UI elements to functions ---
# Typing in Textbox and pressing Enter
user_input.submit(
fn=user_message_submitted,
inputs=[user_input, chat_history],
outputs=[chat_history, chatbot_ui, user_input, output_vis], # Update history state, chatbot display, clear input, clear vis
queue=False # Process immediately
).then(
fn=bot_response_generator,
inputs=[
chat_history, gen_length, steps, constraints_input, visualization_delay,
temperature, top_p, remasking_strategy, alg_temp
],
outputs=[chatbot_ui, output_vis] # Update chatbot display (with new response), update visualization
# Note: history state is updated implicitly by bot_response_generator modifying its input
)
# Clicking the Send button
send_btn.click(
fn=user_message_submitted,
inputs=[user_input, chat_history],
outputs=[chat_history, chatbot_ui, user_input, output_vis],
queue=False
).then(
fn=bot_response_generator,
inputs=[
chat_history, gen_length, steps, constraints_input, visualization_delay,
temperature, top_p, remasking_strategy, alg_temp
],
outputs=[chatbot_ui, output_vis]
)
# Clicking the Clear button
clear_btn.click(
fn=clear_conversation,
inputs=[],
outputs=[chat_history, chatbot_ui, user_input, output_vis],
queue=False
)
return demo
# --- Launch the Gradio App ---
if __name__ == "__main__":
print("Creating Gradio demo...")
demo = create_chatbot_demo()
print("Launching Gradio demo...")
# Use queue for potentially long generation times
# share=True generates a public link (useful for Colab/Spaces)
demo.queue().launch(share=True, debug=True) # Add debug=True for more logs |