File size: 21,859 Bytes
69595ed
c691b46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69595ed
c691b46
 
 
 
 
 
 
69595ed
 
 
 
 
 
 
 
 
c691b46
 
69595ed
c691b46
69595ed
 
 
 
 
 
 
 
 
 
c691b46
 
 
69595ed
 
 
0d2292c
69595ed
 
 
 
 
 
 
 
 
 
0d2292c
 
c691b46
69595ed
c691b46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69595ed
c691b46
 
 
69595ed
c691b46
69595ed
 
 
 
c691b46
 
 
69595ed
c691b46
 
 
 
 
 
69595ed
c691b46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69595ed
c691b46
 
69595ed
c691b46
 
 
 
69595ed
c691b46
69595ed
 
c691b46
 
 
 
0d2292c
c691b46
 
 
69595ed
c691b46
69595ed
c691b46
 
69595ed
c691b46
69595ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d2292c
c691b46
 
 
69595ed
 
c691b46
69595ed
c691b46
69595ed
 
c691b46
69595ed
c691b46
69595ed
 
c691b46
69595ed
 
c691b46
 
69595ed
c691b46
 
 
69595ed
 
c691b46
 
 
 
69595ed
 
 
 
 
 
 
 
 
 
 
 
 
c691b46
 
 
69595ed
c691b46
69595ed
 
 
 
 
 
 
c691b46
69595ed
c691b46
 
 
 
 
 
69595ed
 
c691b46
 
 
 
 
69595ed
c691b46
 
 
 
 
69595ed
c691b46
 
 
 
 
69595ed
c691b46
 
 
 
 
 
 
69595ed
c691b46
 
 
0d2292c
c691b46
 
 
69595ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c691b46
 
 
 
 
 
 
69595ed
 
c691b46
 
 
 
 
 
 
 
 
69595ed
 
 
 
 
 
c691b46
 
 
 
69595ed
c691b46
 
 
 
 
 
 
 
 
69595ed
c691b46
 
 
69595ed
 
c691b46
 
 
 
 
69595ed
 
c691b46
 
69595ed
c691b46
 
69595ed
 
 
 
 
 
 
 
 
 
 
 
 
c691b46
 
 
69595ed
c691b46
 
69595ed
 
c691b46
69595ed
 
c691b46
 
69595ed
 
c691b46
 
69595ed
c691b46
 
 
69595ed
c691b46
 
 
69595ed
c691b46
69595ed
c691b46
 
 
 
69595ed
c691b46
 
69595ed
c691b46
 
 
 
 
 
 
 
 
0d2292c
c691b46
 
0d2292c
c691b46
 
 
 
69595ed
 
c691b46
69595ed
c691b46
69595ed
 
 
 
 
0d2292c
69595ed
 
 
c691b46
 
 
69595ed
c691b46
69595ed
 
c691b46
 
69595ed
c691b46
69595ed
 
 
c691b46
69595ed
 
0d2292c
69595ed
c691b46
 
 
 
 
 
 
 
69595ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
# llada_app.py -> dream_app.py (v2)

import torch
import numpy as np
import gradio as gr
import spaces
# import torch.nn.functional as F # Not needed for DREAM's basic visualization
from transformers import AutoTokenizer, AutoModel
import time
import re # Keep for parsing constraints

# Use try-except for space deployment vs local
try:
    gpu_check = spaces.GPU
    print("Running in Gradio Spaces with GPU environment.")
except AttributeError:
    print("Running in local environment or without spaces.GPU.")
    def gpu_check(func): return func

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# --- Load DREAM Model and Tokenizer ---
model_path = "Dream-org/Dream-v0-Instruct-7B"
print(f"Loading model: {model_path}")
try:
    model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True).to(device).eval()
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    print("Model and tokenizer loaded.")
except Exception as e:
    print(f"FATAL: Could not load model/tokenizer. Error: {e}")
    # Optionally exit or raise
    raise SystemExit(f"Failed to load model: {e}")


# --- Constants for DREAM ---
# Find mask token and ID
if tokenizer.mask_token is None:
    print("Warning: Mask token not explicitly set in tokenizer. Trying to add '[MASK]'.")
    # This might require retraining/fine-tuning if the model didn't see it.
    # Check if it exists first before adding
    if '[MASK]' not in tokenizer.get_vocab():
         tokenizer.add_special_tokens({'mask_token': '[MASK]'})
         model.resize_token_embeddings(len(tokenizer)) # Resize model embeddings
         print("Added '[MASK]' and resized embeddings.")
    else:
         tokenizer.mask_token = '[MASK]' # Set it if it exists but wasn't assigned
         print("Found existing '[MASK]', assigned as mask_token.")

MASK_TOKEN = tokenizer.mask_token
MASK_ID = tokenizer.mask_token_id
if MASK_ID is None:
     raise ValueError("Failed to get MASK_ID after attempting to set mask_token.")
print(f"Using MASK_TOKEN='{MASK_TOKEN}' with ID={MASK_ID}")

# Get EOS and PAD token IDs
EOS_TOKEN_ID = tokenizer.eos_token_id
PAD_TOKEN_ID = tokenizer.pad_token_id
print(f"Using EOS_TOKEN_ID={EOS_TOKEN_ID}, PAD_TOKEN_ID={PAD_TOKEN_ID}")
# Handle cases where they might be None (though unlikely for most models)
if EOS_TOKEN_ID is None:
    print("Warning: EOS token ID not found.")
if PAD_TOKEN_ID is None:
    print("Warning: PAD token ID not found. Using EOS ID as fallback for hiding.")
    PAD_TOKEN_ID = EOS_TOKEN_ID # Use EOS as a fallback for hiding logic if PAD is missing


# --- Helper Functions (Constraint Parsing, History Formatting) ---
# (Keep parse_constraints and format_chat_history functions as they were)
def parse_constraints(constraints_text):
    """Parse constraints in format: 'position:word, position:word, ...'"""
    constraints = {}
    if not constraints_text:
        return constraints

    parts = constraints_text.split(',')
    for part in parts:
        part = part.strip() # Trim whitespace
        if ':' not in part:
            continue
        try:
            pos_str, word = part.split(':', 1)
            pos = int(pos_str.strip())
            word = word.strip()
            # Allow empty words if needed, but usually we want a word
            if word and pos >= 0:
                constraints[pos] = word
        except ValueError:
            print(f"Warning: Could not parse constraint part: '{part}'")
            continue

    return constraints

def format_chat_history(history):
    """
    Format chat history for the DREAM model (standard messages format)

    Args:
        history: List of [user_message, assistant_message] pairs

    Returns:
        Formatted conversation for the model (list of dictionaries)
    """
    messages = []
    # Add system prompt if desired (check DREAM examples/recommendations)
    # messages.append({"role": "system", "content": "You are a helpful assistant."}) # Optional
    for user_msg, assistant_msg in history:
        if user_msg: # Handle potential None message if clearing failed
             messages.append({"role": "user", "content": user_msg})
        if assistant_msg:  # Skip if None (for the latest user message awaiting response)
            messages.append({"role": "assistant", "content": assistant_msg})

    return messages

# --- Core Generation Logic for DREAM with Visualization ---

@gpu_check
def dream_generate_response_with_visualization(
    messages,
    gen_length=64,
    steps=64,
    constraints=None,
    temperature=0.6,
    top_p=0.95,
    alg="entropy",
    alg_temp=0.0,
):
    """
    Generate text with DREAM model with visualization using the generation hook.
    Hides special tokens (EOS, PAD) and uses labels for coloring.
    """
    print("--- Starting DREAM Generation ---")
    print(f"Parameters: gen_length={gen_length}, steps={steps}, temperature={temperature}, top_p={top_p}, alg='{alg}', alg_temp={alg_temp}")
    print(f"Constraints: {constraints}")

    # --- Input Preparation ---
    if constraints is None: constraints = {}

    processed_constraints = {}
    print("Processing constraints:")
    for pos, word in constraints.items():
        tokens = tokenizer.encode(" " + word, add_special_tokens=False)
        if not tokens:
            print(f"  Warning: Could not tokenize constraint word '{word}' at position {pos}. Skipping.")
            continue
        print(f"  Pos {pos}, Word '{word}' -> Tokens {tokens}")
        for i, token_id in enumerate(tokens):
            if pos + i not in processed_constraints:
                 processed_constraints[pos + i] = token_id
            else:
                 print(f"  Warning: Overlapping constraint at position {pos+i}. Keeping first.")

    try:
        inputs = tokenizer.apply_chat_template(
            messages, return_tensors="pt", return_dict=True, add_generation_prompt=True
        )
        input_ids = inputs.input_ids.to(device=device)
        attention_mask = inputs.attention_mask.to(device=device)
        prompt_length = input_ids.shape[1]
        print(f"Input prompt length: {prompt_length}")
    except Exception as e:
        print(f"Error applying chat template: {e}")
        return [([("Error applying chat template.", "Error")],)], f"Error: {e}" # Use 'Error' label

    # Check context length (DREAM uses 2048)
    if prompt_length + gen_length > 2048:
         print(f"Warning: Requested length ({prompt_length + gen_length}) exceeds model max length (2048). Truncating gen_length.")
         gen_length = 2048 - prompt_length
         if gen_length <= 0:
             print("Error: Prompt is already too long.")
             return [([("Prompt too long.", "Error")],)], "Error: Prompt too long."

    # --- State for Visualization Hook ---
    visualization_states = []
    last_x = None

    # Initial state: Prompt + all masks + initial constraints
    initial_x_part = torch.full((1, gen_length), MASK_ID, dtype=torch.long, device=device)
    for pos, token_id in processed_constraints.items():
        absolute_pos = pos
        if 0 <= absolute_pos < gen_length:
            initial_x_part[0, absolute_pos] = token_id

    initial_state_vis = []
    for i in range(gen_length):
        token_id = initial_x_part[0, i].item()
        if token_id == MASK_ID:
            initial_state_vis.append((MASK_TOKEN, "Mask"))
        elif token_id == EOS_TOKEN_ID or token_id == PAD_TOKEN_ID:
            initial_state_vis.append(("", None)) # Hide special tokens
        elif i in processed_constraints and processed_constraints[i] == token_id:
            token_str = tokenizer.decode([token_id], skip_special_tokens=True).strip()
            display_token = token_str if token_str else "?"
            initial_state_vis.append((display_token, "Constraint"))
        else:
             # Should only be constraints here, but add fallback
             token_str = tokenizer.decode([token_id], skip_special_tokens=True).strip()
             display_token = token_str if token_str else "?"
             initial_state_vis.append((display_token, "Old")) # Treat unexpected initial non-masks as 'Old'
    visualization_states.append(initial_state_vis)


    # --- Define the Hook Function ---
    def generation_tokens_hook_func(step, x, logits):
        nonlocal last_x, visualization_states
        # print(f"Hook called for step {step}") # Verbose logging

        current_x = x.clone()
        constrained_x = current_x.clone()
        prompt_len = current_x.shape[1] - gen_length
        if prompt_len < 0:
            print("Warning: prompt_len negative in hook, skipping constraints/vis.")
            return current_x

        # 1. Apply Constraints
        constraints_applied_this_step = False
        for pos, token_id in processed_constraints.items():
            absolute_pos = prompt_len + pos
            if prompt_len <= absolute_pos < current_x.shape[1]:
                if constrained_x[0, absolute_pos] != token_id:
                    constrained_x[0, absolute_pos] = token_id
                    constraints_applied_this_step = True

        # 2. Generate Visualization State for *this* step
        current_state_vis = []
        gen_part_current = current_x[0, prompt_len:]
        gen_part_last = last_x[0, prompt_len:] if last_x is not None else None

        for i in range(gen_length):
            current_token_id = gen_part_current[i].item()

            # --- Logic to Hide Special Tokens ---
            if current_token_id == EOS_TOKEN_ID or current_token_id == PAD_TOKEN_ID:
                # Maybe show on first appearance? For now, always hide.
                # LLaDA's behavior: "shown once and then disappear"
                # Let's implement the simpler "always hide" first.
                current_state_vis.append(("", None)) # Append empty string, no label -> hidden
                continue # Move to next token

            # --- Decode and Determine Label ---
            token_str = tokenizer.decode([current_token_id], skip_special_tokens=True).strip()
            display_token = token_str if token_str else MASK_TOKEN if current_token_id == MASK_ID else "?" # Use MASK_TOKEN if decode fails

            label = None # Default label (no color)
            is_constrained = i in processed_constraints

            if current_token_id == MASK_ID:
                label = "Mask"
            elif is_constrained and processed_constraints[i] == current_token_id:
                 label = "Constraint"
            elif gen_part_last is None or gen_part_last[i].item() == MASK_ID or gen_part_last[i].item() == EOS_TOKEN_ID or gen_part_last[i].item() == PAD_TOKEN_ID:
                # Newly revealed (was mask or hidden special token in previous step)
                label = "New"
            else:
                # Previously revealed and not masked/hidden/constrained
                label = "Old"

            current_state_vis.append((display_token, label))

        visualization_states.append(current_state_vis)

        # 3. Update last_x for the *next* step's comparison
        last_x = constrained_x.clone()

        # 4. Return the sequence with constraints applied
        return constrained_x

    # --- Run DREAM Generation ---
    try:
        print("Calling model.diffusion_generate...")
        initial_full_x = torch.cat([input_ids, initial_x_part], dim=1)
        last_x = initial_full_x.clone() # Initialize last_x *before* the call

        output = model.diffusion_generate(
            input_ids,
            attention_mask=attention_mask,
            max_new_tokens=gen_length,
            output_history=False,
            return_dict_in_generate=True,
            steps=steps,
            temperature=temperature,
            top_p=top_p,
            alg=alg,
            alg_temp=alg_temp if alg != "origin" else 0.0,
            generation_tokens_hook_func=generation_tokens_hook_func
        )
        print("model.diffusion_generate finished.")

        final_sequence = output.sequences[0]
        response_token_ids = final_sequence[prompt_length:]

        # Decode final text, skipping special tokens
        final_text = tokenizer.decode(
            response_token_ids,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        ).strip()
        print(f"Final generated text: {final_text}")

        # Safeguard: Add final state visualization if needed (using the new label logic)
        if len(visualization_states) <= steps:
             final_state_vis = []
             final_gen_part = final_sequence[prompt_length:]
             for i in range(gen_length):
                 token_id = final_gen_part[i].item()
                 if token_id == EOS_TOKEN_ID or token_id == PAD_TOKEN_ID:
                     final_state_vis.append(("", None))
                     continue

                 token_str = tokenizer.decode([token_id], skip_special_tokens=True).strip()
                 display_token = token_str if token_str else MASK_TOKEN if token_id == MASK_ID else "?"
                 label = None
                 is_constrained = i in processed_constraints

                 if token_id == MASK_ID: label = "Mask"
                 elif is_constrained and processed_constraints[i] == token_id: label = "Constraint"
                 else: label = "Old" # Default to 'Old' for final state non-masked tokens
                 final_state_vis.append((display_token, label))
             visualization_states.append(final_state_vis)


    except Exception as e:
        print(f"Error during generation: {e}")
        import traceback
        traceback.print_exc()
        error_msg = f"Error during generation: {str(e)}"
        # Use 'Error' label for color mapping
        visualization_states.append([("Error", "Error")])
        final_text = f"Generation failed: {e}"

    print("--- DREAM Generation Finished ---")
    return visualization_states, final_text


# --- Gradio UI Setup ---

css = '''
.category-legend{display:none}
/* button{height: 60px} */
.small_btn {max-width: 100px; height: 40px; flex-grow: 0; margin-left: 5px;}
.chat-input-row {display: flex; align-items: center;}
.chat-input-row > * {margin-right: 5px;}
.chat-input-row > *:last-child {margin-right: 0;}
'''
def create_chatbot_demo():
    with gr.Blocks(css=css) as demo:
        gr.Markdown("# Dream 7B - Diffusion Language Model Demo")
        gr.Markdown("Watch the text generate step-by-step. Special tokens (EOS, PAD) are hidden.")
        gr.Markdown("[Model Card](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) - [Blog Post](https://hkunlp.github.io/blog/2025/dream/)")

        # STATE MANAGEMENT
        chat_history = gr.State([])

        # UI COMPONENTS
        with gr.Row():
            with gr.Column(scale=3):
                chatbot_ui = gr.Chatbot(
                    label="Conversation", height=500, bubble_full_width=False
                 )
                with gr.Row(elem_classes="chat-input-row"):
                        user_input = gr.Textbox(
                            label="Your Message", placeholder="Type your message...",
                            scale=4, container=False, show_label=False
                        )
                        send_btn = gr.Button("Send", scale=1, elem_classes="small_btn")

                constraints_input = gr.Textbox(
                    label="Word Constraints (Optional)",
                    info="Format: 'pos:word, pos:word'. Example: '0:Once, 5:upon'",
                    placeholder="e.g., 0:Hello, 6:world", value=""
                )
            with gr.Column(scale=2):
                # --- Updated HighlightedText with color_map ---
                output_vis = gr.HighlightedText(
                    label="Denoising Process Visualization",
                    combine_adjacent=True, # Combine adjacent tokens with same label
                    show_legend=False, # Keep legend off
                    color_map={ # Map labels to colors
                        "Mask": "#A0A0A0", # Lighter Gray for Mask
                        "New": "#66CC66", # Light Green
                        "Old": "#6699CC", # Light Blue
                        "Constraint": "#B266FF", # Lighter Purple/Violet
                        "Error": "#FF6666" # Light Red
                    }
                )
                gr.Markdown(
                     # Update legend text to match labels
                    "**Color Legend:** <span style='color:#A0A0A0'>■ Mask</span> | <span style='color:#66CC66'>■ New</span> | <span style='color:#6699CC'>■ Old</span> | <span style='color:#B266FF'>■ Constraint</span>"
                )


        # Advanced generation settings (Keep as before)
        with gr.Accordion("Generation Settings", open=False):
            with gr.Row():
                gen_length = gr.Slider(minimum=16, maximum=512, value=128, step=8, label="Max New Tokens")
                steps = gr.Slider(minimum=8, maximum=512, value=128, step=8, label="Diffusion Steps")
            with gr.Row():
                temperature = gr.Slider(minimum=0.0, maximum=1.5, value=0.6, step=0.05, label="Temperature")
                top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="Top-P (Nucleus Sampling)")
            with gr.Row():
                remasking_strategy = gr.Radio(
                    choices=[("Random", "origin"), ("Entropy", "entropy"), ("MaskGit+", "maskgit_plus"), ("TopK Margin", "topk_margin")],
                    value="entropy", label="Generation Order Strategy (alg)"
                )
                alg_temp = gr.Slider(
                    minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="Order Randomness (alg_temp)",
                    info="Adds randomness to non-Random strategies. Ignored for Random."
                )
            with gr.Row():
                visualization_delay = gr.Slider(minimum=0.0, maximum=0.5, value=0.05, step=0.01, label="Visualization Delay (seconds)")

        clear_btn = gr.Button("Clear Conversation")

        # --- Event Handlers (Keep as before) ---
        def add_message_to_history(history, message, response):
            history = history.copy(); history.append([message, response]); return history

        def user_message_submitted(message, history):
            print(f"User submitted: '{message}'")
            if not message or not message.strip():
                 print("Empty message submitted, doing nothing."); return history, history, "", []
            history = add_message_to_history(history, message, None)
            history_for_display = history.copy()
            message_out = ""; vis_clear = []
            return history, history_for_display, message_out, vis_clear

        def bot_response_generator(
            history, gen_length, steps, constraints_text, delay,
            temperature, top_p, alg, alg_temp
            ):
            print("--- Generating Bot Response ---")
            if not history or history[-1][1] is not None:
                print("History empty or last message already has response. Skipping generation.")
                yield history, [], "No response generated." # Yield current state if called unnecessarily
                return

            messages = format_chat_history(history)
            parsed_constraints = parse_constraints(constraints_text)

            try:
                vis_states, response_text = dream_generate_response_with_visualization(
                    messages, gen_length=gen_length, steps=steps, constraints=parsed_constraints,
                    temperature=temperature, top_p=top_p, alg=alg, alg_temp=alg_temp
                )
                history[-1][1] = response_text.strip() # Update history state

                if vis_states:
                    # Yield initial state first
                    yield history, vis_states[0] # Update chatbot, update visualization
                    # Animate remaining states
                    for state in vis_states[1:]:
                        time.sleep(delay)
                        yield history, state # Update chatbot (implicitly), update visualization
                else:
                    yield history, [("Generation failed.", "Error")] # Use label

            except Exception as e:
                print(f"Error in bot_response_generator: {e}")
                import traceback; traceback.print_exc()
                error_msg = f"Error: {str(e)}"
                error_vis = [(error_msg, "Error")] # Use label
                yield history, error_vis

        def clear_conversation():
            print("Clearing conversation."); return [], [], "", []

        # --- Wire UI elements (Keep as before) ---
        user_input.submit(fn=user_message_submitted, inputs=[user_input, chat_history], outputs=[chat_history, chatbot_ui, user_input, output_vis], queue=False)\
                  .then(fn=bot_response_generator, inputs=[history, gen_length, steps, constraints_input, visualization_delay, temperature, top_p, remasking_strategy, alg_temp], outputs=[chatbot_ui, output_vis])

        send_btn.click(fn=user_message_submitted, inputs=[user_input, chat_history], outputs=[chat_history, chatbot_ui, user_input, output_vis], queue=False)\
                .then(fn=bot_response_generator, inputs=[history, gen_length, steps, constraints_input, visualization_delay, temperature, top_p, remasking_strategy, alg_temp], outputs=[chatbot_ui, output_vis])

        clear_btn.click(fn=clear_conversation, inputs=[], outputs=[chat_history, chatbot_ui, user_input, output_vis], queue=False)

    return demo

# --- Launch the Gradio App ---
if __name__ == "__main__":
    print("Creating Gradio demo...")
    demo = create_chatbot_demo()
    print("Launching Gradio demo...")
    demo.queue().launch(share=True, debug=True)