Spaces:
Running
on
Zero
Running
on
Zero
# dream_app.py | |
import torch | |
import numpy as np | |
import gradio as gr | |
import spaces | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, AutoModel, AutoConfig | |
import time | |
import copy | |
# Determine device | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
print(f"Using device: {device}") | |
# --- Model and Tokenizer Loading --- | |
model_path = "Dream-org/Dream-v0-Instruct-7B" | |
print(f"Loading tokenizer from {model_path}...") | |
# Load configuration first to get special token IDs | |
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) | |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
print(f"Loading model from {model_path}...") | |
model = AutoModel.from_pretrained( | |
model_path, | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True | |
).to(device).eval() | |
print("Model loaded successfully.") | |
# --- Constants from Dream Model --- | |
# Get IDs directly from config or tokenizer if available | |
MASK_TOKEN = tokenizer.mask_token | |
MASK_ID = config.mask_token_id if hasattr(config, 'mask_token_id') else tokenizer.mask_token_id | |
EOS_ID = config.eos_token_id if hasattr(config, 'eos_token_id') else tokenizer.eos_token_id | |
PAD_ID = config.pad_token_id if hasattr(config, 'pad_token_id') else tokenizer.pad_token_id # Often same as EOS | |
print(f"MASK_TOKEN: '{MASK_TOKEN}', MASK_ID: {MASK_ID}") | |
print(f"EOS_ID: {EOS_ID}, PAD_ID: {PAD_ID}") | |
if MASK_ID is None: | |
raise ValueError("Could not determine MASK_ID from model config or tokenizer.") | |
if EOS_ID is None: | |
raise ValueError("Could not determine EOS_ID from model config or tokenizer.") | |
if PAD_ID is None: | |
raise ValueError("Could not determine PAD_ID from model config or tokenizer.") | |
# --- Helper Functions --- | |
def parse_constraints(constraints_text, tokenizer): | |
"""Parse constraints in format: 'position:word, position:word, ...'""" | |
constraints = {} | |
processed_constraints_tokens = {} | |
if not constraints_text: | |
return constraints, processed_constraints_tokens | |
parts = constraints_text.split(',') | |
for part in parts: | |
if ':' not in part: | |
continue | |
pos_str, word = part.split(':', 1) | |
try: | |
pos = int(pos_str.strip()) | |
word = word.strip() | |
if word and pos >= 0: | |
# Store original word constraint for display/debugging if needed | |
constraints[pos] = word | |
# Tokenize the word (add space for consistency if not BOS) | |
# Note: Dream tokenizer might handle spaces differently, adjust if needed | |
prefix = " " if pos > 0 else "" | |
tokens = tokenizer.encode(prefix + word, add_special_tokens=False) | |
for i, token_id in enumerate(tokens): | |
# Prevent overwriting multi-token constraints partially | |
if pos + i not in processed_constraints_tokens: | |
processed_constraints_tokens[pos + i] = token_id | |
except ValueError: | |
continue | |
except Exception as e: | |
print(f"Error tokenizing constraint word '{word}': {e}") | |
continue | |
# Sort by position for consistent application | |
processed_constraints_tokens = dict(sorted(processed_constraints_tokens.items())) | |
print(f"Parsed Constraints (Word): {constraints}") | |
print(f"Parsed Constraints (Tokens): {processed_constraints_tokens}") | |
return constraints, processed_constraints_tokens | |
def format_chat_history(history): | |
""" | |
Format chat history for the Dream model using its chat template convention. | |
Args: | |
history: List of [user_message, assistant_message] pairs | |
Returns: | |
Formatted list of message dictionaries for the model | |
""" | |
messages = [] | |
# Add system prompt if not present (standard practice) | |
if not history or history[0][0] is None or history[0][0].lower() != "system": | |
messages.append({"role": "system", "content": "You are a helpful assistant."}) | |
for user_msg, assistant_msg in history: | |
if user_msg is not None: # Handle potential system message case | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: # Skip if None (for the latest user message) | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
return messages | |
# --- Core Generation Logic with Visualization Hook --- | |
def generate_response_with_visualization( | |
messages, # List of message dictionaries | |
gen_length=64, | |
steps=64, | |
constraints_text="", # Raw constraint text | |
temperature=0.2, | |
top_p=0.95, | |
top_k=None, # Added for Dream | |
alg="entropy", # Changed from remasking | |
alg_temp=0.0, # Added for Dream | |
visualization_delay=0.05, | |
tokenizer=tokenizer, | |
model=model, | |
device=device, | |
MASK_ID=MASK_ID, | |
EOS_ID=EOS_ID, | |
PAD_ID=PAD_ID | |
): | |
""" | |
Generate text with Dream model with real-time visualization using a hook. | |
""" | |
visualization_states = [] | |
final_text = "" | |
# Use a list to hold previous_x, allowing nonlocal modification | |
# Initialize with None, it will be set after the first hook call | |
shared_state = {'previous_x': None} | |
try: | |
# --- 1. Prepare Inputs --- | |
_, parsed_constraints_tokens = parse_constraints(constraints_text, tokenizer) | |
# Apply chat template | |
# Important: Keep tokenize=False initially to get prompt length correctly | |
# The template adds roles and special tokens like <|im_start|> etc. | |
chat_input_text = tokenizer.apply_chat_template( | |
messages, | |
add_generation_prompt=True, # Adds the prompt for the assistant's turn | |
tokenize=False | |
) | |
# Tokenize the full templated chat string | |
inputs = tokenizer(chat_input_text, return_tensors="pt", return_dict=True) | |
input_ids = inputs.input_ids.to(device) | |
attention_mask = inputs.attention_mask.to(device) # Use mask from tokenizer | |
prompt_length = input_ids.shape[1] | |
total_length = prompt_length + gen_length | |
# --- 2. Initialize Generation Sequence --- | |
# Start with the prompt, pad the rest with MASK_ID | |
x = torch.full((1, total_length), MASK_ID, dtype=torch.long, device=device) | |
x[:, :prompt_length] = input_ids.clone() | |
attention_mask = F.pad(attention_mask, (0, gen_length), value=1) # Extend attention mask | |
# Apply initial constraints to the masked sequence `x` | |
for pos, token_id in parsed_constraints_tokens.items(): | |
absolute_pos = prompt_length + pos | |
if absolute_pos < total_length: | |
print(f"Applying initial constraint at pos {absolute_pos}: token {token_id}") | |
x[:, absolute_pos] = token_id | |
# Store initial state (prompt + all masked) for visualization | |
initial_state_vis = [] | |
# Add prompt tokens (optional visualization, could be grayed out or skipped) | |
# for i in range(prompt_length): | |
# token_str = tokenizer.decode([x[0, i].item()], skip_special_tokens=True) | |
# initial_state_vis.append((token_str if token_str else " ", "#AAAAAA")) # Gray for prompt | |
# Add masked tokens for the generation part | |
for _ in range(gen_length): | |
initial_state_vis.append((MASK_TOKEN, "#444444")) # Dark gray for masks | |
visualization_states.append(initial_state_vis) | |
shared_state['previous_x'] = x.clone() # Initialize previous_x | |
# --- 3. Define the Visualization Hook --- | |
def generation_tokens_hook_func(step, current_x_hook, logits): | |
# nonlocal previous_x # Allow modification of the outer scope variable | |
current_x_hook = current_x_hook.clone() # Work on a copy | |
# --- Apply constraints within the hook --- | |
# This ensures constraints are respected even if the model tries to overwrite them | |
for pos, token_id in parsed_constraints_tokens.items(): | |
absolute_pos = prompt_length + pos | |
if absolute_pos < total_length: | |
current_x_hook[:, absolute_pos] = token_id | |
# --- End Constraint Application --- | |
if shared_state['previous_x'] is None: # First call | |
shared_state['previous_x'] = current_x_hook.clone() | |
return current_x_hook # Must return the (potentially modified) sequence | |
# Generate visualization state for this step | |
current_state_vis = [] | |
prev_x_step = shared_state['previous_x'] | |
for i in range(gen_length): | |
pos = prompt_length + i # Absolute position in the sequence | |
current_token_id = current_x_hook[0, pos].item() | |
prev_token_id = prev_x_step[0, pos].item() | |
# Decode token, handling special tokens we want to hide | |
token_str = "" | |
color = "#444444" # Default: Dark Gray (Mask) | |
token_str_raw = tokenizer.decode([current_token_id], skip_special_tokens=False) # Keep special tokens for ID check | |
if current_token_id == MASK_ID: | |
token_str = MASK_TOKEN | |
color = "#444444" # Dark gray | |
elif current_token_id == EOS_ID or current_token_id == PAD_ID: | |
token_str = "" # Hide EOS/PAD visually | |
color = "#DDDDDD" # Use a light gray or make transparent if possible | |
else: | |
# Decode without special tokens for display if it's not MASK/EOS/PAD | |
token_str = tokenizer.decode([current_token_id], skip_special_tokens=True).strip() | |
if not token_str: token_str = token_str_raw # Fallback if strip removes everything (e.g., space) | |
if prev_token_id == MASK_ID: | |
# Newly revealed in this step | |
color = "#66CC66" # Light green (Simplified from confidence levels) | |
else: | |
# Previously revealed | |
color = "#6699CC" # Light blue | |
current_state_vis.append((token_str if token_str else " ", color)) # Ensure non-empty tuple element | |
visualization_states.append(current_state_vis) | |
shared_state['previous_x'] = current_x_hook.clone() # Update previous_x for the next step | |
return current_x_hook # Return the sequence (constraints applied) | |
# --- 4. Run Diffusion Generation --- | |
print("Starting diffusion generation...") | |
start_time = time.time() | |
output = model.diffusion_generate( | |
input_ids=x[:, :prompt_length], # Pass only the initial prompt to diffusion_generate | |
# as it handles the masking internally based on MASK_ID | |
attention_mask=attention_mask, # Provide the full attention mask | |
max_new_tokens=gen_length, | |
output_history=False, # We capture history via the hook | |
return_dict_in_generate=True, | |
steps=steps, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
alg=alg, | |
alg_temp=alg_temp if alg != 'origin' else None, # alg_temp only for confidence-based | |
# Pass the hook function | |
generation_tokens_hook_func=generation_tokens_hook_func, | |
# Ensure the initial masked sequence `x` is used correctly if needed by internal logic | |
# Depending on the exact implementation of diffusion_generate, passing x directly might be needed | |
# Check Dream's generation_utils if issues arise. For now, assume it uses input_ids + max_new_tokens | |
) | |
end_time = time.time() | |
print(f"Diffusion generation finished in {end_time - start_time:.2f} seconds.") | |
# --- 5. Process Final Output --- | |
# The hook has already built visualization_states | |
final_sequence = output.sequences[0] | |
# Decode the generated part, skipping special tokens for the final text output | |
response_tokens = final_sequence[prompt_length:] | |
# Filter out PAD tokens before final decode, keep EOS if needed conceptually, but skip for clean text | |
response_tokens_cleaned = [tok for tok in response_tokens if tok != PAD_ID] # Keep EOS initially if needed | |
final_text = tokenizer.decode( | |
response_tokens_cleaned, | |
skip_special_tokens=True, # Skip EOS, BOS, etc. | |
clean_up_tokenization_spaces=True # Recommended for cleaner output | |
).strip() | |
# Ensure the last state in visualization matches the final text (debug check) | |
# print(f"Last Vis State Tokens: {''.join([t[0] for t in visualization_states[-1]]).strip()}") | |
# print(f"Final Decoded Text: {final_text}") | |
except Exception as e: | |
print(f"Error during generation: {e}") | |
import traceback | |
traceback.print_exc() | |
# Add error message to visualization | |
error_msg = f"Error: {str(e)}" | |
visualization_states.append([(error_msg, "red")]) | |
final_text = error_msg # Display error in the chatbot too | |
# Make sure at least the initial state is present | |
if not visualization_states: | |
visualization_states.append([("Error: No states generated.", "red")]) | |
return visualization_states, final_text | |
# --- Gradio UI Definition --- | |
css = ''' | |
.category-legend{display:none} | |
button{height: 60px} | |
.token-text { white-space: pre; } /* Preserve spaces in tokens */ | |
footer { display: none !important; visibility: hidden !important; } | |
''' | |
def create_chatbot_demo(): | |
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# Dream 7B - Diffusion Language Model Demo") | |
gr.Markdown( | |
"[[Model Card](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)] " | |
"[[Blog Post](https://hkunlp.github.io/blog/2025/dream/)] " | |
"(Note: Visualization shows token reveal steps, colors indicate status: Gray=Masked, Green=Newly Revealed, Blue=Previously Revealed)" | |
) | |
# STATE MANAGEMENT | |
chat_history = gr.State([]) | |
# Store constraints parsed into token IDs | |
parsed_constraints_state = gr.State({}) | |
# UI COMPONENTS | |
with gr.Row(): | |
with gr.Column(scale=3): | |
chatbot_ui = gr.Chatbot( | |
label="Conversation", | |
height=500, | |
bubble_full_width=False # Makes bubbles wrap content | |
) | |
# Message input | |
with gr.Group(): | |
with gr.Row(): | |
user_input = gr.Textbox( | |
label="Your Message", | |
placeholder="Type your message here...", | |
show_label=False, | |
scale=7 | |
) | |
send_btn = gr.Button("Send", scale=1) | |
constraints_input = gr.Textbox( | |
label="Word Constraints (Experimental)", | |
info="Place specific words at positions (0-indexed). Format: 'pos:word, pos:word'. Example: '0:Once, 5:upon, 10:time'. Multi-token words supported.", | |
placeholder="0:The, 10:story", | |
value="" | |
) | |
with gr.Column(scale=2): | |
output_vis = gr.HighlightedText( | |
label="Denoising Process Visualization", | |
combine_adjacent=False, | |
show_legend=False, # Legend not very informative here | |
height=560, # Match chatbot height + input box approx | |
elem_classes=["token-text"] # Apply custom class for styling if needed | |
) | |
# Advanced generation settings | |
with gr.Accordion("Generation Settings", open=False): | |
with gr.Row(): | |
gen_length = gr.Slider( | |
minimum=16, maximum=512, value=128, step=8, | |
label="Max New Tokens" | |
) | |
steps = gr.Slider( | |
minimum=8, maximum=512, value=128, step=4, | |
label="Denoising Steps" | |
) | |
with gr.Row(): | |
temperature = gr.Slider( | |
minimum=0.0, maximum=1.0, value=0.2, step=0.05, | |
label="Temperature" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, maximum=1.0, value=0.95, step=0.05, | |
label="Top-P" | |
) | |
top_k = gr.Slider( | |
minimum=0, maximum=200, value=0, step=5, | |
label="Top-K (0=disabled)" | |
) | |
with gr.Row(): | |
alg = gr.Radio( | |
choices=['origin', 'maskgit_plus', 'topk_margin', 'entropy'], | |
value='entropy', | |
label="Sampling Algorithm (`alg`)" | |
) | |
alg_temp = gr.Slider( | |
minimum=0.0, maximum=1.0, value=0.0, step=0.05, | |
label="Algorithm Temp (`alg_temp`, adds randomness to confidence-based `alg`)" | |
) | |
with gr.Row(): | |
visualization_delay = gr.Slider( | |
minimum=0.0, maximum=0.5, value=0.02, step=0.01, | |
label="Visualization Delay (seconds)" | |
) | |
# Clear button | |
clear_btn = gr.Button("Clear Conversation") | |
# --- Event Handlers --- | |
def add_message(history, message, response): | |
"""Add a message pair to the history and return the updated history""" | |
# Ensure history is a list | |
if not isinstance(history, list): | |
history = [] | |
history.append([message, response]) | |
return history | |
def user_message_submitted(message, history): | |
"""Process a submitted user message""" | |
if not message.strip(): | |
return history, history, "", [] # No change if empty | |
# Add user message (response is None for now) | |
history = add_message(history, message, None) | |
# Return updated history for display, clear input box | |
return history, history, "", [] # history, chatbot_ui, user_input, output_vis | |
def bot_response_stream( | |
history, # Current chat history (list of lists) | |
gen_length, steps, constraints, # Generation settings | |
temperature, top_p, top_k, alg, alg_temp, # Sampling settings | |
delay # Visualization delay | |
): | |
"""Generate bot response and stream visualization states""" | |
if not history or history[-1][1] is not None: # Check if history is present and last response isn't already set | |
print("Skipping bot response generation: No new user message.") | |
# Yield empty state if needed to prevent errors downstream | |
# Ensure history is returned correctly if nothing happens | |
yield history, [], "Internal Error: No user message found." | |
return | |
# Format messages for the model | |
# Exclude the last entry as it only contains the user message | |
messages_for_model = format_chat_history(history) # Already includes system prompt | |
print("\n--- Generating Bot Response ---") | |
print(f"History: {history}") | |
print(f"Messages for model: {messages_for_model}") | |
print(f"Constraints text: '{constraints}'") | |
print(f"Gen length: {gen_length}, Steps: {steps}, Temp: {temperature}, Top-P: {top_p}, Top-K: {top_k}, Alg: {alg}, Alg Temp: {alg_temp}") | |
# Call the generation function | |
vis_states, response_text = generate_response_with_visualization( | |
messages_for_model, | |
gen_length=gen_length, | |
steps=steps, | |
constraints_text=constraints, | |
temperature=temperature, | |
top_p=top_p if top_p < 1.0 else None, # None disables top-p | |
top_k=top_k if top_k > 0 else None, # None disables top-k | |
alg=alg, | |
alg_temp=alg_temp, | |
visualization_delay=delay, | |
# Pass other necessary args like tokenizer, model if not global | |
) | |
print(f"Generated response text: '{response_text}'") | |
print(f"Number of visualization states: {len(vis_states)}") | |
# Update the history with the final response | |
# Make sure history is mutable if needed or reassign | |
if history: | |
history[-1][1] = response_text | |
else: | |
print("Warning: History was empty when trying to update response.") | |
# Stream the visualization states | |
if not vis_states: | |
print("Warning: No visualization states were generated.") | |
# Yield something to prevent downstream errors | |
yield history, [("Error: No visualization.", "red")], response_text | |
return | |
# Yield initial state immediately if desired, or just start loop | |
# yield history, vis_states[0], response_text | |
for state in vis_states: | |
yield history, state, response_text # Yield updated history, current vis state, final text | |
time.sleep(delay) # Pause between steps | |
# Final yield to ensure the last state is displayed | |
yield history, vis_states[-1], response_text | |
def clear_conversation(): | |
"""Clear the conversation history and visualization""" | |
return [], [], "", [] # history, chatbot, user_input, output_vis | |
# --- Event Wiring --- | |
# Clear button | |
clear_btn.click( | |
fn=clear_conversation, | |
inputs=[], | |
outputs=[chat_history, chatbot_ui, user_input, output_vis] | |
) | |
# User message submission flow (2-step using .then) | |
# 1. User submits message -> Update history and chatbot UI immediately | |
submit_action = user_input.submit( | |
fn=user_message_submitted, | |
inputs=[user_input, chat_history], | |
outputs=[chat_history, chatbot_ui, user_input, output_vis] # Update chatbot, clear input | |
) | |
# Connect send button to the same function | |
send_action = send_btn.click( | |
fn=user_message_submitted, | |
inputs=[user_input, chat_history], | |
outputs=[chat_history, chatbot_ui, user_input, output_vis] | |
) | |
# 2. After UI update -> Trigger bot response generation and streaming | |
# Use the updated chat_history from the first step | |
submit_action.then( | |
fn=bot_response_stream, | |
inputs=[ | |
chat_history, gen_length, steps, constraints_input, | |
temperature, top_p, top_k, alg, alg_temp, | |
visualization_delay | |
], | |
outputs=[chatbot_ui, output_vis, user_input] # Update chatbot, visualization. Keep user_input as output to potentially display final text/error? (Check Gradio docs for Textbox output binding on yield) Let's remove user_input from outputs here. | |
) | |
send_action.then( | |
fn=bot_response_stream, | |
inputs=[ | |
chat_history, gen_length, steps, constraints_input, | |
temperature, top_p, top_k, alg, alg_temp, | |
visualization_delay | |
], | |
outputs=[chatbot_ui, output_vis] # Update chatbot and visualization | |
) | |
# Clear input after send/submit (already handled in user_message_submitted) | |
# submit_action.then(lambda: "", outputs=user_input) | |
# send_action.then(lambda: "", outputs=user_input) | |
return demo | |
# --- Launch the Gradio App --- | |
if __name__ == "__main__": | |
demo = create_chatbot_demo() | |
# Using queue for streaming and handling multiple users | |
demo.queue().launch(debug=True, share=True) |