Dream / app.py
multimodalart's picture
Update app.py
825e87d verified
raw
history blame
24.2 kB
# dream_app.py
import torch
import numpy as np
import gradio as gr
import spaces
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel, AutoConfig
import time
import copy
# Determine device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# --- Model and Tokenizer Loading ---
model_path = "Dream-org/Dream-v0-Instruct-7B"
print(f"Loading tokenizer from {model_path}...")
# Load configuration first to get special token IDs
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
print(f"Loading model from {model_path}...")
model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True
).to(device).eval()
print("Model loaded successfully.")
# --- Constants from Dream Model ---
# Get IDs directly from config or tokenizer if available
MASK_TOKEN = tokenizer.mask_token
MASK_ID = config.mask_token_id if hasattr(config, 'mask_token_id') else tokenizer.mask_token_id
EOS_ID = config.eos_token_id if hasattr(config, 'eos_token_id') else tokenizer.eos_token_id
PAD_ID = config.pad_token_id if hasattr(config, 'pad_token_id') else tokenizer.pad_token_id # Often same as EOS
print(f"MASK_TOKEN: '{MASK_TOKEN}', MASK_ID: {MASK_ID}")
print(f"EOS_ID: {EOS_ID}, PAD_ID: {PAD_ID}")
if MASK_ID is None:
raise ValueError("Could not determine MASK_ID from model config or tokenizer.")
if EOS_ID is None:
raise ValueError("Could not determine EOS_ID from model config or tokenizer.")
if PAD_ID is None:
raise ValueError("Could not determine PAD_ID from model config or tokenizer.")
# --- Helper Functions ---
def parse_constraints(constraints_text, tokenizer):
"""Parse constraints in format: 'position:word, position:word, ...'"""
constraints = {}
processed_constraints_tokens = {}
if not constraints_text:
return constraints, processed_constraints_tokens
parts = constraints_text.split(',')
for part in parts:
if ':' not in part:
continue
pos_str, word = part.split(':', 1)
try:
pos = int(pos_str.strip())
word = word.strip()
if word and pos >= 0:
# Store original word constraint for display/debugging if needed
constraints[pos] = word
# Tokenize the word (add space for consistency if not BOS)
# Note: Dream tokenizer might handle spaces differently, adjust if needed
prefix = " " if pos > 0 else ""
tokens = tokenizer.encode(prefix + word, add_special_tokens=False)
for i, token_id in enumerate(tokens):
# Prevent overwriting multi-token constraints partially
if pos + i not in processed_constraints_tokens:
processed_constraints_tokens[pos + i] = token_id
except ValueError:
continue
except Exception as e:
print(f"Error tokenizing constraint word '{word}': {e}")
continue
# Sort by position for consistent application
processed_constraints_tokens = dict(sorted(processed_constraints_tokens.items()))
print(f"Parsed Constraints (Word): {constraints}")
print(f"Parsed Constraints (Tokens): {processed_constraints_tokens}")
return constraints, processed_constraints_tokens
def format_chat_history(history):
"""
Format chat history for the Dream model using its chat template convention.
Args:
history: List of [user_message, assistant_message] pairs
Returns:
Formatted list of message dictionaries for the model
"""
messages = []
# Add system prompt if not present (standard practice)
if not history or history[0][0] is None or history[0][0].lower() != "system":
messages.append({"role": "system", "content": "You are a helpful assistant."})
for user_msg, assistant_msg in history:
if user_msg is not None: # Handle potential system message case
messages.append({"role": "user", "content": user_msg})
if assistant_msg: # Skip if None (for the latest user message)
messages.append({"role": "assistant", "content": assistant_msg})
return messages
# --- Core Generation Logic with Visualization Hook ---
@spaces.GPU
def generate_response_with_visualization(
messages, # List of message dictionaries
gen_length=64,
steps=64,
constraints_text="", # Raw constraint text
temperature=0.2,
top_p=0.95,
top_k=None, # Added for Dream
alg="entropy", # Changed from remasking
alg_temp=0.0, # Added for Dream
visualization_delay=0.05,
tokenizer=tokenizer,
model=model,
device=device,
MASK_ID=MASK_ID,
EOS_ID=EOS_ID,
PAD_ID=PAD_ID
):
"""
Generate text with Dream model with real-time visualization using a hook.
"""
visualization_states = []
final_text = ""
# Use a list to hold previous_x, allowing nonlocal modification
# Initialize with None, it will be set after the first hook call
shared_state = {'previous_x': None}
try:
# --- 1. Prepare Inputs ---
_, parsed_constraints_tokens = parse_constraints(constraints_text, tokenizer)
# Apply chat template
# Important: Keep tokenize=False initially to get prompt length correctly
# The template adds roles and special tokens like <|im_start|> etc.
chat_input_text = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True, # Adds the prompt for the assistant's turn
tokenize=False
)
# Tokenize the full templated chat string
inputs = tokenizer(chat_input_text, return_tensors="pt", return_dict=True)
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device) # Use mask from tokenizer
prompt_length = input_ids.shape[1]
total_length = prompt_length + gen_length
# --- 2. Initialize Generation Sequence ---
# Start with the prompt, pad the rest with MASK_ID
x = torch.full((1, total_length), MASK_ID, dtype=torch.long, device=device)
x[:, :prompt_length] = input_ids.clone()
attention_mask = F.pad(attention_mask, (0, gen_length), value=1) # Extend attention mask
# Apply initial constraints to the masked sequence `x`
for pos, token_id in parsed_constraints_tokens.items():
absolute_pos = prompt_length + pos
if absolute_pos < total_length:
print(f"Applying initial constraint at pos {absolute_pos}: token {token_id}")
x[:, absolute_pos] = token_id
# Store initial state (prompt + all masked) for visualization
initial_state_vis = []
# Add prompt tokens (optional visualization, could be grayed out or skipped)
# for i in range(prompt_length):
# token_str = tokenizer.decode([x[0, i].item()], skip_special_tokens=True)
# initial_state_vis.append((token_str if token_str else " ", "#AAAAAA")) # Gray for prompt
# Add masked tokens for the generation part
for _ in range(gen_length):
initial_state_vis.append((MASK_TOKEN, "#444444")) # Dark gray for masks
visualization_states.append(initial_state_vis)
shared_state['previous_x'] = x.clone() # Initialize previous_x
# --- 3. Define the Visualization Hook ---
def generation_tokens_hook_func(step, current_x_hook, logits):
# nonlocal previous_x # Allow modification of the outer scope variable
current_x_hook = current_x_hook.clone() # Work on a copy
# --- Apply constraints within the hook ---
# This ensures constraints are respected even if the model tries to overwrite them
for pos, token_id in parsed_constraints_tokens.items():
absolute_pos = prompt_length + pos
if absolute_pos < total_length:
current_x_hook[:, absolute_pos] = token_id
# --- End Constraint Application ---
if shared_state['previous_x'] is None: # First call
shared_state['previous_x'] = current_x_hook.clone()
return current_x_hook # Must return the (potentially modified) sequence
# Generate visualization state for this step
current_state_vis = []
prev_x_step = shared_state['previous_x']
for i in range(gen_length):
pos = prompt_length + i # Absolute position in the sequence
current_token_id = current_x_hook[0, pos].item()
prev_token_id = prev_x_step[0, pos].item()
# Decode token, handling special tokens we want to hide
token_str = ""
color = "#444444" # Default: Dark Gray (Mask)
token_str_raw = tokenizer.decode([current_token_id], skip_special_tokens=False) # Keep special tokens for ID check
if current_token_id == MASK_ID:
token_str = MASK_TOKEN
color = "#444444" # Dark gray
elif current_token_id == EOS_ID or current_token_id == PAD_ID:
token_str = "" # Hide EOS/PAD visually
color = "#DDDDDD" # Use a light gray or make transparent if possible
else:
# Decode without special tokens for display if it's not MASK/EOS/PAD
token_str = tokenizer.decode([current_token_id], skip_special_tokens=True).strip()
if not token_str: token_str = token_str_raw # Fallback if strip removes everything (e.g., space)
if prev_token_id == MASK_ID:
# Newly revealed in this step
color = "#66CC66" # Light green (Simplified from confidence levels)
else:
# Previously revealed
color = "#6699CC" # Light blue
current_state_vis.append((token_str if token_str else " ", color)) # Ensure non-empty tuple element
visualization_states.append(current_state_vis)
shared_state['previous_x'] = current_x_hook.clone() # Update previous_x for the next step
return current_x_hook # Return the sequence (constraints applied)
# --- 4. Run Diffusion Generation ---
print("Starting diffusion generation...")
start_time = time.time()
output = model.diffusion_generate(
input_ids=x[:, :prompt_length], # Pass only the initial prompt to diffusion_generate
# as it handles the masking internally based on MASK_ID
attention_mask=attention_mask, # Provide the full attention mask
max_new_tokens=gen_length,
output_history=False, # We capture history via the hook
return_dict_in_generate=True,
steps=steps,
temperature=temperature,
top_p=top_p,
top_k=top_k,
alg=alg,
alg_temp=alg_temp if alg != 'origin' else None, # alg_temp only for confidence-based
# Pass the hook function
generation_tokens_hook_func=generation_tokens_hook_func,
# Ensure the initial masked sequence `x` is used correctly if needed by internal logic
# Depending on the exact implementation of diffusion_generate, passing x directly might be needed
# Check Dream's generation_utils if issues arise. For now, assume it uses input_ids + max_new_tokens
)
end_time = time.time()
print(f"Diffusion generation finished in {end_time - start_time:.2f} seconds.")
# --- 5. Process Final Output ---
# The hook has already built visualization_states
final_sequence = output.sequences[0]
# Decode the generated part, skipping special tokens for the final text output
response_tokens = final_sequence[prompt_length:]
# Filter out PAD tokens before final decode, keep EOS if needed conceptually, but skip for clean text
response_tokens_cleaned = [tok for tok in response_tokens if tok != PAD_ID] # Keep EOS initially if needed
final_text = tokenizer.decode(
response_tokens_cleaned,
skip_special_tokens=True, # Skip EOS, BOS, etc.
clean_up_tokenization_spaces=True # Recommended for cleaner output
).strip()
# Ensure the last state in visualization matches the final text (debug check)
# print(f"Last Vis State Tokens: {''.join([t[0] for t in visualization_states[-1]]).strip()}")
# print(f"Final Decoded Text: {final_text}")
except Exception as e:
print(f"Error during generation: {e}")
import traceback
traceback.print_exc()
# Add error message to visualization
error_msg = f"Error: {str(e)}"
visualization_states.append([(error_msg, "red")])
final_text = error_msg # Display error in the chatbot too
# Make sure at least the initial state is present
if not visualization_states:
visualization_states.append([("Error: No states generated.", "red")])
return visualization_states, final_text
# --- Gradio UI Definition ---
css = '''
.category-legend{display:none}
button{height: 60px}
.token-text { white-space: pre; } /* Preserve spaces in tokens */
footer { display: none !important; visibility: hidden !important; }
'''
def create_chatbot_demo():
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
gr.Markdown("# Dream 7B - Diffusion Language Model Demo")
gr.Markdown(
"[[Model Card](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)] "
"[[Blog Post](https://hkunlp.github.io/blog/2025/dream/)] "
"(Note: Visualization shows token reveal steps, colors indicate status: Gray=Masked, Green=Newly Revealed, Blue=Previously Revealed)"
)
# STATE MANAGEMENT
chat_history = gr.State([])
# Store constraints parsed into token IDs
parsed_constraints_state = gr.State({})
# UI COMPONENTS
with gr.Row():
with gr.Column(scale=3):
chatbot_ui = gr.Chatbot(
label="Conversation",
height=500,
bubble_full_width=False # Makes bubbles wrap content
)
# Message input
with gr.Group():
with gr.Row():
user_input = gr.Textbox(
label="Your Message",
placeholder="Type your message here...",
show_label=False,
scale=7
)
send_btn = gr.Button("Send", scale=1)
constraints_input = gr.Textbox(
label="Word Constraints (Experimental)",
info="Place specific words at positions (0-indexed). Format: 'pos:word, pos:word'. Example: '0:Once, 5:upon, 10:time'. Multi-token words supported.",
placeholder="0:The, 10:story",
value=""
)
with gr.Column(scale=2):
output_vis = gr.HighlightedText(
label="Denoising Process Visualization",
combine_adjacent=False,
show_legend=False, # Legend not very informative here
height=560, # Match chatbot height + input box approx
elem_classes=["token-text"] # Apply custom class for styling if needed
)
# Advanced generation settings
with gr.Accordion("Generation Settings", open=False):
with gr.Row():
gen_length = gr.Slider(
minimum=16, maximum=512, value=128, step=8,
label="Max New Tokens"
)
steps = gr.Slider(
minimum=8, maximum=512, value=128, step=4,
label="Denoising Steps"
)
with gr.Row():
temperature = gr.Slider(
minimum=0.0, maximum=1.0, value=0.2, step=0.05,
label="Temperature"
)
top_p = gr.Slider(
minimum=0.1, maximum=1.0, value=0.95, step=0.05,
label="Top-P"
)
top_k = gr.Slider(
minimum=0, maximum=200, value=0, step=5,
label="Top-K (0=disabled)"
)
with gr.Row():
alg = gr.Radio(
choices=['origin', 'maskgit_plus', 'topk_margin', 'entropy'],
value='entropy',
label="Sampling Algorithm (`alg`)"
)
alg_temp = gr.Slider(
minimum=0.0, maximum=1.0, value=0.0, step=0.05,
label="Algorithm Temp (`alg_temp`, adds randomness to confidence-based `alg`)"
)
with gr.Row():
visualization_delay = gr.Slider(
minimum=0.0, maximum=0.5, value=0.02, step=0.01,
label="Visualization Delay (seconds)"
)
# Clear button
clear_btn = gr.Button("Clear Conversation")
# --- Event Handlers ---
def add_message(history, message, response):
"""Add a message pair to the history and return the updated history"""
# Ensure history is a list
if not isinstance(history, list):
history = []
history.append([message, response])
return history
def user_message_submitted(message, history):
"""Process a submitted user message"""
if not message.strip():
return history, history, "", [] # No change if empty
# Add user message (response is None for now)
history = add_message(history, message, None)
# Return updated history for display, clear input box
return history, history, "", [] # history, chatbot_ui, user_input, output_vis
def bot_response_stream(
history, # Current chat history (list of lists)
gen_length, steps, constraints, # Generation settings
temperature, top_p, top_k, alg, alg_temp, # Sampling settings
delay # Visualization delay
):
"""Generate bot response and stream visualization states"""
if not history or history[-1][1] is not None: # Check if history is present and last response isn't already set
print("Skipping bot response generation: No new user message.")
# Yield empty state if needed to prevent errors downstream
# Ensure history is returned correctly if nothing happens
yield history, [], "Internal Error: No user message found."
return
# Format messages for the model
# Exclude the last entry as it only contains the user message
messages_for_model = format_chat_history(history) # Already includes system prompt
print("\n--- Generating Bot Response ---")
print(f"History: {history}")
print(f"Messages for model: {messages_for_model}")
print(f"Constraints text: '{constraints}'")
print(f"Gen length: {gen_length}, Steps: {steps}, Temp: {temperature}, Top-P: {top_p}, Top-K: {top_k}, Alg: {alg}, Alg Temp: {alg_temp}")
# Call the generation function
vis_states, response_text = generate_response_with_visualization(
messages_for_model,
gen_length=gen_length,
steps=steps,
constraints_text=constraints,
temperature=temperature,
top_p=top_p if top_p < 1.0 else None, # None disables top-p
top_k=top_k if top_k > 0 else None, # None disables top-k
alg=alg,
alg_temp=alg_temp,
visualization_delay=delay,
# Pass other necessary args like tokenizer, model if not global
)
print(f"Generated response text: '{response_text}'")
print(f"Number of visualization states: {len(vis_states)}")
# Update the history with the final response
# Make sure history is mutable if needed or reassign
if history:
history[-1][1] = response_text
else:
print("Warning: History was empty when trying to update response.")
# Stream the visualization states
if not vis_states:
print("Warning: No visualization states were generated.")
# Yield something to prevent downstream errors
yield history, [("Error: No visualization.", "red")], response_text
return
# Yield initial state immediately if desired, or just start loop
# yield history, vis_states[0], response_text
for state in vis_states:
yield history, state, response_text # Yield updated history, current vis state, final text
time.sleep(delay) # Pause between steps
# Final yield to ensure the last state is displayed
yield history, vis_states[-1], response_text
def clear_conversation():
"""Clear the conversation history and visualization"""
return [], [], "", [] # history, chatbot, user_input, output_vis
# --- Event Wiring ---
# Clear button
clear_btn.click(
fn=clear_conversation,
inputs=[],
outputs=[chat_history, chatbot_ui, user_input, output_vis]
)
# User message submission flow (2-step using .then)
# 1. User submits message -> Update history and chatbot UI immediately
submit_action = user_input.submit(
fn=user_message_submitted,
inputs=[user_input, chat_history],
outputs=[chat_history, chatbot_ui, user_input, output_vis] # Update chatbot, clear input
)
# Connect send button to the same function
send_action = send_btn.click(
fn=user_message_submitted,
inputs=[user_input, chat_history],
outputs=[chat_history, chatbot_ui, user_input, output_vis]
)
# 2. After UI update -> Trigger bot response generation and streaming
# Use the updated chat_history from the first step
submit_action.then(
fn=bot_response_stream,
inputs=[
chat_history, gen_length, steps, constraints_input,
temperature, top_p, top_k, alg, alg_temp,
visualization_delay
],
outputs=[chatbot_ui, output_vis, user_input] # Update chatbot, visualization. Keep user_input as output to potentially display final text/error? (Check Gradio docs for Textbox output binding on yield) Let's remove user_input from outputs here.
)
send_action.then(
fn=bot_response_stream,
inputs=[
chat_history, gen_length, steps, constraints_input,
temperature, top_p, top_k, alg, alg_temp,
visualization_delay
],
outputs=[chatbot_ui, output_vis] # Update chatbot and visualization
)
# Clear input after send/submit (already handled in user_message_submitted)
# submit_action.then(lambda: "", outputs=user_input)
# send_action.then(lambda: "", outputs=user_input)
return demo
# --- Launch the Gradio App ---
if __name__ == "__main__":
demo = create_chatbot_demo()
# Using queue for streaming and handling multiple users
demo.queue().launch(debug=True, share=True)