# dream_app.py import torch import numpy as np import gradio as gr import spaces # Ensure spaces is installed if needed for GPU decorator import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel, AutoConfig import time import re from typing import List, Dict, Tuple, Optional # Load model configuration to get special token IDs config = AutoConfig.from_pretrained("Dream-org/Dream-v0-Instruct-7B", trust_remote_code=True) # Use AutoModel for the base model loading, relying on trust_remote_code=True # for the custom DreamModel class and generation mixin. model_path = "Dream-org/Dream-v0-Instruct-7B" # Determine device device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Using device: {device}") # Load model and tokenizer print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) print("Loading model...") # Ensure torch_dtype is set appropriately for your hardware if needed model = AutoModel.from_pretrained( model_path, torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32, # Use bfloat16 only on CUDA trust_remote_code=True ) model = model.to(device).eval() print("Model loaded.") # Constants from Dream's config/tokenizer # Use attributes from loaded config/tokenizer objects MASK_TOKEN = tokenizer.mask_token MASK_ID = config.mask_token_id PAD_ID = config.pad_token_id EOS_ID = config.eos_token_id # Make sure EOS_ID and PAD_ID are handled correctly; Dream uses the same ID for both SPECIAL_TOKEN_IDS = {PAD_ID, EOS_ID, MASK_ID} # Add other special tokens defined in tokenizer_config.json if needed for hiding # Get IDs for im_start, im_end etc. if they should also be hidden/handled specially IM_START_ID = tokenizer.convert_tokens_to_ids("<|im_start|>") IM_END_ID = tokenizer.convert_tokens_to_ids("<|im_end|>") SPECIAL_TOKEN_IDS.add(IM_START_ID) SPECIAL_TOKEN_IDS.add(IM_END_ID) # --- Helper Functions --- def parse_constraints(constraints_text: str) -> Dict[int, List[int]]: """ Parse constraints in format: 'position:word, position:word, ...' Returns a dictionary mapping the starting position (0-indexed from the start of the *generated* sequence) to a list of token IDs for the constraint word. """ constraints = {} if not constraints_text: return constraints parts = constraints_text.split(',') for part in parts: if ':' not in part: continue pos_str, word = part.split(':', 1) try: # Position relative to the start of the *generation* pos = int(pos_str.strip()) word = word.strip() # Tokenize the word - add leading space if not BOS? Dream handles spaces. # Check Dream tokenizer behavior for spaces. Assuming standard behavior: token_ids = tokenizer.encode(" " + word if pos > 0 else word, add_special_tokens=False) if token_ids and pos >= 0: constraints[pos] = token_ids except ValueError: continue # Ignore malformed constraint parts except Exception as e: print(f"Warning: Error processing constraint '{part}': {e}") continue return constraints def format_chat_history(history: List[List[Optional[str]]]) -> List[Dict[str, str]]: """ Format chat history for the Dream model's chat template. Args: history: List of [user_message, assistant_message] pairs. The last assistant_message might be None. Returns: Formatted list of message dictionaries for tokenizer.apply_chat_template. """ messages = [] # Check if the first message is a system prompt, handle accordingly if needed # Based on Dream's examples, the template adds a default system prompt if none exists. # If history starts with System, it should be handled by the template. # Let's assume the template handles the system prompt correctly. for user_msg, assistant_msg in history: if user_msg: # Defensive check messages.append({"role": "user", "content": user_msg}) # Add assistant message only if it exists (it won't for the last turn before generation) if assistant_msg: messages.append({"role": "assistant", "content": assistant_msg}) return messages # --- Core Generation Logic with Live Visualization --- @spaces.GPU # Decorator for Hugging Face Spaces GPU usage def generate_dream_response( history: List[List[Optional[str]]], gen_length: int, steps: int, constraints_text: str, temperature: float, top_p: Optional[float], top_k: Optional[int], alg: str, alg_temp: Optional[float], visualization_delay: float ) -> List[Tuple[str, str]]: """ Generates text using the Dream model and yields visualization states live. Args: history: Chat history. gen_length: Max new tokens to generate. steps: Number of diffusion steps. constraints_text: User-provided constraints string. temperature: Sampling temperature. top_p: Top-p sampling nucleus. top_k: Top-k sampling. alg: Remasking algorithm ('origin', 'maskgit_plus', 'topk_margin', 'entropy'). alg_temp: Temperature for confidence-based algorithms. visualization_delay: Delay between visualization steps. Yields: Tuple[List[List[Optional[str]]], List[Tuple[str, Optional[str]]], str]: - Updated history - Visualization data for HighlightedText - Final response text (repeated in each yield) """ if not history or not history[-1][0]: # No user message to respond to yield history, [("No input message found.", "red")], "" return # --- 1. Preparation --- last_user_message = history[-1][0] messages_for_template = format_chat_history(history) # Includes the latest user message # Parse constraints relative to the *generated* sequence parsed_constraints = parse_constraints(constraints_text) # Dict[rel_pos, List[token_id]] # Prepare inputs using the chat template try: inputs = tokenizer.apply_chat_template( messages_for_template, return_tensors="pt", return_dict=True, add_generation_prompt=True # Important for instruct models ) input_ids = inputs.input_ids.to(device) attention_mask = inputs.attention_mask.to(device) prompt_length = input_ids.shape[1] except Exception as e: print(f"Error applying chat template: {e}") yield history, [("Error preparing input.", "red")], "" return # Calculate total sequence length for the model # Max length constraint from model config (e.g., 2048 for original Dream?) # Let's use a reasonable default or allow configuration if needed. # The provided code uses max_position_embeddings=131072, let's stick to user input + gen_length. total_length = prompt_length + gen_length # --- 2. Visualization Setup --- # This list will store the token sequence (just the generated part) at each step step_sequence_history: List[torch.Tensor] = [] previous_step_tokens = None # Keep track of the previous step's state # Define the hook function *inside* this function to capture state def live_visualization_hook(step: Optional[int], x: torch.Tensor, logits: Optional[torch.Tensor]) -> torch.Tensor: nonlocal step_sequence_history, parsed_constraints, prompt_length # --- Apply Constraints --- # Constraints are applied *after* the model proposes tokens but *before* they are finalized for the step # Note: The hook receives the state *before* the next model call in the next step, # or the final state after the last step. Let's apply constraints consistently. # The `diffusion_generate` calls the hook *after* updating x based on sampling. current_x = x.clone() # Work on a copy for rel_pos, word_token_ids in parsed_constraints.items(): abs_start_pos = prompt_length + rel_pos abs_end_pos = abs_start_pos + len(word_token_ids) # Ensure the constraint fits within the generation length if abs_start_pos < total_length and abs_end_pos <= total_length: try: constraint_tensor = torch.tensor(word_token_ids, dtype=torch.long, device=current_x.device) # Force the constraint tokens onto the sequence current_x[0, abs_start_pos:abs_end_pos] = constraint_tensor except IndexError: print(f"Warning: Constraint at {rel_pos} ('{tokenizer.decode(word_token_ids)}') goes out of bounds.") except Exception as e: print(f"Warning: Failed to apply constraint at {rel_pos}: {e}") # Store the state *after* constraints for visualization # We only need the generated part generated_part = current_x[0, prompt_length:].clone().cpu() # Move to CPU to save GPU memory step_sequence_history.append(generated_part) # Return the (potentially modified by constraints) tensor x return current_x # Pass the constrained version to the next step # --- 3. Run Generation --- final_response_text = "" try: print(f"Starting Dream generation: prompt_len={prompt_length}, gen_len={gen_length}, steps={steps}") start_time = time.time() # Initial masked state for visualization initial_generated_state = torch.full((gen_length,), MASK_ID, dtype=torch.long) # Apply constraints to the *initial* visual state if they start at pos 0 temp_initial_x = torch.cat((input_ids[0], initial_generated_state.to(device)), dim=0).unsqueeze(0) initial_vis_x = live_visualization_hook(None, temp_initial_x, None) # Apply constraints via hook logic step_sequence_history.insert(0, initial_vis_x[0, prompt_length:].cpu()) # Prepend initial state output = model.diffusion_generate( input_ids, attention_mask=attention_mask, max_new_tokens=gen_length, output_history=False, # We capture history via the hook return_dict_in_generate=True, steps=steps, temperature=temperature, top_p=top_p if top_p is not None and top_p < 1.0 else None, # Ensure top_p < 1 or None top_k=top_k if top_k is not None and top_k > 0 else None, # Ensure top_k > 0 or None alg=alg, alg_temp=alg_temp if alg in ['maskgit_plus', 'topk_margin', 'entropy'] else None, # Only relevant for some algs generation_tokens_hook_func=live_visualization_hook ) end_time = time.time() print(f"Dream generation finished in {end_time - start_time:.2f} seconds.") # --- 4. Process Final Output --- final_sequence = output.sequences[0] response_tokens = final_sequence[prompt_length:] # Decode the final response text final_response_text = tokenizer.decode( response_tokens, skip_special_tokens=True, # Skip EOS, PAD, MASK etc. in the final text clean_up_tokenization_spaces=True ).strip() # Update history with the final response history[-1][1] = final_response_text except Exception as e: print(f"Error during generation or processing: {e}") import traceback traceback.print_exc() yield history, [("Error during generation.", "red")], "" return # --- 5. Stream Visualization --- print(f"Streaming {len(step_sequence_history)} visualization steps...") previous_tokens_vis = None for i, current_tokens_vis in enumerate(step_sequence_history): # print(f" Step {i}: {current_tokens_vis.tolist()}") # Debug vis_data = [] current_decoded_tokens = [] # Compare current step tokens with previous step tokens for j in range(gen_length): current_tok_id = current_tokens_vis[j].item() previous_tok_id = previous_tokens_vis[j].item() if previous_tokens_vis is not None else MASK_ID # Decode token - handle potential errors for single IDs if needed try: # Use skip_special_tokens=False here to see the actual tokens decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False) # Explicitly handle mask token display if current_tok_id == MASK_ID: display_token = MASK_TOKEN else: display_token = decoded_token except Exception: display_token = f"[ID:{current_tok_id}]" # Fallback # Determine color and handle hiding of special tokens (like LLaDA demo) color = None token_to_display = display_token if current_tok_id == MASK_ID: color = "#444444" # Dark Gray for masks elif previous_tok_id == MASK_ID: # Token was just revealed # Simple green for newly revealed, no confidence score available from hook color = "#66CC66" # Light Green else: # Token was already revealed color = "#6699CC" # Light Blue # LLaDA hiding effect: If it's a special token (EOS/PAD) *and* it was revealed before this step, hide it. if current_tok_id in {PAD_ID, EOS_ID} and previous_tok_id == current_tok_id: # Hide by making it empty or using a background color - empty string is simpler token_to_display = "" color = "#FFFFFF" # Or just make it blend in # Add token and color to visualization data if token_to_display: # Avoid adding empty strings if hiding vis_data.append((token_to_display, color)) elif len(vis_data) > 0 and isinstance(vis_data[-1], tuple): # If hidden, and previous was text, add a space for visual separation? # This might complicate things, let's omit for now. pass # elif len(vis_data) == 0: # If first token is hidden # vis_data.append(("", None)) # Placeholder? # Update previous state for next iteration previous_tokens_vis = current_tokens_vis # Yield the current visualization state yield history, vis_data, final_response_text # Pause for the specified delay time.sleep(visualization_delay) print("Visualization streaming complete.") # --- Gradio UI --- css = ''' .category-legend{display:none} button{min-height: 60px} ''' def create_chatbot_demo(): with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: gr.Markdown("# Dream 7B - Diffusion Language Model Demo") gr.Markdown( "[[Model Card](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)] " "[[Blog](https://hkunlp.github.io/blog/2025/dream/)]" ) # STATE MANAGEMENT chat_history = gr.State([]) # UI COMPONENTS with gr.Row(): with gr.Column(scale=3): chatbot_ui = gr.Chatbot( label="Conversation", height=500, show_copy_button=True, bubble_full_width=False ) # Message input with gr.Group(): with gr.Row(): user_input = gr.Textbox( label="Your Message", placeholder="Type your message here...", scale=7, autofocus=True, show_label=False, container=False # Remove container for tighter packing ) send_btn = gr.Button("Send", scale=1, variant="primary") constraints_input = gr.Textbox( label="Word Constraints (Optional)", info="Place words at specific positions (0-indexed from start of generation). Format: 'pos:word, pos:word,...'. Example: '0:Once, 5:upon, 10:time'", placeholder="0:Hello, 10:world", value="" ) with gr.Column(scale=2): output_vis = gr.HighlightedText( label="Denoising Process Visualization", combine_adjacent=True, show_legend=False, # Legend isn't very informative here interactive=False # Not interactive ) # Advanced generation settings with gr.Accordion("Generation Settings", open=False): with gr.Row(): gen_length = gr.Slider( minimum=16, maximum=512, value=128, step=8, # Increased max length label="Max New Tokens" ) steps = gr.Slider( minimum=8, maximum=512, value=128, step=8, # Increased max steps label="Diffusion Steps" ) with gr.Row(): temperature = gr.Slider( minimum=0.0, maximum=1.0, value=0.4, step=0.05, label="Temperature" ) alg_temp = gr.Slider( minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="Remasking Temp (for confidence algs)" ) with gr.Row(): top_p = gr.Slider( minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="Top-P (0=disabled)" ) top_k = gr.Slider( minimum=0, maximum=200, value=0, step=5, label="Top-K (0=disabled)" ) with gr.Row(): remasking_strategy = gr.Radio( choices=['origin', 'maskgit_plus', 'topk_margin', 'entropy'], value='entropy', # Default to entropy as in example label="Remasking Strategy (Algorithm)" ) with gr.Row(): visualization_delay = gr.Slider( minimum=0.0, maximum=0.5, value=0.02, step=0.01, # Faster default label="Visualization Delay (seconds)" ) # Clear button clear_btn = gr.Button("Clear Conversation") # Current response text box (hidden, maybe useful for debugging) # current_response = gr.Textbox(visible=False) # --- Event Handlers --- def add_user_message_to_history(message: str, history: List[List[Optional[str]]]): """Adds user message, clears input, prepares for bot response.""" if not message.strip(): gr.Warning("Please enter a message.") return history, history, "", [("Enter a message", "grey")] # Keep vis empty or show prompt # Add user message with placeholder for bot response history.append([message, None]) # Return updated history for chatbot, empty input box, empty visualization return history, history, "", [] def clear_conversation(): """Clears the chat history and visualization.""" return [], [], "", [] # --- Connect UI elements --- # Define the inputs for the generation function once generation_inputs = [ chat_history, gen_length, steps, constraints_input, temperature, top_p, top_k, remasking_strategy, alg_temp, visualization_delay ] # Define the outputs for the generation function generation_outputs = [chatbot_ui, output_vis] # Handle Textbox Submission (Enter key) submit_listener = user_input.submit( fn=add_user_message_to_history, inputs=[user_input, chat_history], outputs=[chat_history, chatbot_ui, user_input, output_vis] # Step 1: Add user msg ) # Chain the bot response generation after the user message is added submit_listener.then( fn=generate_dream_response, inputs=generation_inputs, outputs=generation_outputs # Step 2: Generate response and stream vis ) # Handle Send Button Click click_listener = send_btn.click( fn=add_user_message_to_history, inputs=[user_input, chat_history], outputs=[chat_history, chatbot_ui, user_input, output_vis] # Step 1: Add user msg ) # Chain the bot response generation after the user message is added click_listener.then( fn=generate_dream_response, inputs=generation_inputs, outputs=generation_outputs # Step 2: Generate response and stream vis ) # Clear Button Action remains the same clear_btn.click( clear_conversation, inputs=[], outputs=[chat_history, chatbot_ui, user_input, output_vis] ) return demo # --- Launch --- if __name__ == "__main__": demo = create_chatbot_demo() # Use queue for handling multiple users and streaming demo.queue().launch(debug=True, share=True) # Add share=True for public link if needed