Spaces:
Running
on
Zero
Running
on
Zero
# dream_app.py | |
import torch | |
import numpy as np | |
import gradio as gr | |
import spaces # Ensure spaces is installed if needed for GPU decorator | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, AutoModel, AutoConfig | |
import time | |
import re | |
from typing import List, Dict, Tuple, Optional | |
# Load model configuration to get special token IDs | |
config = AutoConfig.from_pretrained("Dream-org/Dream-v0-Instruct-7B", trust_remote_code=True) | |
# Use AutoModel for the base model loading, relying on trust_remote_code=True | |
# for the custom DreamModel class and generation mixin. | |
model_path = "Dream-org/Dream-v0-Instruct-7B" | |
# Determine device | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
print(f"Using device: {device}") | |
# Load model and tokenizer | |
print("Loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
print("Loading model...") | |
# Ensure torch_dtype is set appropriately for your hardware if needed | |
model = AutoModel.from_pretrained( | |
model_path, | |
torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32, # Use bfloat16 only on CUDA | |
trust_remote_code=True | |
) | |
model = model.to(device).eval() | |
print("Model loaded.") | |
# Constants from Dream's config/tokenizer | |
# Use attributes from loaded config/tokenizer objects | |
MASK_TOKEN = tokenizer.mask_token | |
MASK_ID = config.mask_token_id | |
PAD_ID = config.pad_token_id | |
EOS_ID = config.eos_token_id | |
# Make sure EOS_ID and PAD_ID are handled correctly; Dream uses the same ID for both | |
SPECIAL_TOKEN_IDS = {PAD_ID, EOS_ID, MASK_ID} | |
# Add other special tokens defined in tokenizer_config.json if needed for hiding | |
# Get IDs for im_start, im_end etc. if they should also be hidden/handled specially | |
IM_START_ID = tokenizer.convert_tokens_to_ids("<|im_start|>") | |
IM_END_ID = tokenizer.convert_tokens_to_ids("<|im_end|>") | |
SPECIAL_TOKEN_IDS.add(IM_START_ID) | |
SPECIAL_TOKEN_IDS.add(IM_END_ID) | |
# --- Helper Functions --- | |
def parse_constraints(constraints_text: str) -> Dict[int, List[int]]: | |
""" | |
Parse constraints in format: 'position:word, position:word, ...' | |
Returns a dictionary mapping the starting position (0-indexed from the start | |
of the *generated* sequence) to a list of token IDs for the constraint word. | |
""" | |
constraints = {} | |
if not constraints_text: | |
return constraints | |
parts = constraints_text.split(',') | |
for part in parts: | |
if ':' not in part: | |
continue | |
pos_str, word = part.split(':', 1) | |
try: | |
# Position relative to the start of the *generation* | |
pos = int(pos_str.strip()) | |
word = word.strip() | |
# Tokenize the word - add leading space if not BOS? Dream handles spaces. | |
# Check Dream tokenizer behavior for spaces. Assuming standard behavior: | |
token_ids = tokenizer.encode(" " + word if pos > 0 else word, add_special_tokens=False) | |
if token_ids and pos >= 0: | |
constraints[pos] = token_ids | |
except ValueError: | |
continue # Ignore malformed constraint parts | |
except Exception as e: | |
print(f"Warning: Error processing constraint '{part}': {e}") | |
continue | |
return constraints | |
def format_chat_history(history: List[List[Optional[str]]]) -> List[Dict[str, str]]: | |
""" | |
Format chat history for the Dream model's chat template. | |
Args: | |
history: List of [user_message, assistant_message] pairs. | |
The last assistant_message might be None. | |
Returns: | |
Formatted list of message dictionaries for tokenizer.apply_chat_template. | |
""" | |
messages = [] | |
# Check if the first message is a system prompt, handle accordingly if needed | |
# Based on Dream's examples, the template adds a default system prompt if none exists. | |
# If history starts with System, it should be handled by the template. | |
# Let's assume the template handles the system prompt correctly. | |
for user_msg, assistant_msg in history: | |
if user_msg: # Defensive check | |
messages.append({"role": "user", "content": user_msg}) | |
# Add assistant message only if it exists (it won't for the last turn before generation) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
return messages | |
# --- Core Generation Logic with Live Visualization --- | |
# Decorator for Hugging Face Spaces GPU usage | |
def generate_dream_response( | |
history: List[List[Optional[str]]], | |
gen_length: int, | |
steps: int, | |
constraints_text: str, | |
temperature: float, | |
top_p: Optional[float], | |
top_k: Optional[int], | |
alg: str, | |
alg_temp: Optional[float], | |
visualization_delay: float | |
) -> List[Tuple[str, str]]: | |
""" | |
Generates text using the Dream model and yields visualization states live. | |
Args: | |
history: Chat history. | |
gen_length: Max new tokens to generate. | |
steps: Number of diffusion steps. | |
constraints_text: User-provided constraints string. | |
temperature: Sampling temperature. | |
top_p: Top-p sampling nucleus. | |
top_k: Top-k sampling. | |
alg: Remasking algorithm ('origin', 'maskgit_plus', 'topk_margin', 'entropy'). | |
alg_temp: Temperature for confidence-based algorithms. | |
visualization_delay: Delay between visualization steps. | |
Yields: | |
Tuple[List[List[Optional[str]]], List[Tuple[str, Optional[str]]], str]: | |
- Updated history | |
- Visualization data for HighlightedText | |
- Final response text (repeated in each yield) | |
""" | |
if not history or not history[-1][0]: | |
# No user message to respond to | |
yield history, [("No input message found.", "red")], "" | |
return | |
# --- 1. Preparation --- | |
last_user_message = history[-1][0] | |
messages_for_template = format_chat_history(history) # Includes the latest user message | |
# Parse constraints relative to the *generated* sequence | |
parsed_constraints = parse_constraints(constraints_text) # Dict[rel_pos, List[token_id]] | |
# Prepare inputs using the chat template | |
try: | |
inputs = tokenizer.apply_chat_template( | |
messages_for_template, | |
return_tensors="pt", | |
return_dict=True, | |
add_generation_prompt=True # Important for instruct models | |
) | |
input_ids = inputs.input_ids.to(device) | |
attention_mask = inputs.attention_mask.to(device) | |
prompt_length = input_ids.shape[1] | |
except Exception as e: | |
print(f"Error applying chat template: {e}") | |
yield history, [("Error preparing input.", "red")], "" | |
return | |
# Calculate total sequence length for the model | |
# Max length constraint from model config (e.g., 2048 for original Dream?) | |
# Let's use a reasonable default or allow configuration if needed. | |
# The provided code uses max_position_embeddings=131072, let's stick to user input + gen_length. | |
total_length = prompt_length + gen_length | |
# --- 2. Visualization Setup --- | |
# This list will store the token sequence (just the generated part) at each step | |
step_sequence_history: List[torch.Tensor] = [] | |
previous_step_tokens = None # Keep track of the previous step's state | |
# Define the hook function *inside* this function to capture state | |
def live_visualization_hook(step: Optional[int], x: torch.Tensor, logits: Optional[torch.Tensor]) -> torch.Tensor: | |
nonlocal step_sequence_history, parsed_constraints, prompt_length | |
# --- Apply Constraints --- | |
# Constraints are applied *after* the model proposes tokens but *before* they are finalized for the step | |
# Note: The hook receives the state *before* the next model call in the next step, | |
# or the final state after the last step. Let's apply constraints consistently. | |
# The `diffusion_generate` calls the hook *after* updating x based on sampling. | |
current_x = x.clone() # Work on a copy | |
for rel_pos, word_token_ids in parsed_constraints.items(): | |
abs_start_pos = prompt_length + rel_pos | |
abs_end_pos = abs_start_pos + len(word_token_ids) | |
# Ensure the constraint fits within the generation length | |
if abs_start_pos < total_length and abs_end_pos <= total_length: | |
try: | |
constraint_tensor = torch.tensor(word_token_ids, dtype=torch.long, device=current_x.device) | |
# Force the constraint tokens onto the sequence | |
current_x[0, abs_start_pos:abs_end_pos] = constraint_tensor | |
except IndexError: | |
print(f"Warning: Constraint at {rel_pos} ('{tokenizer.decode(word_token_ids)}') goes out of bounds.") | |
except Exception as e: | |
print(f"Warning: Failed to apply constraint at {rel_pos}: {e}") | |
# Store the state *after* constraints for visualization | |
# We only need the generated part | |
generated_part = current_x[0, prompt_length:].clone().cpu() # Move to CPU to save GPU memory | |
step_sequence_history.append(generated_part) | |
# Return the (potentially modified by constraints) tensor x | |
return current_x # Pass the constrained version to the next step | |
# --- 3. Run Generation --- | |
final_response_text = "" | |
try: | |
print(f"Starting Dream generation: prompt_len={prompt_length}, gen_len={gen_length}, steps={steps}") | |
start_time = time.time() | |
# Initial masked state for visualization | |
initial_generated_state = torch.full((gen_length,), MASK_ID, dtype=torch.long) | |
# Apply constraints to the *initial* visual state if they start at pos 0 | |
temp_initial_x = torch.cat((input_ids[0], initial_generated_state.to(device)), dim=0).unsqueeze(0) | |
initial_vis_x = live_visualization_hook(None, temp_initial_x, None) # Apply constraints via hook logic | |
step_sequence_history.insert(0, initial_vis_x[0, prompt_length:].cpu()) # Prepend initial state | |
output = model.diffusion_generate( | |
input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=gen_length, | |
output_history=False, # We capture history via the hook | |
return_dict_in_generate=True, | |
steps=steps, | |
temperature=temperature, | |
top_p=top_p if top_p is not None and top_p < 1.0 else None, # Ensure top_p < 1 or None | |
top_k=top_k if top_k is not None and top_k > 0 else None, # Ensure top_k > 0 or None | |
alg=alg, | |
alg_temp=alg_temp if alg in ['maskgit_plus', 'topk_margin', 'entropy'] else None, # Only relevant for some algs | |
generation_tokens_hook_func=live_visualization_hook | |
) | |
end_time = time.time() | |
print(f"Dream generation finished in {end_time - start_time:.2f} seconds.") | |
# --- 4. Process Final Output --- | |
final_sequence = output.sequences[0] | |
response_tokens = final_sequence[prompt_length:] | |
# Decode the final response text | |
final_response_text = tokenizer.decode( | |
response_tokens, | |
skip_special_tokens=True, # Skip EOS, PAD, MASK etc. in the final text | |
clean_up_tokenization_spaces=True | |
).strip() | |
# Update history with the final response | |
history[-1][1] = final_response_text | |
except Exception as e: | |
print(f"Error during generation or processing: {e}") | |
import traceback | |
traceback.print_exc() | |
yield history, [("Error during generation.", "red")], "" | |
return | |
# --- 5. Stream Visualization --- | |
print(f"Streaming {len(step_sequence_history)} visualization steps...") | |
previous_tokens_vis = None | |
for i, current_tokens_vis in enumerate(step_sequence_history): | |
# print(f" Step {i}: {current_tokens_vis.tolist()}") # Debug | |
vis_data = [] | |
current_decoded_tokens = [] | |
# Compare current step tokens with previous step tokens | |
for j in range(gen_length): | |
current_tok_id = current_tokens_vis[j].item() | |
previous_tok_id = previous_tokens_vis[j].item() if previous_tokens_vis is not None else MASK_ID | |
# Decode token - handle potential errors for single IDs if needed | |
try: | |
# Use skip_special_tokens=False here to see the actual tokens | |
decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False) | |
# Explicitly handle mask token display | |
if current_tok_id == MASK_ID: | |
display_token = MASK_TOKEN | |
else: | |
display_token = decoded_token | |
except Exception: | |
display_token = f"[ID:{current_tok_id}]" # Fallback | |
# Determine color and handle hiding of special tokens (like LLaDA demo) | |
color = None | |
token_to_display = display_token | |
if current_tok_id == MASK_ID: | |
color = "#444444" # Dark Gray for masks | |
elif previous_tok_id == MASK_ID: # Token was just revealed | |
# Simple green for newly revealed, no confidence score available from hook | |
color = "#66CC66" # Light Green | |
else: # Token was already revealed | |
color = "#6699CC" # Light Blue | |
# LLaDA hiding effect: If it's a special token (EOS/PAD) *and* it was revealed before this step, hide it. | |
if current_tok_id in {PAD_ID, EOS_ID} and previous_tok_id == current_tok_id: | |
# Hide by making it empty or using a background color - empty string is simpler | |
token_to_display = "" | |
color = "#FFFFFF" # Or just make it blend in | |
# Add token and color to visualization data | |
if token_to_display: # Avoid adding empty strings if hiding | |
vis_data.append((token_to_display, color)) | |
elif len(vis_data) > 0 and isinstance(vis_data[-1], tuple): | |
# If hidden, and previous was text, add a space for visual separation? | |
# This might complicate things, let's omit for now. | |
pass | |
# elif len(vis_data) == 0: # If first token is hidden | |
# vis_data.append(("", None)) # Placeholder? | |
# Update previous state for next iteration | |
previous_tokens_vis = current_tokens_vis | |
# Yield the current visualization state | |
yield history, vis_data, final_response_text | |
# Pause for the specified delay | |
time.sleep(visualization_delay) | |
print("Visualization streaming complete.") | |
# --- Gradio UI --- | |
css = ''' | |
.category-legend{display:none} | |
button{min-height: 60px} | |
''' | |
def create_chatbot_demo(): | |
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# Dream 7B - Diffusion Language Model Demo") | |
gr.Markdown( | |
"[[Model Card](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)] " | |
"[[Blog](https://hkunlp.github.io/blog/2025/dream/)]" | |
) | |
# STATE MANAGEMENT | |
chat_history = gr.State([]) | |
# UI COMPONENTS | |
with gr.Row(): | |
with gr.Column(scale=3): | |
chatbot_ui = gr.Chatbot( | |
label="Conversation", | |
height=500, | |
show_copy_button=True, | |
bubble_full_width=False | |
) | |
# Message input | |
with gr.Group(): | |
with gr.Row(): | |
user_input = gr.Textbox( | |
label="Your Message", | |
placeholder="Type your message here...", | |
scale=7, | |
autofocus=True, | |
show_label=False, | |
container=False # Remove container for tighter packing | |
) | |
send_btn = gr.Button("Send", scale=1, variant="primary") | |
constraints_input = gr.Textbox( | |
label="Word Constraints (Optional)", | |
info="Place words at specific positions (0-indexed from start of generation). Format: 'pos:word, pos:word,...'. Example: '0:Once, 5:upon, 10:time'", | |
placeholder="0:Hello, 10:world", | |
value="" | |
) | |
with gr.Column(scale=2): | |
output_vis = gr.HighlightedText( | |
label="Denoising Process Visualization", | |
combine_adjacent=True, | |
show_legend=False, # Legend isn't very informative here | |
interactive=False # Not interactive | |
) | |
# Advanced generation settings | |
with gr.Accordion("Generation Settings", open=False): | |
with gr.Row(): | |
gen_length = gr.Slider( | |
minimum=16, maximum=512, value=128, step=8, # Increased max length | |
label="Max New Tokens" | |
) | |
steps = gr.Slider( | |
minimum=8, maximum=512, value=128, step=8, # Increased max steps | |
label="Diffusion Steps" | |
) | |
with gr.Row(): | |
temperature = gr.Slider( | |
minimum=0.0, maximum=1.0, value=0.4, step=0.05, | |
label="Temperature" | |
) | |
alg_temp = gr.Slider( | |
minimum=0.0, maximum=1.0, value=0.1, step=0.05, | |
label="Remasking Temp (for confidence algs)" | |
) | |
with gr.Row(): | |
top_p = gr.Slider( | |
minimum=0.0, maximum=1.0, value=0.95, step=0.05, | |
label="Top-P (0=disabled)" | |
) | |
top_k = gr.Slider( | |
minimum=0, maximum=200, value=0, step=5, | |
label="Top-K (0=disabled)" | |
) | |
with gr.Row(): | |
remasking_strategy = gr.Radio( | |
choices=['origin', 'maskgit_plus', 'topk_margin', 'entropy'], | |
value='entropy', # Default to entropy as in example | |
label="Remasking Strategy (Algorithm)" | |
) | |
with gr.Row(): | |
visualization_delay = gr.Slider( | |
minimum=0.0, maximum=0.5, value=0.02, step=0.01, # Faster default | |
label="Visualization Delay (seconds)" | |
) | |
# Clear button | |
clear_btn = gr.Button("Clear Conversation") | |
# Current response text box (hidden, maybe useful for debugging) | |
# current_response = gr.Textbox(visible=False) | |
# --- Event Handlers --- | |
def add_user_message_to_history(message: str, history: List[List[Optional[str]]]): | |
"""Adds user message, clears input, prepares for bot response.""" | |
if not message.strip(): | |
gr.Warning("Please enter a message.") | |
return history, history, "", [("Enter a message", "grey")] # Keep vis empty or show prompt | |
# Add user message with placeholder for bot response | |
history.append([message, None]) | |
# Return updated history for chatbot, empty input box, empty visualization | |
return history, history, "", [] | |
def clear_conversation(): | |
"""Clears the chat history and visualization.""" | |
return [], [], "", [] | |
# --- Connect UI elements --- | |
# Define the inputs for the generation function once | |
generation_inputs = [ | |
chat_history, gen_length, steps, constraints_input, | |
temperature, top_p, top_k, remasking_strategy, alg_temp, | |
visualization_delay | |
] | |
# Define the outputs for the generation function | |
generation_outputs = [chatbot_ui, output_vis] | |
# Handle Textbox Submission (Enter key) | |
submit_listener = user_input.submit( | |
fn=add_user_message_to_history, | |
inputs=[user_input, chat_history], | |
outputs=[chat_history, chatbot_ui, user_input, output_vis] # Step 1: Add user msg | |
) | |
# Chain the bot response generation after the user message is added | |
submit_listener.then( | |
fn=generate_dream_response, | |
inputs=generation_inputs, | |
outputs=generation_outputs # Step 2: Generate response and stream vis | |
) | |
# Handle Send Button Click | |
click_listener = send_btn.click( | |
fn=add_user_message_to_history, | |
inputs=[user_input, chat_history], | |
outputs=[chat_history, chatbot_ui, user_input, output_vis] # Step 1: Add user msg | |
) | |
# Chain the bot response generation after the user message is added | |
click_listener.then( | |
fn=generate_dream_response, | |
inputs=generation_inputs, | |
outputs=generation_outputs # Step 2: Generate response and stream vis | |
) | |
# Clear Button Action remains the same | |
clear_btn.click( | |
clear_conversation, | |
inputs=[], | |
outputs=[chat_history, chatbot_ui, user_input, output_vis] | |
) | |
return demo | |
# --- Launch --- | |
if __name__ == "__main__": | |
demo = create_chatbot_demo() | |
# Use queue for handling multiple users and streaming | |
demo.queue().launch(debug=True, share=True) # Add share=True for public link if needed |