Dream / app.py
multimodalart's picture
Update app.py
ecc9ba6 verified
raw
history blame
21.5 kB
# dream_app.py
import torch
import numpy as np
import gradio as gr
import spaces # Ensure spaces is installed if needed for GPU decorator
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel, AutoConfig
import time
import re
from typing import List, Dict, Tuple, Optional
# Load model configuration to get special token IDs
config = AutoConfig.from_pretrained("Dream-org/Dream-v0-Instruct-7B", trust_remote_code=True)
# Use AutoModel for the base model loading, relying on trust_remote_code=True
# for the custom DreamModel class and generation mixin.
model_path = "Dream-org/Dream-v0-Instruct-7B"
# Determine device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# Load model and tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
print("Loading model...")
# Ensure torch_dtype is set appropriately for your hardware if needed
model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32, # Use bfloat16 only on CUDA
trust_remote_code=True
)
model = model.to(device).eval()
print("Model loaded.")
# Constants from Dream's config/tokenizer
# Use attributes from loaded config/tokenizer objects
MASK_TOKEN = tokenizer.mask_token
MASK_ID = config.mask_token_id
PAD_ID = config.pad_token_id
EOS_ID = config.eos_token_id
# Make sure EOS_ID and PAD_ID are handled correctly; Dream uses the same ID for both
SPECIAL_TOKEN_IDS = {PAD_ID, EOS_ID, MASK_ID}
# Add other special tokens defined in tokenizer_config.json if needed for hiding
# Get IDs for im_start, im_end etc. if they should also be hidden/handled specially
IM_START_ID = tokenizer.convert_tokens_to_ids("<|im_start|>")
IM_END_ID = tokenizer.convert_tokens_to_ids("<|im_end|>")
SPECIAL_TOKEN_IDS.add(IM_START_ID)
SPECIAL_TOKEN_IDS.add(IM_END_ID)
# --- Helper Functions ---
def parse_constraints(constraints_text: str) -> Dict[int, List[int]]:
"""
Parse constraints in format: 'position:word, position:word, ...'
Returns a dictionary mapping the starting position (0-indexed from the start
of the *generated* sequence) to a list of token IDs for the constraint word.
"""
constraints = {}
if not constraints_text:
return constraints
parts = constraints_text.split(',')
for part in parts:
if ':' not in part:
continue
pos_str, word = part.split(':', 1)
try:
# Position relative to the start of the *generation*
pos = int(pos_str.strip())
word = word.strip()
# Tokenize the word - add leading space if not BOS? Dream handles spaces.
# Check Dream tokenizer behavior for spaces. Assuming standard behavior:
token_ids = tokenizer.encode(" " + word if pos > 0 else word, add_special_tokens=False)
if token_ids and pos >= 0:
constraints[pos] = token_ids
except ValueError:
continue # Ignore malformed constraint parts
except Exception as e:
print(f"Warning: Error processing constraint '{part}': {e}")
continue
return constraints
def format_chat_history(history: List[List[Optional[str]]]) -> List[Dict[str, str]]:
"""
Format chat history for the Dream model's chat template.
Args:
history: List of [user_message, assistant_message] pairs.
The last assistant_message might be None.
Returns:
Formatted list of message dictionaries for tokenizer.apply_chat_template.
"""
messages = []
# Check if the first message is a system prompt, handle accordingly if needed
# Based on Dream's examples, the template adds a default system prompt if none exists.
# If history starts with System, it should be handled by the template.
# Let's assume the template handles the system prompt correctly.
for user_msg, assistant_msg in history:
if user_msg: # Defensive check
messages.append({"role": "user", "content": user_msg})
# Add assistant message only if it exists (it won't for the last turn before generation)
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
return messages
# --- Core Generation Logic with Live Visualization ---
@spaces.GPU # Decorator for Hugging Face Spaces GPU usage
def generate_dream_response(
history: List[List[Optional[str]]],
gen_length: int,
steps: int,
constraints_text: str,
temperature: float,
top_p: Optional[float],
top_k: Optional[int],
alg: str,
alg_temp: Optional[float],
visualization_delay: float
) -> List[Tuple[str, str]]:
"""
Generates text using the Dream model and yields visualization states live.
Args:
history: Chat history.
gen_length: Max new tokens to generate.
steps: Number of diffusion steps.
constraints_text: User-provided constraints string.
temperature: Sampling temperature.
top_p: Top-p sampling nucleus.
top_k: Top-k sampling.
alg: Remasking algorithm ('origin', 'maskgit_plus', 'topk_margin', 'entropy').
alg_temp: Temperature for confidence-based algorithms.
visualization_delay: Delay between visualization steps.
Yields:
Tuple[List[List[Optional[str]]], List[Tuple[str, Optional[str]]], str]:
- Updated history
- Visualization data for HighlightedText
- Final response text (repeated in each yield)
"""
if not history or not history[-1][0]:
# No user message to respond to
yield history, [("No input message found.", "red")], ""
return
# --- 1. Preparation ---
last_user_message = history[-1][0]
messages_for_template = format_chat_history(history) # Includes the latest user message
# Parse constraints relative to the *generated* sequence
parsed_constraints = parse_constraints(constraints_text) # Dict[rel_pos, List[token_id]]
# Prepare inputs using the chat template
try:
inputs = tokenizer.apply_chat_template(
messages_for_template,
return_tensors="pt",
return_dict=True,
add_generation_prompt=True # Important for instruct models
)
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
prompt_length = input_ids.shape[1]
except Exception as e:
print(f"Error applying chat template: {e}")
yield history, [("Error preparing input.", "red")], ""
return
# Calculate total sequence length for the model
# Max length constraint from model config (e.g., 2048 for original Dream?)
# Let's use a reasonable default or allow configuration if needed.
# The provided code uses max_position_embeddings=131072, let's stick to user input + gen_length.
total_length = prompt_length + gen_length
# --- 2. Visualization Setup ---
# This list will store the token sequence (just the generated part) at each step
step_sequence_history: List[torch.Tensor] = []
previous_step_tokens = None # Keep track of the previous step's state
# Define the hook function *inside* this function to capture state
def live_visualization_hook(step: Optional[int], x: torch.Tensor, logits: Optional[torch.Tensor]) -> torch.Tensor:
nonlocal step_sequence_history, parsed_constraints, prompt_length
# --- Apply Constraints ---
# Constraints are applied *after* the model proposes tokens but *before* they are finalized for the step
# Note: The hook receives the state *before* the next model call in the next step,
# or the final state after the last step. Let's apply constraints consistently.
# The `diffusion_generate` calls the hook *after* updating x based on sampling.
current_x = x.clone() # Work on a copy
for rel_pos, word_token_ids in parsed_constraints.items():
abs_start_pos = prompt_length + rel_pos
abs_end_pos = abs_start_pos + len(word_token_ids)
# Ensure the constraint fits within the generation length
if abs_start_pos < total_length and abs_end_pos <= total_length:
try:
constraint_tensor = torch.tensor(word_token_ids, dtype=torch.long, device=current_x.device)
# Force the constraint tokens onto the sequence
current_x[0, abs_start_pos:abs_end_pos] = constraint_tensor
except IndexError:
print(f"Warning: Constraint at {rel_pos} ('{tokenizer.decode(word_token_ids)}') goes out of bounds.")
except Exception as e:
print(f"Warning: Failed to apply constraint at {rel_pos}: {e}")
# Store the state *after* constraints for visualization
# We only need the generated part
generated_part = current_x[0, prompt_length:].clone().cpu() # Move to CPU to save GPU memory
step_sequence_history.append(generated_part)
# Return the (potentially modified by constraints) tensor x
return current_x # Pass the constrained version to the next step
# --- 3. Run Generation ---
final_response_text = ""
try:
print(f"Starting Dream generation: prompt_len={prompt_length}, gen_len={gen_length}, steps={steps}")
start_time = time.time()
# Initial masked state for visualization
initial_generated_state = torch.full((gen_length,), MASK_ID, dtype=torch.long)
# Apply constraints to the *initial* visual state if they start at pos 0
temp_initial_x = torch.cat((input_ids[0], initial_generated_state.to(device)), dim=0).unsqueeze(0)
initial_vis_x = live_visualization_hook(None, temp_initial_x, None) # Apply constraints via hook logic
step_sequence_history.insert(0, initial_vis_x[0, prompt_length:].cpu()) # Prepend initial state
output = model.diffusion_generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=gen_length,
output_history=False, # We capture history via the hook
return_dict_in_generate=True,
steps=steps,
temperature=temperature,
top_p=top_p if top_p is not None and top_p < 1.0 else None, # Ensure top_p < 1 or None
top_k=top_k if top_k is not None and top_k > 0 else None, # Ensure top_k > 0 or None
alg=alg,
alg_temp=alg_temp if alg in ['maskgit_plus', 'topk_margin', 'entropy'] else None, # Only relevant for some algs
generation_tokens_hook_func=live_visualization_hook
)
end_time = time.time()
print(f"Dream generation finished in {end_time - start_time:.2f} seconds.")
# --- 4. Process Final Output ---
final_sequence = output.sequences[0]
response_tokens = final_sequence[prompt_length:]
# Decode the final response text
final_response_text = tokenizer.decode(
response_tokens,
skip_special_tokens=True, # Skip EOS, PAD, MASK etc. in the final text
clean_up_tokenization_spaces=True
).strip()
# Update history with the final response
history[-1][1] = final_response_text
except Exception as e:
print(f"Error during generation or processing: {e}")
import traceback
traceback.print_exc()
yield history, [("Error during generation.", "red")], ""
return
# --- 5. Stream Visualization ---
print(f"Streaming {len(step_sequence_history)} visualization steps...")
previous_tokens_vis = None
for i, current_tokens_vis in enumerate(step_sequence_history):
# print(f" Step {i}: {current_tokens_vis.tolist()}") # Debug
vis_data = []
current_decoded_tokens = []
# Compare current step tokens with previous step tokens
for j in range(gen_length):
current_tok_id = current_tokens_vis[j].item()
previous_tok_id = previous_tokens_vis[j].item() if previous_tokens_vis is not None else MASK_ID
# Decode token - handle potential errors for single IDs if needed
try:
# Use skip_special_tokens=False here to see the actual tokens
decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False)
# Explicitly handle mask token display
if current_tok_id == MASK_ID:
display_token = MASK_TOKEN
else:
display_token = decoded_token
except Exception:
display_token = f"[ID:{current_tok_id}]" # Fallback
# Determine color and handle hiding of special tokens (like LLaDA demo)
color = None
token_to_display = display_token
if current_tok_id == MASK_ID:
color = "#444444" # Dark Gray for masks
elif previous_tok_id == MASK_ID: # Token was just revealed
# Simple green for newly revealed, no confidence score available from hook
color = "#66CC66" # Light Green
else: # Token was already revealed
color = "#6699CC" # Light Blue
# LLaDA hiding effect: If it's a special token (EOS/PAD) *and* it was revealed before this step, hide it.
if current_tok_id in {PAD_ID, EOS_ID} and previous_tok_id == current_tok_id:
# Hide by making it empty or using a background color - empty string is simpler
token_to_display = ""
color = "#FFFFFF" # Or just make it blend in
# Add token and color to visualization data
if token_to_display: # Avoid adding empty strings if hiding
vis_data.append((token_to_display, color))
elif len(vis_data) > 0 and isinstance(vis_data[-1], tuple):
# If hidden, and previous was text, add a space for visual separation?
# This might complicate things, let's omit for now.
pass
# elif len(vis_data) == 0: # If first token is hidden
# vis_data.append(("", None)) # Placeholder?
# Update previous state for next iteration
previous_tokens_vis = current_tokens_vis
# Yield the current visualization state
yield history, vis_data, final_response_text
# Pause for the specified delay
time.sleep(visualization_delay)
print("Visualization streaming complete.")
# --- Gradio UI ---
css = '''
.category-legend{display:none}
button{min-height: 60px}
'''
def create_chatbot_demo():
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
gr.Markdown("# Dream 7B - Diffusion Language Model Demo")
gr.Markdown(
"[[Model Card](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)] "
"[[Blog](https://hkunlp.github.io/blog/2025/dream/)]"
)
# STATE MANAGEMENT
chat_history = gr.State([])
# UI COMPONENTS
with gr.Row():
with gr.Column(scale=3):
chatbot_ui = gr.Chatbot(
label="Conversation",
height=500,
show_copy_button=True,
bubble_full_width=False
)
# Message input
with gr.Group():
with gr.Row():
user_input = gr.Textbox(
label="Your Message",
placeholder="Type your message here...",
scale=7,
autofocus=True,
show_label=False,
container=False # Remove container for tighter packing
)
send_btn = gr.Button("Send", scale=1, variant="primary")
constraints_input = gr.Textbox(
label="Word Constraints (Optional)",
info="Place words at specific positions (0-indexed from start of generation). Format: 'pos:word, pos:word,...'. Example: '0:Once, 5:upon, 10:time'",
placeholder="0:Hello, 10:world",
value=""
)
with gr.Column(scale=2):
output_vis = gr.HighlightedText(
label="Denoising Process Visualization",
combine_adjacent=True,
show_legend=False, # Legend isn't very informative here
interactive=False # Not interactive
)
# Advanced generation settings
with gr.Accordion("Generation Settings", open=False):
with gr.Row():
gen_length = gr.Slider(
minimum=16, maximum=512, value=128, step=8, # Increased max length
label="Max New Tokens"
)
steps = gr.Slider(
minimum=8, maximum=512, value=128, step=8, # Increased max steps
label="Diffusion Steps"
)
with gr.Row():
temperature = gr.Slider(
minimum=0.0, maximum=1.0, value=0.4, step=0.05,
label="Temperature"
)
alg_temp = gr.Slider(
minimum=0.0, maximum=1.0, value=0.1, step=0.05,
label="Remasking Temp (for confidence algs)"
)
with gr.Row():
top_p = gr.Slider(
minimum=0.0, maximum=1.0, value=0.95, step=0.05,
label="Top-P (0=disabled)"
)
top_k = gr.Slider(
minimum=0, maximum=200, value=0, step=5,
label="Top-K (0=disabled)"
)
with gr.Row():
remasking_strategy = gr.Radio(
choices=['origin', 'maskgit_plus', 'topk_margin', 'entropy'],
value='entropy', # Default to entropy as in example
label="Remasking Strategy (Algorithm)"
)
with gr.Row():
visualization_delay = gr.Slider(
minimum=0.0, maximum=0.5, value=0.02, step=0.01, # Faster default
label="Visualization Delay (seconds)"
)
# Clear button
clear_btn = gr.Button("Clear Conversation")
# Current response text box (hidden, maybe useful for debugging)
# current_response = gr.Textbox(visible=False)
# --- Event Handlers ---
def add_user_message_to_history(message: str, history: List[List[Optional[str]]]):
"""Adds user message, clears input, prepares for bot response."""
if not message.strip():
gr.Warning("Please enter a message.")
return history, history, "", [("Enter a message", "grey")] # Keep vis empty or show prompt
# Add user message with placeholder for bot response
history.append([message, None])
# Return updated history for chatbot, empty input box, empty visualization
return history, history, "", []
def clear_conversation():
"""Clears the chat history and visualization."""
return [], [], "", []
# --- Connect UI elements ---
# Define the inputs for the generation function once
generation_inputs = [
chat_history, gen_length, steps, constraints_input,
temperature, top_p, top_k, remasking_strategy, alg_temp,
visualization_delay
]
# Define the outputs for the generation function
generation_outputs = [chatbot_ui, output_vis]
# Handle Textbox Submission (Enter key)
submit_listener = user_input.submit(
fn=add_user_message_to_history,
inputs=[user_input, chat_history],
outputs=[chat_history, chatbot_ui, user_input, output_vis] # Step 1: Add user msg
)
# Chain the bot response generation after the user message is added
submit_listener.then(
fn=generate_dream_response,
inputs=generation_inputs,
outputs=generation_outputs # Step 2: Generate response and stream vis
)
# Handle Send Button Click
click_listener = send_btn.click(
fn=add_user_message_to_history,
inputs=[user_input, chat_history],
outputs=[chat_history, chatbot_ui, user_input, output_vis] # Step 1: Add user msg
)
# Chain the bot response generation after the user message is added
click_listener.then(
fn=generate_dream_response,
inputs=generation_inputs,
outputs=generation_outputs # Step 2: Generate response and stream vis
)
# Clear Button Action remains the same
clear_btn.click(
clear_conversation,
inputs=[],
outputs=[chat_history, chatbot_ui, user_input, output_vis]
)
return demo
# --- Launch ---
if __name__ == "__main__":
demo = create_chatbot_demo()
# Use queue for handling multiple users and streaming
demo.queue().launch(debug=True, share=True) # Add share=True for public link if needed