Spaces:
Running
on
Zero
Running
on
Zero
File size: 27,587 Bytes
0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c b861a35 c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 0d2292c c691b46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 |
# dream_app.py (Updated)
import torch
import numpy as np
import gradio as gr
import spaces
# import torch.nn.functional as F # Not needed for DREAM's basic visualization
from transformers import AutoTokenizer, AutoModel
import time
import re # Keep for parsing constraints
# Use try-except for space deployment vs local
try:
# Used for spaces deployment with GPU
gpu_check = spaces.GPU
print("Running in Gradio Spaces with GPU environment.")
except AttributeError:
# Fallback for local execution or environments without spaces.GPU
print("Running in local environment or without spaces.GPU.")
# Define a dummy decorator if spaces.GPU is not available
def gpu_check(func):
return func
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# --- Load DREAM Model and Tokenizer ---
model_path = "Dream-org/Dream-v0-Instruct-7B"
print(f"Loading model: {model_path}")
model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
print("Model and tokenizer loaded.")
# --- Constants for DREAM ---
if tokenizer.mask_token is None:
print("Warning: Mask token not found in tokenizer. Attempting to add '[MASK]'.")
tokenizer.add_special_tokens({'mask_token': '[MASK]'})
model.resize_token_embeddings(len(tokenizer)) # Important if vocab size changed
if tokenizer.mask_token is None or tokenizer.mask_token_id is None:
raise ValueError("Could not set or find ID for a mask token for the tokenizer.")
MASK_TOKEN = tokenizer.mask_token
MASK_ID = tokenizer.mask_token_id
EOS_TOKEN = tokenizer.eos_token # Get EOS token string
EOS_ID = tokenizer.eos_token_id # Get EOS token ID
# Add other special tokens if needed for visualization
SPECIAL_TOKENS_MAP = {
tokenizer.eos_token_id: "[EOS]",
tokenizer.bos_token_id: "[BOS]",
tokenizer.pad_token_id: "[PAD]",
tokenizer.unk_token_id: "[UNK]",
MASK_ID: MASK_TOKEN # Map mask ID back to its string representation
}
# Add None key to handle cases where token IDs might be None (shouldn't happen with tensors)
SPECIAL_TOKENS_MAP[None] = "[NONE]"
print(f"Using MASK_TOKEN='{MASK_TOKEN}' with ID={MASK_ID}")
print(f"Using EOS_TOKEN='{EOS_TOKEN}' with ID={EOS_ID}")
# --- Helper Functions (Constraint Parsing, History Formatting) ---
def parse_constraints(constraints_text):
"""Parse constraints in format: 'position:word, position:word, ...'"""
constraints = {}
if not constraints_text:
return constraints
parts = constraints_text.split(',')
for part in parts:
part = part.strip() # Trim whitespace
if ':' not in part:
continue
try:
pos_str, word = part.split(':', 1)
pos = int(pos_str.strip())
word = word.strip()
# Allow empty words if needed, but usually we want a word
if word and pos >= 0:
constraints[pos] = word
except ValueError:
print(f"Warning: Could not parse constraint part: '{part}'")
continue
return constraints
def format_chat_history(history):
"""
Format chat history for the DREAM model (standard messages format)
Args:
history: List of [user_message, assistant_message] pairs
Returns:
Formatted conversation for the model (list of dictionaries)
"""
messages = []
# Add system prompt if desired (check DREAM examples/recommendations)
# messages.append({"role": "system", "content": "You are a helpful assistant."}) # Optional
for user_msg, assistant_msg in history:
if user_msg: # Handle potential None message if clearing failed
messages.append({"role": "user", "content": user_msg})
if assistant_msg: # Skip if None (for the latest user message awaiting response)
messages.append({"role": "assistant", "content": assistant_msg})
return messages
# --- Core Generation Logic for DREAM with Visualization ---
@gpu_check # Use the potentially dummy decorator
def dream_generate_response_with_visualization(
messages,
gen_length=64,
steps=64, # Default based on DREAM examples
constraints=None,
temperature=0.6, # Default based on DREAM examples
top_p=0.95, # Default based on DREAM examples
alg="entropy", # Default based on DREAM examples
alg_temp=0.0, # Default based on DREAM examples
):
"""
Generate text with DREAM model with visualization using the generation hook.
Args:
messages: List of message dictionaries with 'role' and 'content'
gen_length: Length of text to generate (max_new_tokens)
steps: Number of diffusion steps
constraints: Dictionary mapping positions (relative to response start) to words
temperature: Sampling temperature
top_p: Nucleus sampling p
alg: Remasking algorithm ('origin', 'maskgit_plus', 'topk_margin', 'entropy')
alg_temp: Temperature for confidence-based algorithms
Returns:
Tuple: (List of visualization states, final generated text string)
"""
print("--- Starting DREAM Generation ---")
print(f"Parameters: gen_length={gen_length}, steps={steps}, temperature={temperature}, top_p={top_p}, alg='{alg}', alg_temp={alg_temp}")
print(f"Constraints: {constraints}")
# --- Input Preparation ---
if constraints is None:
constraints = {}
# Convert word constraints to token IDs (handle multi-token words)
processed_constraints = {}
print("Processing constraints:")
for pos, word in constraints.items():
# Prepend space for consistent tokenization, similar to LLaDA example
# Important: use add_special_tokens=False for constraints
tokens = tokenizer.encode(" " + word, add_special_tokens=False)
if not tokens:
print(f" Warning: Could not tokenize constraint word '{word}' at position {pos}. Skipping.")
continue
print(f" Pos {pos}, Word '{word}' -> Tokens {tokens}")
for i, token_id in enumerate(tokens):
# Ensure we don't overwrite parts of multi-token constraints accidentally
if pos + i not in processed_constraints:
processed_constraints[pos + i] = token_id
else:
print(f" Warning: Overlapping constraint at position {pos+i}. Keeping first.")
# Prepare the prompt using chat template
try:
inputs = tokenizer.apply_chat_template(
messages,
return_tensors="pt",
return_dict=True,
add_generation_prompt=True # Crucial for instruction-tuned models like Dream-Instruct
)
input_ids = inputs.input_ids.to(device=device)
attention_mask = inputs.attention_mask.to(device=device) # Get attention mask
prompt_length = input_ids.shape[1]
print(f"Input prompt length: {prompt_length}")
# print(f"Input IDs: {input_ids}") # Keep commented unless debugging
except Exception as e:
print(f"Error applying chat template: {e}")
return [([("Error applying chat template.", "Error")],)], f"Error: {e}"
if prompt_length + gen_length > 2048: # Check context length (DREAM uses 2048)
print(f"Warning: Requested length ({prompt_length + gen_length}) exceeds model max length (2048). Truncating gen_length.")
gen_length = 2048 - prompt_length
if gen_length <= 0:
print("Error: Prompt is already too long.")
return [([("Prompt too long.", "Error")],)], "Error: Prompt too long."
# --- State for Visualization Hook ---
visualization_states = []
last_x = None # Store the sequence from the previous step
# Initial state: Prompt + all masks
initial_x_part = torch.full((1, gen_length), MASK_ID, dtype=torch.long, device=device)
# Apply initial constraints to the masked part *before* showing the first state
for pos, token_id in processed_constraints.items():
absolute_pos = pos # Position relative to start of generation
if 0 <= absolute_pos < gen_length:
# Check if the constraint token itself is special
if token_id in SPECIAL_TOKENS_MAP:
print(f" Note: Constraint at pos {pos} is a special token: {SPECIAL_TOKENS_MAP[token_id]}")
initial_x_part[0, absolute_pos] = token_id
# --- Define the Hook Function ---
# This function will be called at each diffusion step
def generation_tokens_hook_func(step, x, logits):
nonlocal last_x, visualization_states # Allow modification of outer scope variables
# print(f"Hook called for step {step}") # Keep commented unless debugging
current_x = x.clone() # Work on a copy for comparison/modification
# 1. Apply Constraints *before* generating visualization for this step
# Constraints are relative to the start of the *generated* part
constrained_x = current_x.clone()
current_prompt_len = current_x.shape[1] - gen_length # Recalculate actual prompt length
if current_prompt_len < 0:
print("Warning: prompt_len negative in hook, skipping constraints/vis.")
return current_x # Return unmodified if something is wrong
for pos, token_id in processed_constraints.items():
absolute_pos = current_prompt_len + pos
if current_prompt_len <= absolute_pos < current_x.shape[1]:
# Apply constraint if the current token doesn't match
if constrained_x[0, absolute_pos] != token_id:
constrained_x[0, absolute_pos] = token_id
# print(f" Constraint applied at pos {pos} ({absolute_pos}) -> token {token_id}")
# 2. Generate Visualization State for *this* step
# Compare current_x (output of diffusion for this step, before constraints applied *in this call*)
# with last_x (state from *previous* hook call / initial state, *after* constraints were applied then)
current_state_vis = []
gen_part_current = current_x[0, current_prompt_len:]
gen_part_last = last_x[0, current_prompt_len:] if last_x is not None else None
for i in range(gen_length):
current_token_id = gen_part_current[i].item()
last_token_id = gen_part_last[i].item() if gen_part_last is not None else MASK_ID # Assume mask initially
# Determine display string - Handle special tokens explicitly
if current_token_id in SPECIAL_TOKENS_MAP:
display_token = SPECIAL_TOKENS_MAP[current_token_id]
else:
# Decode non-special tokens, skipping special tokens in the *output string*
# and stripping whitespace
display_token = tokenizer.decode([current_token_id],
skip_special_tokens=True,
clean_up_tokenization_spaces=True).strip()
# If decoding results in empty string for a non-special token, use a space perhaps
if not display_token:
display_token = " " # Use a single space as placeholder
# Determine category (label) for color mapping
category = "Old" # Default assume it was revealed before
is_constrained = i in processed_constraints
if current_token_id == MASK_ID:
category = "Mask"
elif is_constrained and processed_constraints[i] == current_token_id:
# Check if it was *just* constrained or already was correct
# We mark as 'Constraint' if it matches the required token, regardless of when it appeared
category = "Constraint"
elif last_token_id == MASK_ID and current_token_id != MASK_ID:
# It was a mask before, now it's not -> Newly revealed
# (Unless it's a constraint, handled above)
category = "New"
# else: category remains "Old"
current_state_vis.append((display_token, category))
visualization_states.append(current_state_vis)
# 3. Update last_x for the *next* step's comparison
# Store the state *after* applying constraints for accurate comparison next time
last_x = constrained_x.clone()
# 4. Return the sequence with constraints applied for the model's next step
return constrained_x # Return the sequence with constraints enforced
# --- Run DREAM Generation ---
try:
print("Calling model.diffusion_generate...")
# Make sure last_x is initialized correctly before the first hook call
# It should represent the state *before* the first diffusion step.
# Create the initial full sequence (prompt + initial masked/constrained part)
initial_full_x = torch.cat([input_ids, initial_x_part], dim=1)
last_x = initial_full_x.clone() # Initialize last_x with the state before step 0
# Add the very first visualization state (prompt + initial masks/constraints)
# This state corresponds to the `last_x` *before* the first hook call.
initial_state_vis = []
initial_gen_part = initial_full_x[0, prompt_length:]
for i in range(gen_length):
token_id = initial_gen_part[i].item()
category = "Mask"
display_token = MASK_TOKEN
if token_id != MASK_ID:
# This must be an initial constraint
category = "Constraint"
if token_id in SPECIAL_TOKENS_MAP:
display_token = SPECIAL_TOKENS_MAP[token_id]
else:
display_token = tokenizer.decode([token_id], skip_special_tokens=True).strip()
if not display_token: display_token = " " # Placeholder
initial_state_vis.append((display_token, category))
visualization_states.append(initial_state_vis)
output = model.diffusion_generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=gen_length,
output_history=False, # We build history in the hook
return_dict_in_generate=True,
steps=steps,
temperature=temperature,
top_p=top_p,
alg=alg,
alg_temp=alg_temp if alg != "origin" else 0.0, # alg_temp only for confidence algs
generation_tokens_hook_func=generation_tokens_hook_func
)
print("model.diffusion_generate finished.")
# Extract final generated sequence (response part only)
final_sequence = output.sequences[0]
response_token_ids = final_sequence[prompt_length:]
# Decode the final response, skipping special tokens for the final output text
final_text = tokenizer.decode(
response_token_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
).strip()
print(f"Final generated text: {final_text}")
# The hook should have added the last state, no need for safeguard typically
except Exception as e:
print(f"Error during generation: {e}")
import traceback
traceback.print_exc()
# Add error message to visualization using the "Error" category
error_msg = f"Error during generation: {str(e)}"
visualization_states.append([("Error", "Error")]) # Use 'Error' category
final_text = f"Generation failed: {e}"
print("--- DREAM Generation Finished ---")
# Return states list (already built by hook) and final text
return visualization_states, final_text
# --- Gradio UI Setup ---
css = '''
/* Hide the default legend */
.gradio-container .output-markdown table { display: none !important; }
.small_btn {
max-width: 100px; /* Adjust as needed */
min-width: 60px; /* Ensure button doesn't collapse */
height: 40px; /* Adjust as needed */
flex-grow: 0 !important; /* Prevent button from growing */
margin-left: 5px !important; /* Add some space */
margin-top: auto; /* Align button bottom with textbox */
margin-bottom: auto; /* Align button bottom with textbox */
line-height: 1; /* Adjust line height if text vertical align is off */
padding: 0 10px; /* Adjust padding */
}
.chat-input-row {
display: flex;
align-items: center; /* Vertically align items */
margin-bottom: 10px; /* Add space below input row */
}
.chat-input-row > * {
margin-right: 5px; /* Space between textbox and button */
}
.chat-input-row > *:last-child {
margin-right: 0;
}
/* Style HighlightedText elements */
.token-hl span {
padding: 2px 1px; /* Minimal padding */
margin: 0 1px; /* Minimal margin */
border-radius: 3px;
display: inline-block; /* Ensure background covers token */
line-height: 1.2; /* Adjust for better vertical spacing */
}
/* Custom legend styling */
.custom-legend span {
display: inline-block;
margin-right: 15px;
font-size: 0.9em;
}
.custom-legend span::before {
content: "■";
margin-right: 4px;
font-size: 1.1em; /* Make square slightly larger */
vertical-align: middle; /* Align square with text */
}
'''
# Define color map mapping CATEGORY names to colors
color_map = {
"Mask": "#A0A0A0", # Darker Gray for masks
"New": "#77DD77", # Light Green for new tokens
"Old": "#AEC6CF", # Light Blue/Gray for old tokens
"Constraint": "#C3A0E0", # Purple for constraints
"Error": "#FF6961" # Light Red for errors
}
# Create the custom legend HTML string
legend_html = "<div class='custom-legend'>"
for category, color in color_map.items():
legend_html += f"<span style='color:{color};'>{category}</span>"
legend_html += "</div>"
def create_chatbot_demo():
with gr.Blocks(css=css) as demo:
gr.Markdown("# Dream 7B - Diffusion Language Model Demo")
gr.Markdown("A demonstration of the Dream 7B diffusion-based language model. Watch the text generate step-by-step.")
gr.Markdown("[Model Card](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) - [Blog Post](https://hkunlp.github.io/blog/2025/dream/)")
# STATE MANAGEMENT
chat_history = gr.State([])
# UI COMPONENTS
with gr.Row():
with gr.Column(scale=3):
chatbot_ui = gr.Chatbot(
label="Conversation",
height=500,
bubble_full_width=False
)
# Message input Row
with gr.Row(elem_classes="chat-input-row"):
user_input = gr.Textbox(
label="Your Message",
placeholder="Type your message here and press Enter...",
scale=4,
container=False,
show_label=False
)
send_btn = gr.Button("Send", scale=1, elem_classes="small_btn")
constraints_input = gr.Textbox(
label="Word Constraints (Optional)",
info="Force specific words at positions (0-indexed from response start). Format: 'pos:word, pos:word'. Example: '0:Once, 5:upon, 10:time'",
placeholder="e.g., 0:Hello, 6:world",
value=""
)
with gr.Column(scale=2):
output_vis = gr.HighlightedText(
label="Denoising Process Visualization",
combine_adjacent=False, # Keep tokens separate
show_legend=True, # Hide default legend table
#color_map=color_map, # Provide the color map
#elem_classes="token-hl" # Add class for token styling
)
# Use Markdown to display the custom legend
gr.Markdown(legend_html)
# Advanced generation settings
with gr.Accordion("Generation Settings", open=False):
with gr.Row():
gen_length = gr.Slider(
minimum=16, maximum=512, value=128, step=8,
label="Max New Tokens"
)
steps = gr.Slider(
minimum=8, maximum=512, value=128, step=8,
label="Diffusion Steps"
)
with gr.Row():
temperature = gr.Slider(
minimum=0.0, maximum=1.5, value=0.6, step=0.05,
label="Temperature"
)
top_p = gr.Slider(
minimum=0.0, maximum=1.0, value=0.95, step=0.05,
label="Top-P (Nucleus Sampling)"
)
with gr.Row():
remasking_strategy = gr.Radio(
choices=[
("Random", "origin"),
("Entropy", "entropy"),
("MaskGit+", "maskgit_plus"),
("TopK Margin", "topk_margin"),
],
value="entropy",
label="Generation Order Strategy (alg)"
)
alg_temp = gr.Slider(
minimum=0.0, maximum=1.0, value=0.1, step=0.05,
label="Order Randomness (alg_temp)" ,
info="Adds randomness to non-Random strategies. Ignored for Random."
)
with gr.Row():
visualization_delay = gr.Slider(
minimum=0.0, maximum=0.5, value=0.05, step=0.01,
label="Visualization Delay (seconds)"
)
# Clear button
clear_btn = gr.Button("Clear Conversation")
# --- Event Handlers ---
# Helper to add message to history state
def add_message_to_history(history, message, response):
history = history.copy() # Modify copy
history.append([message, response])
return history
# Function when user submits message (Enter or Send button)
def user_message_submitted(message, history):
print(f"User submitted: '{message}'")
if not message or not message.strip():
print("Empty message submitted, doing nothing.")
return history, history, "", [] # history, chatbot_ui, user_input, output_vis
history = add_message_to_history(history, message, None)
history_for_display = history.copy()
message_out = ""
vis_clear = [] # Clear visualization when new message submitted
return history, history_for_display, message_out, vis_clear
# Function to generate bot response (triggered after user message is processed)
def bot_response_generator(
history, gen_length, steps, constraints_text, delay,
temperature, top_p, alg, alg_temp
):
print("--- Generating Bot Response ---")
if not history or history[-1][1] is not None:
print("History empty or last message already has response. Skipping generation.")
yield history, [], "No response generated." # Yield current state if called unnecessarily
return
messages = format_chat_history(history)
parsed_constraints = parse_constraints(constraints_text)
try:
vis_states, response_text = dream_generate_response_with_visualization(
messages,
gen_length=gen_length,
steps=steps,
constraints=parsed_constraints,
temperature=temperature,
top_p=top_p,
alg=alg,
alg_temp=alg_temp
)
# Update the history state only ONCE with the final bot response
final_history = history.copy() # Create copy to modify
final_history[-1][1] = response_text.strip() # Update the last element
# Yield visualization states one by one
# Important: Yield the *original* history for all intermediate steps,
# only yield the final_history with the *last* visualization state.
num_states = len(vis_states)
for i, state in enumerate(vis_states):
current_chatbot_state = history if i < num_states - 1 else final_history
yield current_chatbot_state, state
if delay > 0 and i < num_states - 1: # Don't sleep after last state
time.sleep(delay)
except Exception as e:
print(f"Error in bot_response_generator: {e}")
import traceback
traceback.print_exc()
error_msg = f"Error: {str(e)}"
error_vis = [(error_msg, "Error")] # Use Error category
# Update history with error message? Optional.
final_history_error = history.copy()
final_history_error[-1][1] = error_msg # Add error to chatbot too
yield final_history_error, error_vis
# Function to clear everything
def clear_conversation():
print("Clearing conversation.")
return [], [], "", [] # chat_history, chatbot_ui, user_input, output_vis
# --- Wire UI elements to functions ---
# Typing in Textbox and pressing Enter
submit_event = user_input.submit(
fn=user_message_submitted,
inputs=[user_input, chat_history],
outputs=[chat_history, chatbot_ui, user_input, output_vis],
queue=False # Show user message immediately
)
# Clicking the Send button
click_event = send_btn.click(
fn=user_message_submitted,
inputs=[user_input, chat_history],
outputs=[chat_history, chatbot_ui, user_input, output_vis],
queue=False
)
# Chain the generation after user message is processed (for both submit and click)
# Use .then() to trigger the generator
generation_inputs = [
chat_history, gen_length, steps, constraints_input, visualization_delay,
temperature, top_p, remasking_strategy, alg_temp
]
generation_outputs = [chatbot_ui, output_vis]
submit_event.then(
fn=bot_response_generator,
inputs=generation_inputs,
outputs=generation_outputs
)
click_event.then(
fn=bot_response_generator,
inputs=generation_inputs,
outputs=generation_outputs
)
# Clicking the Clear button
clear_btn.click(
fn=clear_conversation,
inputs=[],
outputs=[chat_history, chatbot_ui, user_input, output_vis],
queue=False
)
return demo
# --- Launch the Gradio App ---
if __name__ == "__main__":
print("Creating Gradio demo...")
demo = create_chatbot_demo()
print("Launching Gradio demo...")
# Use queue for potentially long generation times
# share=True generates a public link (useful for Colab/Spaces)
demo.queue().launch(share=True, debug=True) # Add debug=True for more logs |