jannisborn commited on
Commit
be74e38
·
unverified ·
1 Parent(s): 0dc70d1
Files changed (1) hide show
  1. src/streamlit_app.py +425 -437
src/streamlit_app.py CHANGED
@@ -1,193 +1,392 @@
1
- import logging
2
- import time
3
-
4
  import altair as alt
5
- import numpy as np
6
  import pandas as pd
7
- import streamlit as st
8
  import streamlit_vertical_slider as svs
9
  import torch
10
-
11
- from scenarios import dirac, gauss, make_bimodal_scenarios
12
-
13
- logging.getLogger("streamlit.watcher.local_sources_watcher").setLevel(logging.ERROR)
14
-
15
- DEMO_INTERVAL = 1.5
16
- CE_SCALING = 0.25
17
- MAX_LOSS_PLOT = 6
18
- LAST_STEP = -1
19
-
20
 
21
  # Define options globally as it's used in initialization and UI
22
  options = [str(i) for i in range(10)] + ["Text"]
23
 
24
-
25
- def compute_losses(probs: torch.Tensor, gt_token: str) -> tuple[float, float, float]:
26
- """Compute CE, NTL-MAE, NTL-WAS losses for the given probability vector and ground truth token."""
27
- ce_loss = CE_SCALING * -torch.log(
28
- torch.clamp(probs[options.index(gt_token)], min=1e-9)
29
- )
30
-
31
- numeric_mass = probs[:10].sum()
32
-
33
- if gt_token == "Text" or numeric_mass < 1e-6:
34
- return ce_loss.item(), 0.0, 0.0
35
-
36
- gt_numeric = int(gt_token)
37
- token_vals = torch.arange(10, dtype=torch.float32)
38
- mae = numeric_mass * abs(torch.dot(token_vals, probs[:10]) - gt_numeric)
39
- was = numeric_mass * torch.dot(probs[:10], torch.abs(token_vals - gt_numeric))
40
- return round(ce_loss.item(), 3), round(mae.item(), 3), round(was.item(), 3)
41
-
42
-
43
  # --- Session State Initialization ---
44
  # Ensure all session state variables are initialized before first use, especially by widgets.
45
- if "running_demo" not in st.session_state:
46
  st.session_state.running_demo = False
47
- if "demo_step" not in st.session_state:
48
  st.session_state.demo_step = 0
49
- if "last_update_time" not in st.session_state:
50
  st.session_state.last_update_time = 0
51
- if "loss_container" not in st.session_state:
52
  st.session_state.loss_container = None
53
- if "previous_chart_html" not in st.session_state:
54
  st.session_state.previous_chart_html = ""
55
- if "active_scenarios" not in st.session_state:
56
- # default if you want one to load on first show
57
- st.session_state.active_scenarios = dirac
58
- if "loss_history" not in st.session_state:
59
- st.session_state.loss_history = []
60
-
61
 
62
  # Initialize states for sliders and ground_truth selector
63
  # Using len(options) to correctly size for 0-9 + "Text"
64
  for i in range(len(options)):
65
  if f"slider_{i}" not in st.session_state:
66
- st.session_state[f"slider_{i}"] = 0
67
- if "ground_truth" not in st.session_state:
68
- st.session_state["ground_truth"] = options[5]
69
- if "manual_ground_truth" not in st.session_state:
70
- st.session_state["manual_ground_truth"] = options[5]
71
- if "demo_name" not in st.session_state:
72
- st.session_state["demo_name"] = "Dirac"
73
 
74
 
75
- st.title("NTL -- The Number Token Loss 🚀")
76
-
77
- st.markdown(
78
- """This is the interactive demo for our [ICML 2025](https://arxiv.org/abs/2411.02083) paper!🎉
79
- ➡️ NTL augments cross-entropy to help LMs reason better with numbers 🧠
80
- """
81
- )
82
-
83
- st.subheader("Demo 1 — NTL vs. Cross Entropy in 3 Scenarios")
84
 
85
  st.markdown("""
86
- 1️⃣ Pick a ground truth token: a digit (09) or "Text" 📝 (simulates generic text tokens).
87
- 2️⃣ Choose a demo:
88
- - **Dirac** ⚡: All probability mass on one token.
89
- - **Gaussian** 🌊: Soft bell-curve around the true number.
90
- - **Bimodal** 🎯: Two peaks moving away from the target.
91
-
92
- Watch how losses evolve as predictions get worse — and see how NTL shines compared to CE! 🌟
93
  """)
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- if "ground_truth" not in st.session_state:
97
- st.session_state["ground_truth"] = "4"
98
- gt = st.selectbox("Ground Truth Token", options=options, key="ground_truth")
99
-
100
-
101
  def apply_scenario(step_idx):
102
- scenario = st.session_state.active_scenarios[step_idx]
 
 
103
  for i, val in enumerate(scenario["values"]):
104
  st.session_state[f"slider_{i}"] = val
 
105
 
106
-
107
- def start_dirac_demo():
108
- st.session_state.loss_history = []
109
- st.session_state.active_scenarios = dirac
110
- st.session_state.demo_name = "Dirac"
111
  st.session_state.running_demo = True
112
  st.session_state.demo_step = 0
113
  st.session_state.last_update_time = time.time()
114
- apply_scenario(0)
115
-
116
-
117
- def start_gauss_demo():
118
- st.session_state.loss_history = []
119
- st.session_state.active_scenarios = gauss
120
- st.session_state.demo_name = "Gauss"
121
- st.session_state.running_demo = True
122
- st.session_state.demo_step = 0
123
- st.session_state.last_update_time = time.time()
124
- apply_scenario(0)
125
-
126
-
127
- def start_bimodal_demo():
128
- st.session_state.loss_history = []
129
- gt = st.session_state["ground_truth"]
130
- st.session_state.active_scenarios = make_bimodal_scenarios(gt, options)
131
-
132
- st.session_state.demo_name = f"Bimodal (GT={gt})"
133
- st.session_state.running_demo = True
134
- st.session_state.demo_step = 0
135
- st.session_state.last_update_time = time.time()
136
- apply_scenario(0)
137
-
138
 
139
  def stop_demo():
140
  st.session_state.running_demo = False
141
 
142
-
143
  # --- Demo State Advancement Logic ---
144
  # This block handles advancing the demo. If it advances, it updates session state
145
  # and then reruns. This ensures widgets are drawn with the new state in the next run.
146
  if st.session_state.running_demo:
147
- scenario = st.session_state.active_scenarios
148
  current_time = time.time()
149
- if current_time - st.session_state.last_update_time > DEMO_INTERVAL:
150
- # if we haven’t yet shown the last scenario, advance
151
- if st.session_state.demo_step < len(scenario) - 1:
152
- st.session_state.demo_step += 1
153
- apply_scenario(st.session_state.demo_step)
154
- st.session_state.last_update_time = current_time
155
- st.rerun()
156
- else:
157
- # we just displayed the final case → stop
158
- st.session_state.running_demo = False
159
 
160
  # --- UI Rendering ---
161
  # This section renders the main UI. It executes after any potential rerun from the block above.
162
 
163
  if st.session_state.running_demo:
164
- st.info(
165
- f"Showing scenario {st.session_state.demo_step + 1}"
166
- f"/{len(st.session_state.active_scenarios)}: "
167
- f"{st.session_state.active_scenarios[st.session_state.demo_step]['name']}"
168
- )
169
  if st.button("Stop Demo"):
170
- st.session_state.running_demo = False
171
  st.rerun()
172
- else:
173
- col1, col2, col3 = st.columns(3)
174
- with col1:
175
- if st.button("Run: Dirac"):
176
- start_dirac_demo()
177
- st.rerun()
178
- with col2:
179
- if st.button("Run: Gauss"):
180
- start_gauss_demo()
181
- st.rerun()
182
- with col3:
183
- if st.button("Run: Bimodal"):
184
- start_bimodal_demo()
185
- st.rerun()
186
-
187
- current_prob_values_from_state = [
188
- st.session_state.get(f"slider_{j}", 0)
189
- for j in range(len(options)) # 1.0 / len(options)) for j in range(len(options))
190
- ]
 
 
 
 
 
 
 
 
 
 
 
 
191
  total_from_state = sum(current_prob_values_from_state)
192
  probs_for_charts = (
193
  torch.ones(len(options)) / len(options)
@@ -195,322 +394,112 @@ probs_for_charts = (
195
  else torch.tensor([v / total_from_state for v in current_prob_values_from_state])
196
  )
197
 
198
- # Use manual GT token when not in running demo
199
- gt_choice_for_charts = (
200
- st.session_state["manual_ground_truth"]
201
- if not st.session_state.running_demo
202
- else st.session_state["ground_truth"]
203
- )
204
  if gt_choice_for_charts == "Text":
205
- gt_index_for_charts = 10 # Assuming "Text" is the 11th item (index 10)
206
  gt_numeric_for_charts = None
207
  else:
208
  gt_index_for_charts = int(gt_choice_for_charts)
209
  gt_numeric_for_charts = gt_index_for_charts
210
 
211
- gt = st.session_state["ground_truth"]
212
- demo_name = st.session_state["demo_name"]
213
-
214
- st.markdown(f"#### Predicted distribution — ground truth: {gt}")
215
- df_dist = pd.DataFrame(
216
- {"token": options, "probability": probs_for_charts.numpy().round(2)}
217
- )
218
- df_dist["type"] = [
219
- "Ground Truth" if token == gt_choice_for_charts else "Prediction"
220
- for token in options
221
- ]
222
-
223
- bars = (
224
- alt.Chart(df_dist)
225
- .mark_bar(color="dodgerblue", size=40)
226
- .encode(
227
- x=alt.X(
228
- "token:N",
229
- title="Token",
230
- sort=options,
231
- axis=alt.Axis(
232
- labelAngle=0,
233
- labelFontSize=14,
234
- titleFontSize=16,
235
- labelAlign="center",
236
- labelFlush=False,
237
- ),
238
- ),
239
- y=alt.Y(
240
- "probability:Q",
241
- title="Probability",
242
- scale=alt.Scale(domain=[0, 1]),
243
- axis=alt.Axis(format=".2f", labelFontSize=14, titleFontSize=16),
244
- ),
245
- tooltip=[
246
- alt.Tooltip("token:N", title="Token"),
247
- alt.Tooltip("probability:Q", title="Predicted Prob.", format=".2f"),
248
- ],
249
- )
250
- )
251
-
252
- bg_bar = pd.DataFrame({"token": [gt], "height": [1.0]})
253
- gt_bar = (
254
- alt.Chart(bg_bar)
255
- .mark_bar(
256
- color="darkgreen",
257
- size=20,
258
- opacity=0.3,
259
- stroke="gray",
260
- strokeWidth=2,
261
- strokeDash=[4, 4],
262
- )
263
- .encode(
264
- x=alt.X("token:N", sort=options),
265
- y=alt.Y("height:Q", scale=alt.Scale(domain=[0, 1])),
266
- tooltip=[
267
- alt.Tooltip("token:N", title="Ground Truth"),
268
- alt.Tooltip("height:Q", title="Desired mass", format=".2f"),
269
- ],
270
- )
271
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
- annot1 = (
274
- alt.Chart(pd.DataFrame({"token": [gt]}))
275
- .mark_text(
276
- text="⬇ Ground",
277
- dy=-25, # 10px above the top of the bar
278
- dx=25,
279
- fontSize=14,
280
- fontWeight="bold",
281
- color="darkgreen",
282
- )
283
- .encode(x=alt.X("token:N", sort=options), y=alt.value(1))
284
- )
285
 
286
- annot2 = (
287
- alt.Chart(pd.DataFrame({"token": [gt]}))
288
- .mark_text(
289
- text=f"truth={gt}",
290
- dy=-10, # 25px above the top, so it sits above line 1
291
- dx=35,
292
- fontSize=14,
293
- fontWeight="bold",
294
- color="darkgreen",
295
- )
296
- .encode(x=alt.X("token:N", sort=options), y=alt.value(1))
297
- )
298
 
299
- # 4) Layer them in order: background, bars, annotation
300
- final_chart = (gt_bar + bars + annot1 + annot2).properties(height=200)
 
 
 
 
 
301
 
302
- st.altair_chart(final_chart, use_container_width=True)
303
- ce_val, mae_val, was_val = compute_losses(probs_for_charts, gt_choice_for_charts)
304
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
- if (
307
- st.session_state.running_demo
308
- and len(st.session_state.loss_history) < st.session_state.demo_step + 1
309
- ):
310
- step = st.session_state.demo_step
311
- scenario = st.session_state.active_scenarios[step]
312
- ce, mae, was = compute_losses(probs_for_charts, gt_choice_for_charts)
313
 
314
- # pick x_val differently for bimodal vs others
315
- if st.session_state.demo_name.startswith("Bimodal"):
316
- x_val = scenario["name"] # e.g. "(4,4)", "(3,5)", …
317
- else:
318
- # exactly like before:
319
- best_idx = np.argmax(scenario["values"])
320
- x_val = options[best_idx] # "0", "1", …, or "Text"
321
-
322
- st.session_state.loss_history.append(
323
- {
324
- "step": step,
325
- "x_val": x_val,
326
- "Cross Entropy": ce,
327
- "NTL-MAE": mae,
328
- "NTL-WAS": was,
329
- }
330
- )
331
-
332
-
333
- # 1) build a raw DF from histories
334
- df = pd.DataFrame(st.session_state.loss_history)
335
-
336
- if df.empty:
337
- # define an empty "melted" DataFrame with the right columns
338
- df_loss_plot = pd.DataFrame(columns=["step", "x_val", "Loss Type", "Loss Value"])
339
- else:
340
- # now it's safe to melt
341
- df_loss_plot = df.melt(
342
- id_vars=["step", "x_val"],
343
- value_vars=["Cross Entropy", "NTL-MAE", "NTL-WAS"],
344
- var_name="Loss Type",
345
- value_name="Loss Value",
346
- )
347
 
348
 
349
  loss_data = {"Loss": ["Cross Entropy"], "Value": [ce_val]}
350
  if was_val != "N/A":
351
  loss_data["Loss"].append("NTL-WAS")
352
  loss_data["Value"].append(was_val)
353
- if mae_val != "N/A":
354
- loss_data["Loss"].append("NTL-MAE")
355
- loss_data["Value"].append(mae_val)
356
 
357
  loss_df = pd.DataFrame(loss_data)
358
 
359
- if st.session_state.demo_name.startswith("Bimodal"):
360
- domain = [sc["name"] for sc in st.session_state.active_scenarios]
361
- x_title = f"Offset from GT {st.session_state['ground_truth']}"
362
- else:
363
- domain = options
364
- x_title = f"Maximum of predicted {st.session_state['demo_name']} distribution"
365
-
366
-
367
  # ============== Chart Display ==============
368
-
369
-
370
- st.markdown("#### Loss as a function of predicted distribution")
371
-
372
- grouped_chart = (
373
- alt.Chart(df_loss_plot)
374
- .mark_bar()
375
- .encode(
376
- x=alt.X(
377
- "x_val:N",
378
- title=x_title,
379
- sort=domain,
380
- scale=alt.Scale(domain=domain),
381
- axis=alt.Axis(labelAngle=0, labelFontSize=14, titleFontSize=16),
382
- ),
383
- y=alt.Y(
384
- "Loss Value:Q",
385
- title="Loss Value",
386
- scale=alt.Scale(domain=[0, MAX_LOSS_PLOT], nice=False, clamp=True),
387
- axis=alt.Axis(labelFontSize=14, titleFontSize=16),
388
- ),
389
- color=alt.Color(
390
- "Loss Type:N",
391
- scale=alt.Scale(
392
- domain=["Cross Entropy", "NTL-WAS", "NTL-MAE"],
393
- range=["red", "limegreen", "blueviolet"],
394
- ),
395
- legend=alt.Legend(
396
- title="",
397
- orient="top",
398
- direction="horizontal",
399
- columns=3,
400
- ),
401
- ),
402
- xOffset="Loss Type:N", # grouped bars
403
- tooltip=[
404
- alt.Tooltip("x_val:N", title="Scenario"),
405
- alt.Tooltip("Loss Type:N", title="Loss Type"),
406
- alt.Tooltip("Loss Value:Q", title="Value", format=".3f"),
407
- ],
408
- )
409
- .properties(height=250)
410
- )
411
- st.altair_chart(grouped_chart, use_container_width=True)
412
-
413
-
414
  # Create a single chart for loss visualization
415
- if not st.session_state.running_demo:
416
- for i in range(len(options)):
417
- st.session_state[f"slider_{i}"] = 0.0
418
- st.session_state.demo_step = 0
 
 
 
 
 
 
 
 
 
 
419
 
420
- st.subheader("Demo 2 -- Manual loss comparison")
421
- st.subheader("🧪 Demo 2 — Craft your own distribution")
422
- st.markdown("""
423
- This demo gives you more control but is harder to interpret. See it as a playground! 🎨
424
- Manually adjust the sliders to change the predicted probabilities for each token.
425
- The demo normalizes the values to form a valid probability distribution and calculates the losses.
426
-
427
- 👣 **Steps:**
428
- - Use the **vertical sliders** to allocate probability to each token.
429
- - Choose the correct **Ground Truth Token** (0–9 or "Text" 📜).
430
- - Observe how each loss function reacts.
431
-
432
- 💡 **Tip:** Want to trick the loss? Try putting all mass on the wrong token or spread it wildly. See how NTL handles it! 😈
433
- """)
434
-
435
- manual_gt = st.selectbox(
436
- "Ground Truth Token",
437
- options=options,
438
- key="manual_ground_truth",
439
- )
440
-
441
- loss_df = pd.DataFrame(
442
- {
443
- "Loss": ["Cross Entropy", "NTL-MAE", "NTL-WAS"],
444
- "Value": [ce_val, mae_val, was_val],
445
- }
446
- )
447
-
448
- # Sliders and Ground Truth Selector
449
- # These widgets will read their initial values from st.session_state.
450
- # User interactions will update st.session_state directly due to their keys.
451
- st.markdown("#### Adjust the predicted token probability")
452
- cols = st.columns(len(options))
453
- for i, col in enumerate(cols):
454
- label = options[i] # Use token name directly for label
455
- with col:
456
- svs.vertical_slider(
457
- label=label,
458
- min_value=0.0,
459
- max_value=1.0,
460
- step=0.01,
461
- height=50,
462
- key=f"slider_{i}",
463
- slider_color="green",
464
- track_color="lightgray",
465
- thumb_color="black",
466
- )
467
 
468
- chart = (
469
- alt.Chart(loss_df)
470
- .mark_bar()
471
- .encode(
472
- x=alt.X("Loss:N", sort=loss_df["Loss"].tolist()),
473
- y=alt.Y(
474
- "Value:Q",
475
- scale=alt.Scale(
476
- domain=[
477
- 0,
478
- max(
479
- loss_df["Value"].max() * 1.2,
480
- 20 if st.session_state.running_demo else 0.5,
481
- ),
482
- ]
483
- ),
484
- ),
485
- color=alt.Color(
486
- "Loss:N",
487
- scale=alt.Scale(
488
- domain=["Cross Entropy", "NTL-WAS", "NTL-MAE"],
489
- range=["orangered", "limegreen", "blueviolet"],
490
- ),
491
- ),
492
- tooltip=["Loss", "Value"],
493
- )
494
- .properties(height=300)
495
- )
496
-
497
- text = chart.mark_text(
498
- align="center", baseline="bottom", dy=-5, fontSize=14
499
- ).encode(text=alt.Text("Value:Q", format=".3f"))
500
- final_chart = chart + text
501
- st.altair_chart(final_chart, use_container_width=True)
502
-
503
-
504
- # # Add value labels on top of bars
505
- # text = chart.mark_text(align="center", baseline="bottom", dy=-5, fontSize=14).encode(
506
- # text=alt.Text("Value:Q", format=".3f")
507
- # )
508
-
509
- # # Combine chart and text
510
- # final_chart = chart + text
511
 
512
  # Display chart with the full container width
513
- # st.altair_chart(final_chart, use_container_width=True)
514
 
515
  # --- Polling Rerun for Demo Mode ---
516
  # If the demo is running and we haven't just advanced (which would have caused a rerun),
@@ -518,21 +507,20 @@ if not st.session_state.running_demo:
518
  if st.session_state.running_demo:
519
  # This check is implicitly: if we are here and demo is running, it means
520
  # the time-based advance condition was NOT met in the block at the top.
521
- time.sleep(0.1)
522
  st.rerun()
523
 
524
-
525
  st.markdown("""
526
- ### 🤔 TL;DR Why NTL?
527
- Cross Entropy only cares if the prediction is exactly right or wrong ❌✅ — it doesn’t care *how close* a guess is!
528
- That’s bad for LLMs doing math and numeric reasoning 🧮.
529
 
530
- 💥 NTL fixes that: it behaves like a regression loss on the token head, rewarding predictions that are numerically close.
 
531
  """)
532
 
533
- st.markdown("#### 📚 Further Resources")
 
534
  st.markdown("""
535
- - 📄 [ICML 2025 Paper](https://arxiv.org/abs/2411.02083)
536
- - 🌐 [NTL Landing Page](https://tum-ai.github.io/number-token-loss/)
537
- - 💻 [GitHub Code](https://github.com/tum-ai/number-token-loss)
538
  """)
 
 
 
 
1
  import altair as alt
 
2
  import pandas as pd
 
3
  import streamlit_vertical_slider as svs
4
  import torch
5
+ # from streamlit_vertical_slider import vertical_slider # Not directly used, svs.vertical_slider is
6
+ import streamlit as st
7
+ import time
8
+ import plotly.graph_objects as go # Add Plotly import
 
 
 
 
 
 
9
 
10
  # Define options globally as it's used in initialization and UI
11
  options = [str(i) for i in range(10)] + ["Text"]
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # --- Session State Initialization ---
14
  # Ensure all session state variables are initialized before first use, especially by widgets.
15
+ if 'running_demo' not in st.session_state:
16
  st.session_state.running_demo = False
17
+ if 'demo_step' not in st.session_state:
18
  st.session_state.demo_step = 0
19
+ if 'last_update_time' not in st.session_state:
20
  st.session_state.last_update_time = 0
21
+ if 'loss_container' not in st.session_state:
22
  st.session_state.loss_container = None
23
+ if 'previous_chart_html' not in st.session_state:
24
  st.session_state.previous_chart_html = ""
 
 
 
 
 
 
25
 
26
  # Initialize states for sliders and ground_truth selector
27
  # Using len(options) to correctly size for 0-9 + "Text"
28
  for i in range(len(options)):
29
  if f"slider_{i}" not in st.session_state:
30
+ st.session_state[f"slider_{i}"] = 1.0 / len(options)
31
+ if 'ground_truth' not in st.session_state:
32
+ st.session_state['ground_truth'] = options[0] # Default to "0"
 
 
 
 
33
 
34
 
35
+ st.title("Number Token Loss - Demo")
 
 
 
 
 
 
 
 
36
 
37
  st.markdown("""
38
+ Adjust the sliders to set a predicted probability for each token (0-9 and "Text").
39
+ The sliders are vertical and compact. The app normalizes the slider values
40
+ to form a valid probability distribution, visualizes it, and computes the corresponding
41
+ Cross Entropy, NTL-MSE, and NTL-WAS losses.
 
 
 
42
  """)
43
 
44
+ # --- Scenario Definitions ---
45
+ scenarios = [
46
+ {
47
+ "name": "Probability mass at 0",
48
+ "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
49
+ "ground_truth": "0",
50
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
51
+ },
52
+ {
53
+ "name": "Probability mass at 0",
54
+ "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
55
+ "ground_truth": "1",
56
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
57
+ },
58
+ {
59
+ "name": "Probability mass at 0",
60
+ "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
61
+ "ground_truth": "2",
62
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
63
+ },
64
+ {
65
+ "name": "Probability mass at 0",
66
+ "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
67
+ "ground_truth": "3",
68
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
69
+ },
70
+ {
71
+ "name": "Probability mass at 0",
72
+ "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
73
+ "ground_truth": "4",
74
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
75
+ },
76
+ {
77
+ "name": "Probability mass at 0",
78
+ "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
79
+ "ground_truth": "5",
80
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
81
+ },
82
+ {
83
+ "name": "Probability mass at 0",
84
+ "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
85
+ "ground_truth": "6",
86
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
87
+ },
88
+ {
89
+ "name": "Probability mass at 0",
90
+ "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
91
+ "ground_truth": "7",
92
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
93
+ },
94
+ {
95
+ "name": "Probability mass at 0",
96
+ "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
97
+ "ground_truth": "8",
98
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
99
+ },
100
+ {
101
+ "name": "Probability mass at 0",
102
+ "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
103
+ "ground_truth": "9",
104
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
105
+ },
106
+
107
+
108
+ {
109
+ "name": "Probability mass around 5",
110
+ "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
111
+ "ground_truth": "0",
112
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
113
+ },
114
+ {
115
+ "name": "Probability mass around 5",
116
+ "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
117
+ "ground_truth": "1",
118
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
119
+ },
120
+ {
121
+ "name": "Probability mass around 5",
122
+ "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
123
+ "ground_truth": "2",
124
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
125
+ },
126
+ {
127
+ "name": "Probability mass around 5",
128
+ "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
129
+ "ground_truth": "3",
130
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
131
+ },
132
+ {
133
+ "name": "Probability mass around 5",
134
+ "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
135
+ "ground_truth": "4",
136
+ "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
137
+ },
138
+ {
139
+ "name": "Probability mass around ground truth (5)",
140
+ "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
141
+ "ground_truth": "5",
142
+ "explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth."
143
+ },
144
+ {
145
+ "name": "Probability mass around 5",
146
+ "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
147
+ "ground_truth": "6",
148
+ "explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth."
149
+ },
150
+ {
151
+ "name": "Probability mass around 5",
152
+ "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
153
+ "ground_truth": "7",
154
+ "explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth."
155
+ },
156
+ {
157
+ "name": "Probability mass around 5",
158
+ "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
159
+ "ground_truth": "8",
160
+ "explanation": "Cross Entropy is high, NTL is higher but still penalizes less than CE because distribution knows it's a number."
161
+ },
162
+ {
163
+ "name": "Probability mass around 5",
164
+ "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
165
+ "ground_truth": "9",
166
+ "explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth."
167
+ },
168
+
169
+ {
170
+ "name": "Probability mass concentrated on 5",
171
+ "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
172
+ "ground_truth": "0",
173
+ "explanation": "Both CE and NTL are high because the prediction is far from correct."
174
+ },
175
+ {
176
+ "name": "Probability mass concentrated on 5",
177
+ "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
178
+ "ground_truth": "1",
179
+ "explanation": "Both CE and NTL are high because the prediction is far from correct."
180
+ },
181
+ {
182
+ "name": "Probability mass concentrated on 5",
183
+ "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
184
+ "ground_truth": "2",
185
+ "explanation": "Both CE and NTL are high because the prediction is far from correct."
186
+ },
187
+ {
188
+ "name": "Probability mass concentrated on 5",
189
+ "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
190
+ "ground_truth": "3",
191
+ "explanation": "Both CE and NTL are high because the prediction is far from correct."
192
+ },
193
+ {
194
+ "name": "Probability mass concentrated on 5",
195
+ "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
196
+ "ground_truth": "4",
197
+ "explanation": "Both CE and NTL are high because the prediction is far from correct."
198
+ },
199
+ {
200
+ "name": "Probability mass concentrated on 5",
201
+ "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
202
+ "ground_truth": "5",
203
+ "explanation": "Both CE and NTL are high because the prediction is far from correct."
204
+ },
205
+ {
206
+ "name": "Probability mass concentrated on 5",
207
+ "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
208
+ "ground_truth": "6",
209
+ "explanation": "Both CE and NTL are high because the prediction is far from correct."
210
+ },
211
+ {
212
+ "name": "Probability mass concentrated on 5",
213
+ "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
214
+ "ground_truth": "7",
215
+ "explanation": "Both CE and NTL are high because the prediction is far from correct."
216
+ },
217
+ {
218
+ "name": "Probability mass concentrated on 5",
219
+ "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
220
+ "ground_truth": "8",
221
+ "explanation": "Both CE and NTL are high because the prediction is far from correct."
222
+ },
223
+ {
224
+ "name": "Probability mass concentrated on 5",
225
+ "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
226
+ "ground_truth": "9",
227
+ "explanation": "Both CE and NTL are high because the prediction is far from correct."
228
+ },
229
+
230
+
231
+ {
232
+ "name": "Probability mass concentrated on 1",
233
+ "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
234
+ "ground_truth": "0",
235
+ "explanation": "Both losses are low because the prediction is correct."
236
+ },
237
+ {
238
+ "name": "Probability mass concentrated on 1",
239
+ "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
240
+ "ground_truth": "1",
241
+ "explanation": "Both losses are low because the prediction is correct."
242
+ },
243
+ {
244
+ "name": "Probability mass concentrated on 1",
245
+ "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
246
+ "ground_truth": "2",
247
+ "explanation": "Both losses are low because the prediction is correct."
248
+ },
249
+ {
250
+ "name": "Probability mass concentrated on 1",
251
+ "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
252
+ "ground_truth": "3",
253
+ "explanation": "Both losses are low because the prediction is correct."
254
+ },
255
+ {
256
+ "name": "Probability mass concentrated on 1",
257
+ "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
258
+ "ground_truth": "4",
259
+ "explanation": "Both losses are low because the prediction is correct."
260
+ },
261
+ {
262
+ "name": "Probability mass concentrated on 1",
263
+ "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
264
+ "ground_truth": "5",
265
+ "explanation": "Both losses are low because the prediction is correct."
266
+ },
267
+ {
268
+ "name": "Probability mass concentrated on 1",
269
+ "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
270
+ "ground_truth": "6",
271
+ "explanation": "Both losses are low because the prediction is correct."
272
+ },
273
+ {
274
+ "name": "Probability mass concentrated on 1",
275
+ "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
276
+ "ground_truth": "7",
277
+ "explanation": "Both losses are low because the prediction is correct."
278
+ },
279
+ {
280
+ "name": "Probability mass concentrated on 1",
281
+ "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
282
+ "ground_truth": "8",
283
+ "explanation": "Both losses are low because the prediction is correct."
284
+ },
285
+ {
286
+ "name": "Probability mass concentrated on 1",
287
+ "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
288
+ "ground_truth": "9",
289
+ "explanation": "Both losses are low because the prediction is correct."
290
+ },
291
+
292
+
293
+ {
294
+ "name": "Almost correct (1 vs 2)",
295
+ "values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values
296
+ "ground_truth": "0",
297
+ "explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close."
298
+ },
299
+ {
300
+ "name": "Almost correct (1 vs 2)",
301
+ "values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values
302
+ "ground_truth": "1",
303
+ "explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close."
304
+ },
305
+ {
306
+ "name": "Almost correct (1 vs 2)",
307
+ "values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values
308
+ "ground_truth": "2",
309
+ "explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close."
310
+ },
311
+ {
312
+ "name": "Almost correct (1 vs 2)",
313
+ "values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values
314
+ "ground_truth": "3",
315
+ "explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close."
316
+ }
317
+ ]
318
 
319
+ # --- Helper Functions ---
 
 
 
 
320
  def apply_scenario(step_idx):
321
+ scenario = scenarios[step_idx]
322
+ # These assignments modify session state. They must be done *before* the widgets
323
+ # are rendered in the script run that should display these new values.
324
  for i, val in enumerate(scenario["values"]):
325
  st.session_state[f"slider_{i}"] = val
326
+ st.session_state['ground_truth'] = scenario["ground_truth"]
327
 
328
+ def start_demo():
 
 
 
 
329
  st.session_state.running_demo = True
330
  st.session_state.demo_step = 0
331
  st.session_state.last_update_time = time.time()
332
+ apply_scenario(0) # Apply the first scenario's state
333
+ # The button click that calls start_demo() will itself cause a rerun.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
335
  def stop_demo():
336
  st.session_state.running_demo = False
337
 
 
338
  # --- Demo State Advancement Logic ---
339
  # This block handles advancing the demo. If it advances, it updates session state
340
  # and then reruns. This ensures widgets are drawn with the new state in the next run.
341
  if st.session_state.running_demo:
 
342
  current_time = time.time()
343
+ if current_time - st.session_state.last_update_time > 3.0: # 3 seconds per scenario
344
+ next_step = (st.session_state.demo_step + 1) % len(scenarios)
345
+ st.session_state.demo_step = next_step
346
+ apply_scenario(next_step) # Update session state for the new scenario
347
+ st.session_state.last_update_time = time.time() # Reset timer
348
+ st.rerun() # Crucial: Rerun to reflect changes in widgets and charts
 
 
 
 
349
 
350
  # --- UI Rendering ---
351
  # This section renders the main UI. It executes after any potential rerun from the block above.
352
 
353
  if st.session_state.running_demo:
354
+ st.info(f"Showing scenario {st.session_state.demo_step + 1}/{len(scenarios)}: {scenarios[st.session_state.demo_step]['name']}")
355
+ st.markdown(f"**Explanation:** {scenarios[st.session_state.demo_step]['explanation']}")
 
 
 
356
  if st.button("Stop Demo"):
357
+ stop_demo()
358
  st.rerun()
359
+ else: # Not st.session_state.running_demo
360
+ if st.button("Start Automated Demo"):
361
+ start_demo() # This calls apply_scenario(0)
362
+ st.rerun() # Rerun to enter demo mode and draw scenario 0 correctly
363
+
364
+ # Sliders and Ground Truth Selector
365
+ # These widgets will read their initial values from st.session_state.
366
+ # User interactions will update st.session_state directly due to their keys.
367
+ if not st.session_state.running_demo:
368
+ st.markdown("#### Predicted Token Probabilities")
369
+ cols = st.columns(len(options))
370
+ for i, col in enumerate(cols):
371
+ label = options[i] # Use token name directly for label
372
+ with col:
373
+ svs.vertical_slider(
374
+ label=label, min_value=0.0, max_value=1.0, step=0.01, height=50,
375
+ key=f"slider_{i}", # This key links the widget to st.session_state[f"slider_{i}"]
376
+ slider_color="green", track_color="lightgray", thumb_color="black"
377
+ )
378
+
379
+ # Ground truth selectbox
380
+ st.selectbox(
381
+ "Ground Truth Token", options=options,
382
+ index=options.index(st.session_state['ground_truth']), # Display value from session state
383
+ key='ground_truth' # Links widget to st.session_state['ground_truth']
384
+ )
385
+
386
+ # Placeholder for charts and loss calculations that will be updated
387
+ # This section always reads the current st.session_state to generate its content.
388
+
389
+ current_prob_values_from_state = [st.session_state.get(f"slider_{j}", 1.0/len(options)) for j in range(len(options))]
390
  total_from_state = sum(current_prob_values_from_state)
391
  probs_for_charts = (
392
  torch.ones(len(options)) / len(options)
 
394
  else torch.tensor([v / total_from_state for v in current_prob_values_from_state])
395
  )
396
 
397
+ gt_choice_for_charts = st.session_state.get('ground_truth', options[0])
 
 
 
 
 
398
  if gt_choice_for_charts == "Text":
399
+ gt_index_for_charts = 10 # Assuming "Text" is the 11th item (index 10)
400
  gt_numeric_for_charts = None
401
  else:
402
  gt_index_for_charts = int(gt_choice_for_charts)
403
  gt_numeric_for_charts = gt_index_for_charts
404
 
405
+ st.markdown("#### Input Probability Distribution")
406
+ df_dist = pd.DataFrame({"token": options, "probability": probs_for_charts.numpy()})
407
+ df_dist["type"] = ["Ground Truth" if token == gt_choice_for_charts else "Prediction" for token in options]
408
+ chart = (
409
+ alt.Chart(df_dist).mark_bar().encode(
410
+ x=alt.X("token:N", title="Token", sort=options), # Ensure consistent sort order
411
+ y=alt.Y("probability:Q", title="Probability", scale=alt.Scale(domain=[0, 1])),
412
+ color=alt.Color("type:N", scale=alt.Scale(domain=["Ground Truth", "Prediction"], range=["green", "steelblue"]), legend=alt.Legend(title="Token Type"))
413
+ ).properties(height=300)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  )
415
+ st.altair_chart(chart, use_container_width=True)
416
+
417
+ ce_loss = -torch.log(torch.clamp(probs_for_charts[gt_index_for_charts], min=1e-9))
418
+ if gt_numeric_for_charts is None: # Text token
419
+ ntl_mse_loss = torch.tensor(float('nan')) # MSE not applicable for text
420
+ ntl_was_loss = torch.tensor(float('nan')) # WAS not applicable for text
421
+ else: # Numeric token
422
+ numeric_probs_for_loss = probs_for_charts[:10] # Probabilities for 0-9
423
+ # Ensure numeric_probs_for_loss sums to 1 for NTL calculations if it's a subset
424
+ numeric_probs_sum = torch.sum(numeric_probs_for_loss)
425
+ if numeric_probs_sum > 1e-6 : # Avoid division by zero
426
+ normalized_numeric_probs = numeric_probs_for_loss / numeric_probs_sum
427
+ else:
428
+ normalized_numeric_probs = torch.zeros_like(numeric_probs_for_loss)
429
 
 
 
 
 
 
 
 
 
 
 
 
 
430
 
431
+ loss_values_tensor = torch.arange(0, 10, dtype=torch.float32)
 
 
 
 
 
 
 
 
 
 
 
432
 
433
+ # Use normalized probabilities for NTL if only considering numeric tokens
434
+ if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6 :
435
+ pred_value = torch.sum( (probs_for_charts[:10]/torch.sum(probs_for_charts[:10])) * loss_values_tensor)
436
+ elif gt_choice_for_charts != "Text": # if sum is zero, pred_value is ill-defined or 0
437
+ pred_value = torch.tensor(0.0)
438
+ else: # Should not happen if gt_numeric_for_charts is not None
439
+ pred_value = torch.tensor(float('nan'))
440
 
 
 
441
 
442
+ if not torch.isnan(pred_value):
443
+ ntl_mse_loss = (pred_value - float(gt_numeric_for_charts)) ** 2
444
+ abs_diff = torch.abs(loss_values_tensor - float(gt_numeric_for_charts))
445
+ if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6:
446
+ ntl_was_loss = torch.sum((probs_for_charts[:10]/torch.sum(probs_for_charts[:10])) * abs_diff)
447
+ elif gt_choice_for_charts != "Text":
448
+ ntl_was_loss = torch.tensor(0.0) # Or some other default if all numeric probs are zero
449
+ else:
450
+ ntl_was_loss = torch.tensor(float('nan'))
451
+ else:
452
+ ntl_mse_loss = torch.tensor(float('nan'))
453
+ ntl_was_loss = torch.tensor(float('nan'))
454
 
 
 
 
 
 
 
 
455
 
456
+ ce_val = round(ce_loss.item(), 3)
457
+ mse_val = round(ntl_mse_loss.item(), 3) if not torch.isnan(ntl_mse_loss) else "N/A"
458
+ was_val = round(ntl_was_loss.item(), 3) if not torch.isnan(ntl_was_loss) else "N/A"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
 
460
 
461
  loss_data = {"Loss": ["Cross Entropy"], "Value": [ce_val]}
462
  if was_val != "N/A":
463
  loss_data["Loss"].append("NTL-WAS")
464
  loss_data["Value"].append(was_val)
465
+ if mse_val != "N/A":
466
+ loss_data["Loss"].append("NTL-MSE")
467
+ loss_data["Value"].append(mse_val)
468
 
469
  loss_df = pd.DataFrame(loss_data)
470
 
 
 
 
 
 
 
 
 
471
  # ============== Chart Display ==============
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  # Create a single chart for loss visualization
473
+ st.subheader("Loss Comparison")
474
+
475
+ # Create an Altair chart that will look good and redraw cleanly
476
+ chart = alt.Chart(loss_df).mark_bar().encode(
477
+ x=alt.X('Loss:N', sort=loss_df["Loss"].tolist()),
478
+ y=alt.Y('Value:Q', scale=alt.Scale(domain=[0, max(loss_df["Value"].max() * 1.2, 20 if st.session_state.running_demo else 0.5)])),
479
+ color=alt.Color('Loss:N', scale=alt.Scale(
480
+ domain=['Cross Entropy', 'NTL-WAS', 'NTL-MSE'],
481
+ range=['steelblue', 'red', 'forestgreen']
482
+ )),
483
+ tooltip=['Loss', 'Value']
484
+ ).properties(
485
+ height=300
486
+ )
487
 
488
+ # Add value labels on top of bars
489
+ text = chart.mark_text(
490
+ align='center',
491
+ baseline='bottom',
492
+ dy=-5,
493
+ fontSize=14
494
+ ).encode(
495
+ text=alt.Text('Value:Q', format='.3f')
496
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
 
498
+ # Combine chart and text
499
+ final_chart = (chart + text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
 
501
  # Display chart with the full container width
502
+ st.altair_chart(final_chart, use_container_width=True)
503
 
504
  # --- Polling Rerun for Demo Mode ---
505
  # If the demo is running and we haven't just advanced (which would have caused a rerun),
 
507
  if st.session_state.running_demo:
508
  # This check is implicitly: if we are here and demo is running, it means
509
  # the time-based advance condition was NOT met in the block at the top.
510
+ time.sleep(0.1) # Adjusted from 0.2 to 0.5 (or try 1.0)
511
  st.rerun()
512
 
513
+ # Add explanation of the demonstration
514
  st.markdown("""
515
+ ### What Does This Demo Show?
 
 
516
 
517
+ - **Cross Entropy Loss**: Only cares if the prediction is exactly right or wrong - it doesn't consider how "close" a numerical prediction is.
518
+ - **Number Token Loss (NTL)**: Considers numerical proximity - predicting "7" when the true value is "8" is better than predicting "2".
519
  """)
520
 
521
+ # References / resources section with links (common to both modes)
522
+ st.markdown("### Resources")
523
  st.markdown("""
524
+ - [Paper: Number Token Loss (ArXiv)](https://arxiv.org/abs/2411.02083)
525
+ - [GitHub: Number Token Loss](https://github.com/tum-ai/number-token-loss)
 
526
  """)