jannisborn commited on
Commit
2854521
ยท
unverified ยท
1 Parent(s): be74e38
Files changed (1) hide show
  1. src/streamlit_app.py +386 -426
src/streamlit_app.py CHANGED
@@ -1,392 +1,198 @@
 
 
 
1
  import altair as alt
 
2
  import pandas as pd
 
3
  import streamlit_vertical_slider as svs
4
  import torch
5
- # from streamlit_vertical_slider import vertical_slider # Not directly used, svs.vertical_slider is
6
- import streamlit as st
7
- import time
8
- import plotly.graph_objects as go # Add Plotly import
 
 
 
 
 
 
9
 
10
  # Define options globally as it's used in initialization and UI
11
  options = [str(i) for i in range(10)] + ["Text"]
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # --- Session State Initialization ---
14
  # Ensure all session state variables are initialized before first use, especially by widgets.
15
- if 'running_demo' not in st.session_state:
16
  st.session_state.running_demo = False
17
- if 'demo_step' not in st.session_state:
18
  st.session_state.demo_step = 0
19
- if 'last_update_time' not in st.session_state:
20
  st.session_state.last_update_time = 0
21
- if 'loss_container' not in st.session_state:
22
  st.session_state.loss_container = None
23
- if 'previous_chart_html' not in st.session_state:
24
  st.session_state.previous_chart_html = ""
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  # Initialize states for sliders and ground_truth selector
27
  # Using len(options) to correctly size for 0-9 + "Text"
28
  for i in range(len(options)):
29
  if f"slider_{i}" not in st.session_state:
30
- st.session_state[f"slider_{i}"] = 1.0 / len(options)
31
- if 'ground_truth' not in st.session_state:
32
- st.session_state['ground_truth'] = options[0] # Default to "0"
 
 
 
 
 
33
 
 
 
 
 
 
 
 
34
 
35
- st.title("Number Token Loss - Demo")
36
 
37
  st.markdown("""
38
- Adjust the sliders to set a predicted probability for each token (0-9 and "Text").
39
- The sliders are vertical and compact. The app normalizes the slider values
40
- to form a valid probability distribution, visualizes it, and computes the corresponding
41
- Cross Entropy, NTL-MSE, and NTL-WAS losses.
 
 
 
42
  """)
43
 
44
- # --- Scenario Definitions ---
45
- scenarios = [
46
- {
47
- "name": "Probability mass at 0",
48
- "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
49
- "ground_truth": "0",
50
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
51
- },
52
- {
53
- "name": "Probability mass at 0",
54
- "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
55
- "ground_truth": "1",
56
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
57
- },
58
- {
59
- "name": "Probability mass at 0",
60
- "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
61
- "ground_truth": "2",
62
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
63
- },
64
- {
65
- "name": "Probability mass at 0",
66
- "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
67
- "ground_truth": "3",
68
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
69
- },
70
- {
71
- "name": "Probability mass at 0",
72
- "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
73
- "ground_truth": "4",
74
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
75
- },
76
- {
77
- "name": "Probability mass at 0",
78
- "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
79
- "ground_truth": "5",
80
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
81
- },
82
- {
83
- "name": "Probability mass at 0",
84
- "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
85
- "ground_truth": "6",
86
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
87
- },
88
- {
89
- "name": "Probability mass at 0",
90
- "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
91
- "ground_truth": "7",
92
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
93
- },
94
- {
95
- "name": "Probability mass at 0",
96
- "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
97
- "ground_truth": "8",
98
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
99
- },
100
- {
101
- "name": "Probability mass at 0",
102
- "values": [0.3, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.0], # 11 values
103
- "ground_truth": "9",
104
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
105
- },
106
-
107
-
108
- {
109
- "name": "Probability mass around 5",
110
- "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
111
- "ground_truth": "0",
112
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
113
- },
114
- {
115
- "name": "Probability mass around 5",
116
- "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
117
- "ground_truth": "1",
118
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
119
- },
120
- {
121
- "name": "Probability mass around 5",
122
- "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
123
- "ground_truth": "2",
124
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
125
- },
126
- {
127
- "name": "Probability mass around 5",
128
- "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
129
- "ground_truth": "3",
130
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
131
- },
132
- {
133
- "name": "Probability mass around 5",
134
- "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
135
- "ground_truth": "4",
136
- "explanation": "Cross Entropy does not penalize if the prediction is far from the ground truth."
137
- },
138
- {
139
- "name": "Probability mass around ground truth (5)",
140
- "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
141
- "ground_truth": "5",
142
- "explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth."
143
- },
144
- {
145
- "name": "Probability mass around 5",
146
- "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
147
- "ground_truth": "6",
148
- "explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth."
149
- },
150
- {
151
- "name": "Probability mass around 5",
152
- "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
153
- "ground_truth": "7",
154
- "explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth."
155
- },
156
- {
157
- "name": "Probability mass around 5",
158
- "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
159
- "ground_truth": "8",
160
- "explanation": "Cross Entropy is high, NTL is higher but still penalizes less than CE because distribution knows it's a number."
161
- },
162
- {
163
- "name": "Probability mass around 5",
164
- "values": [0.05, 0.05, 0.05, 0.1, 0.2, 0.3, 0.15, 0.05, 0.03, 0.02, 0.0], # 11 values
165
- "ground_truth": "9",
166
- "explanation": "Cross Entropy is moderate, NTL is low because predictions are close to ground truth."
167
- },
168
-
169
- {
170
- "name": "Probability mass concentrated on 5",
171
- "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
172
- "ground_truth": "0",
173
- "explanation": "Both CE and NTL are high because the prediction is far from correct."
174
- },
175
- {
176
- "name": "Probability mass concentrated on 5",
177
- "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
178
- "ground_truth": "1",
179
- "explanation": "Both CE and NTL are high because the prediction is far from correct."
180
- },
181
- {
182
- "name": "Probability mass concentrated on 5",
183
- "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
184
- "ground_truth": "2",
185
- "explanation": "Both CE and NTL are high because the prediction is far from correct."
186
- },
187
- {
188
- "name": "Probability mass concentrated on 5",
189
- "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
190
- "ground_truth": "3",
191
- "explanation": "Both CE and NTL are high because the prediction is far from correct."
192
- },
193
- {
194
- "name": "Probability mass concentrated on 5",
195
- "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
196
- "ground_truth": "4",
197
- "explanation": "Both CE and NTL are high because the prediction is far from correct."
198
- },
199
- {
200
- "name": "Probability mass concentrated on 5",
201
- "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
202
- "ground_truth": "5",
203
- "explanation": "Both CE and NTL are high because the prediction is far from correct."
204
- },
205
- {
206
- "name": "Probability mass concentrated on 5",
207
- "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
208
- "ground_truth": "6",
209
- "explanation": "Both CE and NTL are high because the prediction is far from correct."
210
- },
211
- {
212
- "name": "Probability mass concentrated on 5",
213
- "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
214
- "ground_truth": "7",
215
- "explanation": "Both CE and NTL are high because the prediction is far from correct."
216
- },
217
- {
218
- "name": "Probability mass concentrated on 5",
219
- "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
220
- "ground_truth": "8",
221
- "explanation": "Both CE and NTL are high because the prediction is far from correct."
222
- },
223
- {
224
- "name": "Probability mass concentrated on 5",
225
- "values": [0.05, 0.05, 0.05, 0.05, 0.05, 0.3, 0.2, 0.15, 0.05, 0.05, 0.0], # 11 values
226
- "ground_truth": "9",
227
- "explanation": "Both CE and NTL are high because the prediction is far from correct."
228
- },
229
-
230
-
231
- {
232
- "name": "Probability mass concentrated on 1",
233
- "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
234
- "ground_truth": "0",
235
- "explanation": "Both losses are low because the prediction is correct."
236
- },
237
- {
238
- "name": "Probability mass concentrated on 1",
239
- "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
240
- "ground_truth": "1",
241
- "explanation": "Both losses are low because the prediction is correct."
242
- },
243
- {
244
- "name": "Probability mass concentrated on 1",
245
- "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
246
- "ground_truth": "2",
247
- "explanation": "Both losses are low because the prediction is correct."
248
- },
249
- {
250
- "name": "Probability mass concentrated on 1",
251
- "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
252
- "ground_truth": "3",
253
- "explanation": "Both losses are low because the prediction is correct."
254
- },
255
- {
256
- "name": "Probability mass concentrated on 1",
257
- "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
258
- "ground_truth": "4",
259
- "explanation": "Both losses are low because the prediction is correct."
260
- },
261
- {
262
- "name": "Probability mass concentrated on 1",
263
- "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
264
- "ground_truth": "5",
265
- "explanation": "Both losses are low because the prediction is correct."
266
- },
267
- {
268
- "name": "Probability mass concentrated on 1",
269
- "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
270
- "ground_truth": "6",
271
- "explanation": "Both losses are low because the prediction is correct."
272
- },
273
- {
274
- "name": "Probability mass concentrated on 1",
275
- "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
276
- "ground_truth": "7",
277
- "explanation": "Both losses are low because the prediction is correct."
278
- },
279
- {
280
- "name": "Probability mass concentrated on 1",
281
- "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
282
- "ground_truth": "8",
283
- "explanation": "Both losses are low because the prediction is correct."
284
- },
285
- {
286
- "name": "Probability mass concentrated on 1",
287
- "values": [0.05, 0.7, 0.05, 0.05, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02, 0.0], # 11 values
288
- "ground_truth": "9",
289
- "explanation": "Both losses are low because the prediction is correct."
290
- },
291
-
292
-
293
- {
294
- "name": "Almost correct (1 vs 2)",
295
- "values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values
296
- "ground_truth": "0",
297
- "explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close."
298
- },
299
- {
300
- "name": "Almost correct (1 vs 2)",
301
- "values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values
302
- "ground_truth": "1",
303
- "explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close."
304
- },
305
- {
306
- "name": "Almost correct (1 vs 2)",
307
- "values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values
308
- "ground_truth": "2",
309
- "explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close."
310
- },
311
- {
312
- "name": "Almost correct (1 vs 2)",
313
- "values": [0.1, 0.1, 0.7, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 11 values
314
- "ground_truth": "3",
315
- "explanation": "CE penalizes harshly, but NTL-WAS remains low because prediction is numerically close."
316
- }
317
- ]
318
 
319
- # --- Helper Functions ---
 
 
 
 
320
  def apply_scenario(step_idx):
321
- scenario = scenarios[step_idx]
322
- # These assignments modify session state. They must be done *before* the widgets
323
- # are rendered in the script run that should display these new values.
324
  for i, val in enumerate(scenario["values"]):
325
  st.session_state[f"slider_{i}"] = val
326
- st.session_state['ground_truth'] = scenario["ground_truth"]
327
 
328
- def start_demo():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  st.session_state.running_demo = True
330
  st.session_state.demo_step = 0
331
  st.session_state.last_update_time = time.time()
332
- apply_scenario(0) # Apply the first scenario's state
333
- # The button click that calls start_demo() will itself cause a rerun.
 
 
 
 
 
 
 
 
 
 
 
 
334
 
335
  def stop_demo():
336
  st.session_state.running_demo = False
337
 
 
338
  # --- Demo State Advancement Logic ---
339
  # This block handles advancing the demo. If it advances, it updates session state
340
  # and then reruns. This ensures widgets are drawn with the new state in the next run.
341
  if st.session_state.running_demo:
 
342
  current_time = time.time()
343
- if current_time - st.session_state.last_update_time > 3.0: # 3 seconds per scenario
344
- next_step = (st.session_state.demo_step + 1) % len(scenarios)
345
- st.session_state.demo_step = next_step
346
- apply_scenario(next_step) # Update session state for the new scenario
347
- st.session_state.last_update_time = time.time() # Reset timer
348
- st.rerun() # Crucial: Rerun to reflect changes in widgets and charts
 
 
 
 
349
 
350
  # --- UI Rendering ---
351
  # This section renders the main UI. It executes after any potential rerun from the block above.
352
 
353
  if st.session_state.running_demo:
354
- st.info(f"Showing scenario {st.session_state.demo_step + 1}/{len(scenarios)}: {scenarios[st.session_state.demo_step]['name']}")
355
- st.markdown(f"**Explanation:** {scenarios[st.session_state.demo_step]['explanation']}")
 
 
 
356
  if st.button("Stop Demo"):
357
- stop_demo()
358
  st.rerun()
359
- else: # Not st.session_state.running_demo
360
- if st.button("Start Automated Demo"):
361
- start_demo() # This calls apply_scenario(0)
362
- st.rerun() # Rerun to enter demo mode and draw scenario 0 correctly
363
-
364
- # Sliders and Ground Truth Selector
365
- # These widgets will read their initial values from st.session_state.
366
- # User interactions will update st.session_state directly due to their keys.
367
- if not st.session_state.running_demo:
368
- st.markdown("#### Predicted Token Probabilities")
369
- cols = st.columns(len(options))
370
- for i, col in enumerate(cols):
371
- label = options[i] # Use token name directly for label
372
- with col:
373
- svs.vertical_slider(
374
- label=label, min_value=0.0, max_value=1.0, step=0.01, height=50,
375
- key=f"slider_{i}", # This key links the widget to st.session_state[f"slider_{i}"]
376
- slider_color="green", track_color="lightgray", thumb_color="black"
377
- )
378
-
379
- # Ground truth selectbox
380
- st.selectbox(
381
- "Ground Truth Token", options=options,
382
- index=options.index(st.session_state['ground_truth']), # Display value from session state
383
- key='ground_truth' # Links widget to st.session_state['ground_truth']
384
- )
385
-
386
- # Placeholder for charts and loss calculations that will be updated
387
- # This section always reads the current st.session_state to generate its content.
388
-
389
- current_prob_values_from_state = [st.session_state.get(f"slider_{j}", 1.0/len(options)) for j in range(len(options))]
390
  total_from_state = sum(current_prob_values_from_state)
391
  probs_for_charts = (
392
  torch.ones(len(options)) / len(options)
@@ -394,112 +200,265 @@ probs_for_charts = (
394
  else torch.tensor([v / total_from_state for v in current_prob_values_from_state])
395
  )
396
 
397
- gt_choice_for_charts = st.session_state.get('ground_truth', options[0])
 
 
 
 
 
398
  if gt_choice_for_charts == "Text":
399
- gt_index_for_charts = 10 # Assuming "Text" is the 11th item (index 10)
400
  gt_numeric_for_charts = None
401
  else:
402
  gt_index_for_charts = int(gt_choice_for_charts)
403
  gt_numeric_for_charts = gt_index_for_charts
404
 
405
- st.markdown("#### Input Probability Distribution")
406
- df_dist = pd.DataFrame({"token": options, "probability": probs_for_charts.numpy()})
407
- df_dist["type"] = ["Ground Truth" if token == gt_choice_for_charts else "Prediction" for token in options]
408
- chart = (
409
- alt.Chart(df_dist).mark_bar().encode(
410
- x=alt.X("token:N", title="Token", sort=options), # Ensure consistent sort order
411
- y=alt.Y("probability:Q", title="Probability", scale=alt.Scale(domain=[0, 1])),
412
- color=alt.Color("type:N", scale=alt.Scale(domain=["Ground Truth", "Prediction"], range=["green", "steelblue"]), legend=alt.Legend(title="Token Type"))
413
- ).properties(height=300)
414
- )
415
- st.altair_chart(chart, use_container_width=True)
416
-
417
- ce_loss = -torch.log(torch.clamp(probs_for_charts[gt_index_for_charts], min=1e-9))
418
- if gt_numeric_for_charts is None: # Text token
419
- ntl_mse_loss = torch.tensor(float('nan')) # MSE not applicable for text
420
- ntl_was_loss = torch.tensor(float('nan')) # WAS not applicable for text
421
- else: # Numeric token
422
- numeric_probs_for_loss = probs_for_charts[:10] # Probabilities for 0-9
423
- # Ensure numeric_probs_for_loss sums to 1 for NTL calculations if it's a subset
424
- numeric_probs_sum = torch.sum(numeric_probs_for_loss)
425
- if numeric_probs_sum > 1e-6 : # Avoid division by zero
426
- normalized_numeric_probs = numeric_probs_for_loss / numeric_probs_sum
427
- else:
428
- normalized_numeric_probs = torch.zeros_like(numeric_probs_for_loss)
429
 
 
430
 
431
- loss_values_tensor = torch.arange(0, 10, dtype=torch.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
 
433
- # Use normalized probabilities for NTL if only considering numeric tokens
434
- if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6 :
435
- pred_value = torch.sum( (probs_for_charts[:10]/torch.sum(probs_for_charts[:10])) * loss_values_tensor)
436
- elif gt_choice_for_charts != "Text": # if sum is zero, pred_value is ill-defined or 0
437
- pred_value = torch.tensor(0.0)
438
- else: # Should not happen if gt_numeric_for_charts is not None
439
- pred_value = torch.tensor(float('nan'))
440
 
441
 
442
- if not torch.isnan(pred_value):
443
- ntl_mse_loss = (pred_value - float(gt_numeric_for_charts)) ** 2
444
- abs_diff = torch.abs(loss_values_tensor - float(gt_numeric_for_charts))
445
- if gt_choice_for_charts != "Text" and torch.sum(probs_for_charts[:10]) > 1e-6:
446
- ntl_was_loss = torch.sum((probs_for_charts[:10]/torch.sum(probs_for_charts[:10])) * abs_diff)
447
- elif gt_choice_for_charts != "Text":
448
- ntl_was_loss = torch.tensor(0.0) # Or some other default if all numeric probs are zero
449
- else:
450
- ntl_was_loss = torch.tensor(float('nan'))
451
- else:
452
- ntl_mse_loss = torch.tensor(float('nan'))
453
- ntl_was_loss = torch.tensor(float('nan'))
454
 
455
 
456
- ce_val = round(ce_loss.item(), 3)
457
- mse_val = round(ntl_mse_loss.item(), 3) if not torch.isnan(ntl_mse_loss) else "N/A"
458
- was_val = round(ntl_was_loss.item(), 3) if not torch.isnan(ntl_was_loss) else "N/A"
 
 
 
 
459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
 
461
  loss_data = {"Loss": ["Cross Entropy"], "Value": [ce_val]}
462
  if was_val != "N/A":
463
  loss_data["Loss"].append("NTL-WAS")
464
  loss_data["Value"].append(was_val)
465
- if mse_val != "N/A":
466
- loss_data["Loss"].append("NTL-MSE")
467
- loss_data["Value"].append(mse_val)
468
 
469
  loss_df = pd.DataFrame(loss_data)
470
 
 
 
 
 
 
 
 
 
471
  # ============== Chart Display ==============
472
- # Create a single chart for loss visualization
473
- st.subheader("Loss Comparison")
474
-
475
- # Create an Altair chart that will look good and redraw cleanly
476
- chart = alt.Chart(loss_df).mark_bar().encode(
477
- x=alt.X('Loss:N', sort=loss_df["Loss"].tolist()),
478
- y=alt.Y('Value:Q', scale=alt.Scale(domain=[0, max(loss_df["Value"].max() * 1.2, 20 if st.session_state.running_demo else 0.5)])),
479
- color=alt.Color('Loss:N', scale=alt.Scale(
480
- domain=['Cross Entropy', 'NTL-WAS', 'NTL-MSE'],
481
- range=['steelblue', 'red', 'forestgreen']
482
- )),
483
- tooltip=['Loss', 'Value']
484
- ).properties(
485
- height=300
486
- )
487
 
488
- # Add value labels on top of bars
489
- text = chart.mark_text(
490
- align='center',
491
- baseline='bottom',
492
- dy=-5,
493
- fontSize=14
494
- ).encode(
495
- text=alt.Text('Value:Q', format='.3f')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
  )
 
497
 
498
- # Combine chart and text
499
- final_chart = (chart + text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
 
501
  # Display chart with the full container width
502
- st.altair_chart(final_chart, use_container_width=True)
503
 
504
  # --- Polling Rerun for Demo Mode ---
505
  # If the demo is running and we haven't just advanced (which would have caused a rerun),
@@ -507,20 +466,21 @@ st.altair_chart(final_chart, use_container_width=True)
507
  if st.session_state.running_demo:
508
  # This check is implicitly: if we are here and demo is running, it means
509
  # the time-based advance condition was NOT met in the block at the top.
510
- time.sleep(0.1) # Adjusted from 0.2 to 0.5 (or try 1.0)
511
  st.rerun()
512
 
513
- # Add explanation of the demonstration
514
  st.markdown("""
515
- ### What Does This Demo Show?
 
 
516
 
517
- - **Cross Entropy Loss**: Only cares if the prediction is exactly right or wrong - it doesn't consider how "close" a numerical prediction is.
518
- - **Number Token Loss (NTL)**: Considers numerical proximity - predicting "7" when the true value is "8" is better than predicting "2".
519
  """)
520
 
521
- # References / resources section with links (common to both modes)
522
- st.markdown("### Resources")
523
  st.markdown("""
524
- - [Paper: Number Token Loss (ArXiv)](https://arxiv.org/abs/2411.02083)
525
- - [GitHub: Number Token Loss](https://github.com/tum-ai/number-token-loss)
 
526
  """)
 
1
+ import logging
2
+ import time
3
+
4
  import altair as alt
5
+ import numpy as np
6
  import pandas as pd
7
+ import streamlit as st
8
  import streamlit_vertical_slider as svs
9
  import torch
10
+
11
+ from scenarios import dirac, gauss, make_bimodal_scenarios
12
+
13
+ logging.getLogger("streamlit.watcher.local_sources_watcher").setLevel(logging.ERROR)
14
+
15
+ DEMO_INTERVAL = 0.75
16
+ CE_SCALING = 0.25
17
+ MAX_LOSS_PLOT = 6
18
+ LAST_STEP = -1
19
+
20
 
21
  # Define options globally as it's used in initialization and UI
22
  options = [str(i) for i in range(10)] + ["Text"]
23
 
24
+
25
+ def compute_losses(probs: torch.Tensor, gt_token: str) -> tuple[float, float, float]:
26
+ """Compute CE, NTL-MAE, NTL-WAS losses for the given probability vector and ground truth token."""
27
+ ce_loss = CE_SCALING * -torch.log(
28
+ torch.clamp(probs[options.index(gt_token)], min=1e-9)
29
+ )
30
+
31
+ numeric_mass = probs[:10].sum()
32
+
33
+ if gt_token == "Text" or numeric_mass < 1e-6:
34
+ return ce_loss.item(), 0.0, 0.0
35
+
36
+ gt_numeric = int(gt_token)
37
+ token_vals = torch.arange(10, dtype=torch.float32)
38
+ mae = numeric_mass * abs(torch.dot(token_vals, probs[:10]) - gt_numeric)
39
+ was = numeric_mass * torch.dot(probs[:10], torch.abs(token_vals - gt_numeric))
40
+ return round(ce_loss.item(), 3), round(mae.item(), 3), round(was.item(), 3)
41
+
42
+
43
  # --- Session State Initialization ---
44
  # Ensure all session state variables are initialized before first use, especially by widgets.
45
+ if "running_demo" not in st.session_state:
46
  st.session_state.running_demo = False
47
+ if "demo_step" not in st.session_state:
48
  st.session_state.demo_step = 0
49
+ if "last_update_time" not in st.session_state:
50
  st.session_state.last_update_time = 0
51
+ if "loss_container" not in st.session_state:
52
  st.session_state.loss_container = None
53
+ if "previous_chart_html" not in st.session_state:
54
  st.session_state.previous_chart_html = ""
55
+ if "active_scenarios" not in st.session_state:
56
+ # default if you want one to load on first show
57
+ st.session_state.active_scenarios = dirac
58
+ if "loss_history" not in st.session_state:
59
+ st.session_state.loss_history = []
60
+ if "df_loss_plot" not in st.session_state:
61
+ # Initialize an empty DataFrame for loss history
62
+ st.session_state.df_loss_plot = pd.DataFrame(
63
+ columns=["step", "x_val", "Loss Type", "Loss Value"]
64
+ )
65
+
66
 
67
  # Initialize states for sliders and ground_truth selector
68
  # Using len(options) to correctly size for 0-9 + "Text"
69
  for i in range(len(options)):
70
  if f"slider_{i}" not in st.session_state:
71
+ st.session_state[f"slider_{i}"] = 0
72
+ if "ground_truth" not in st.session_state:
73
+ st.session_state["ground_truth"] = options[5]
74
+ if "manual_ground_truth" not in st.session_state:
75
+ st.session_state["manual_ground_truth"] = options[5]
76
+ if "demo_name" not in st.session_state:
77
+ st.session_state["demo_name"] = "Dirac"
78
+
79
 
80
+ st.title("NTL -- The Number Token Loss ๐Ÿš€")
81
+
82
+ st.markdown(
83
+ """This is the interactive demo for our [ICML 2025](https://arxiv.org/abs/2411.02083) paper!๐ŸŽ‰
84
+ โžก๏ธ NTL augments cross-entropy to help LMs reason better with numbers ๐Ÿง 
85
+ """
86
+ )
87
 
88
+ st.subheader("Demo 1 โ€” NTL vs. Cross Entropy in 3 Scenarios")
89
 
90
  st.markdown("""
91
+ 1๏ธโƒฃ Pick a ground truth token: a digit (0โ€“9) or "Text" ๐Ÿ“ (simulates generic text tokens).
92
+ 2๏ธโƒฃ Choose a demo:
93
+ - **Dirac** โšก: All probability mass on one token.
94
+ - **Gaussian** ๐ŸŒŠ: Soft bell-curve around the true number.
95
+ - **Bimodal** ๐ŸŽฏ: Two peaks moving away from the target.
96
+
97
+ Watch how losses evolve as predictions get worse โ€” and see how NTL shines compared to CE! ๐ŸŒŸ
98
  """)
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ if "ground_truth" not in st.session_state:
102
+ st.session_state["ground_truth"] = "4"
103
+ gt = st.selectbox("Ground Truth Token", options=options, key="ground_truth")
104
+
105
+
106
  def apply_scenario(step_idx):
107
+ scenario = st.session_state.active_scenarios[step_idx]
 
 
108
  for i, val in enumerate(scenario["values"]):
109
  st.session_state[f"slider_{i}"] = val
 
110
 
111
+
112
+ def start_dirac_demo():
113
+ st.session_state.loss_history = []
114
+ st.session_state.active_scenarios = dirac
115
+ st.session_state.demo_name = "Dirac"
116
+ st.session_state.running_demo = True
117
+ st.session_state.demo_step = 0
118
+ st.session_state.last_update_time = time.time()
119
+ apply_scenario(0)
120
+
121
+
122
+ def start_gauss_demo():
123
+ st.session_state.loss_history = []
124
+ st.session_state.active_scenarios = gauss
125
+ st.session_state.demo_name = "Gauss"
126
  st.session_state.running_demo = True
127
  st.session_state.demo_step = 0
128
  st.session_state.last_update_time = time.time()
129
+ apply_scenario(0)
130
+
131
+
132
+ def start_bimodal_demo():
133
+ st.session_state.loss_history = []
134
+ gt = st.session_state["ground_truth"]
135
+ st.session_state.active_scenarios = make_bimodal_scenarios(gt, options)
136
+
137
+ st.session_state.demo_name = f"Bimodal (GT={gt})"
138
+ st.session_state.running_demo = True
139
+ st.session_state.demo_step = 0
140
+ st.session_state.last_update_time = time.time()
141
+ apply_scenario(0)
142
+
143
 
144
  def stop_demo():
145
  st.session_state.running_demo = False
146
 
147
+
148
  # --- Demo State Advancement Logic ---
149
  # This block handles advancing the demo. If it advances, it updates session state
150
  # and then reruns. This ensures widgets are drawn with the new state in the next run.
151
  if st.session_state.running_demo:
152
+ scenario = st.session_state.active_scenarios
153
  current_time = time.time()
154
+ if current_time - st.session_state.last_update_time > DEMO_INTERVAL:
155
+ # if we havenโ€™t yet shown the last scenario, advance
156
+ if st.session_state.demo_step < len(scenario) - 1:
157
+ st.session_state.demo_step += 1
158
+ apply_scenario(st.session_state.demo_step)
159
+ st.session_state.last_update_time = current_time
160
+ # st.rerun() # not needed, leading to too many reruns
161
+ else:
162
+ # we just displayed the final case โ†’ stop
163
+ st.session_state.running_demo = False
164
 
165
  # --- UI Rendering ---
166
  # This section renders the main UI. It executes after any potential rerun from the block above.
167
 
168
  if st.session_state.running_demo:
169
+ st.info(
170
+ f"Showing scenario {st.session_state.demo_step + 1}"
171
+ f"/{len(st.session_state.active_scenarios)}: "
172
+ f"{st.session_state.active_scenarios[st.session_state.demo_step]['name']}"
173
+ )
174
  if st.button("Stop Demo"):
175
+ st.session_state.running_demo = False
176
  st.rerun()
177
+ else:
178
+ col1, col2, col3 = st.columns(3)
179
+ with col1:
180
+ if st.button("Run: Dirac"):
181
+ start_dirac_demo()
182
+ st.rerun()
183
+ with col2:
184
+ if st.button("Run: Gauss"):
185
+ start_gauss_demo()
186
+ st.rerun()
187
+ with col3:
188
+ if st.button("Run: Bimodal"):
189
+ start_bimodal_demo()
190
+ st.rerun()
191
+
192
+ current_prob_values_from_state = [
193
+ st.session_state.get(f"slider_{j}", 0)
194
+ for j in range(len(options)) # 1.0 / len(options)) for j in range(len(options))
195
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
196
  total_from_state = sum(current_prob_values_from_state)
197
  probs_for_charts = (
198
  torch.ones(len(options)) / len(options)
 
200
  else torch.tensor([v / total_from_state for v in current_prob_values_from_state])
201
  )
202
 
203
+ # Use manual GT token when not in running demo
204
+ gt_choice_for_charts = (
205
+ st.session_state["manual_ground_truth"]
206
+ if not st.session_state.running_demo
207
+ else st.session_state["ground_truth"]
208
+ )
209
  if gt_choice_for_charts == "Text":
210
+ gt_index_for_charts = 10 # Assuming "Text" is the 11th item (index 10)
211
  gt_numeric_for_charts = None
212
  else:
213
  gt_index_for_charts = int(gt_choice_for_charts)
214
  gt_numeric_for_charts = gt_index_for_charts
215
 
216
+ gt = st.session_state["ground_truth"]
217
+ demo_name = st.session_state["demo_name"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
+ st.markdown(f'#### Predicted distribution (<span style="color:darkgreen;">ground truth: {gt}</span>)', unsafe_allow_html=True)
220
 
221
+ df_dist = pd.DataFrame(
222
+ {"token": options, "probability": probs_for_charts.numpy().round(2)}
223
+ )
224
+ df_dist["is_gt"] = df_dist["token"] == gt
225
+
226
+ bars = (
227
+ alt.Chart(df_dist)
228
+ .mark_bar(color="dodgerblue", size=40)
229
+ .encode(
230
+ x=alt.X(
231
+ "token:N",
232
+ title="Token",
233
+ sort=options,
234
+ axis=alt.Axis(
235
+ labelAngle=0,
236
+ labelFontSize=14,
237
+ titleFontSize=16,
238
+ labelAlign="center",
239
+ labelFlush=False,
240
+ ),
241
+ ),
242
+ color=alt.condition(
243
+ "datum.is_gt",
244
+ alt.value("darkgreen"), # color for ground truth
245
+ alt.value("dodgerblue") # color for others
246
+ ),
247
+ y=alt.Y(
248
+ "probability:Q",
249
+ title="Probability",
250
+ scale=alt.Scale(domain=[0, 1]),
251
+ axis=alt.Axis(format=".2f", labelFontSize=14, titleFontSize=16),
252
+ ),
253
+ tooltip=[
254
+ alt.Tooltip("token:N", title="Token"),
255
+ alt.Tooltip("probability:Q", title="Predicted Prob.", format=".2f"),
256
+ alt.Tooltip("is_gt:N", title="Ground Truth")
257
+ ]
258
+ )
259
+ )
260
 
261
+ st.altair_chart(bars.properties(height=200), use_container_width=True, theme="streamlit")
 
 
 
 
 
 
262
 
263
 
264
+ ce_val, mae_val, was_val = compute_losses(probs_for_charts, gt_choice_for_charts)
 
 
 
 
 
 
 
 
 
 
 
265
 
266
 
267
+ if (
268
+ st.session_state.running_demo
269
+ and len(st.session_state.loss_history) < st.session_state.demo_step + 1
270
+ ):
271
+ step = st.session_state.demo_step
272
+ scenario = st.session_state.active_scenarios[step]
273
+ ce, mae, was = compute_losses(probs_for_charts, gt_choice_for_charts)
274
 
275
+ # pick x_val differently for bimodal vs others
276
+ if st.session_state.demo_name.startswith("Bimodal"):
277
+ x_val = scenario["name"] # e.g. "(4,4)", "(3,5)", โ€ฆ
278
+ else:
279
+ # exactly like before:
280
+ best_idx = np.argmax(scenario["values"])
281
+ x_val = options[best_idx] # "0", "1", โ€ฆ, or "Text"
282
+
283
+ st.session_state.loss_history.append(
284
+ {
285
+ "step": step,
286
+ "x_val": x_val,
287
+ "Cross Entropy": ce,
288
+ "NTL-MAE": mae,
289
+ "NTL-WAS": was,
290
+ }
291
+ )
292
+ st.session_state.df_loss_plot = pd.DataFrame(st.session_state.loss_history).melt(id_vars=["step", "x_val"],
293
+ value_vars=["Cross Entropy", "NTL-MAE", "NTL-WAS"],
294
+ var_name="Loss Type",
295
+ value_name="Loss Value")
296
 
297
  loss_data = {"Loss": ["Cross Entropy"], "Value": [ce_val]}
298
  if was_val != "N/A":
299
  loss_data["Loss"].append("NTL-WAS")
300
  loss_data["Value"].append(was_val)
301
+ if mae_val != "N/A":
302
+ loss_data["Loss"].append("NTL-MAE")
303
+ loss_data["Value"].append(mae_val)
304
 
305
  loss_df = pd.DataFrame(loss_data)
306
 
307
+ if st.session_state.demo_name.startswith("Bimodal"):
308
+ domain = [sc["name"] for sc in st.session_state.active_scenarios]
309
+ x_title = f"Offset from GT {st.session_state['ground_truth']}"
310
+ else:
311
+ domain = options
312
+ x_title = f"Maximum of predicted {st.session_state['demo_name']} distribution"
313
+
314
+
315
  # ============== Chart Display ==============
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
+
318
+ st.markdown("#### Loss as a function of predicted distribution")
319
+
320
+ grouped_chart = (
321
+ alt.Chart(st.session_state.df_loss_plot)
322
+ .mark_bar()
323
+ .encode(
324
+ x=alt.X(
325
+ "x_val:N",
326
+ title=x_title,
327
+ sort=domain,
328
+ scale=alt.Scale(domain=domain),
329
+ axis=alt.Axis(labelAngle=0, labelFontSize=14, titleFontSize=16),
330
+ ),
331
+ y=alt.Y(
332
+ "Loss Value:Q",
333
+ title="Loss Value",
334
+ scale=alt.Scale(domain=[0, MAX_LOSS_PLOT], nice=False, clamp=True),
335
+ axis=alt.Axis(labelFontSize=14, titleFontSize=16),
336
+ ),
337
+ color=alt.Color(
338
+ "Loss Type:N",
339
+ scale=alt.Scale(
340
+ domain=["Cross Entropy", "NTL-WAS", "NTL-MAE"],
341
+ range=["red", "limegreen", "blueviolet"],
342
+ ),
343
+ legend=alt.Legend(
344
+ title="",
345
+ orient="top",
346
+ direction="horizontal",
347
+ columns=3,
348
+ ),
349
+ ),
350
+ xOffset="Loss Type:N", # grouped bars
351
+ tooltip=[
352
+ alt.Tooltip("x_val:N", title="Scenario"),
353
+ alt.Tooltip("Loss Type:N", title="Loss Type"),
354
+ alt.Tooltip("Loss Value:Q", title="Value", format=".3f"),
355
+ ],
356
+ )
357
+ .properties(height=250)
358
  )
359
+ st.altair_chart(grouped_chart, use_container_width=True, theme="streamlit")
360
 
361
+
362
+ # Create a single chart for loss visualization
363
+ if not st.session_state.running_demo:
364
+ for i in range(len(options)):
365
+ st.session_state[f"slider_{i}"] = 0.0
366
+ st.session_state.demo_step = 0
367
+
368
+ st.subheader("Demo 2 -- Manual loss comparison")
369
+ st.subheader("๐Ÿงช Demo 2 โ€” Craft your own distribution")
370
+ st.markdown("""
371
+ This demo gives you more control but is harder to interpret. See it as a playground! ๐ŸŽจ
372
+ Manually adjust the sliders to change the predicted probabilities for each token.
373
+ The demo normalizes the values to form a valid probability distribution and calculates the losses.
374
+
375
+ ๐Ÿ‘ฃ **Steps:**
376
+ - Use the **vertical sliders** to allocate probability to each token.
377
+ - Choose the correct **Ground Truth Token** (0โ€“9 or "Text" ๐Ÿ“œ).
378
+ - Observe how each loss function reacts.
379
+
380
+ ๐Ÿ’ก **Tip:** Want to trick the loss? Try putting all mass on the wrong token or spread it wildly. See how NTL handles it! ๐Ÿ˜ˆ
381
+ """)
382
+
383
+ manual_gt = st.selectbox(
384
+ "Ground Truth Token",
385
+ options=options,
386
+ key="manual_ground_truth",
387
+ )
388
+
389
+ loss_df = pd.DataFrame(
390
+ {
391
+ "Loss": ["Cross Entropy", "NTL-MAE", "NTL-WAS"],
392
+ "Value": [ce_val, mae_val, was_val],
393
+ }
394
+ )
395
+
396
+ # Sliders and Ground Truth Selector
397
+ # These widgets will read their initial values from st.session_state.
398
+ # User interactions will update st.session_state directly due to their keys.
399
+ st.markdown("#### Adjust the predicted token probability")
400
+ cols = st.columns(len(options))
401
+ for i, col in enumerate(cols):
402
+ label = options[i] # Use token name directly for label
403
+ with col:
404
+ svs.vertical_slider(
405
+ label=label,
406
+ min_value=0.0,
407
+ max_value=1.0,
408
+ step=0.01,
409
+ height=50,
410
+ key=f"slider_{i}",
411
+ slider_color="green",
412
+ track_color="lightgray",
413
+ thumb_color="black",
414
+ )
415
+
416
+ chart = (
417
+ alt.Chart(loss_df)
418
+ .mark_bar()
419
+ .encode(
420
+ x=alt.X("Loss:N", sort=loss_df["Loss"].tolist()),
421
+ y=alt.Y(
422
+ "Value:Q",
423
+ scale=alt.Scale(
424
+ domain=[
425
+ 0,
426
+ max(
427
+ loss_df["Value"].max() * 1.2,
428
+ 20 if st.session_state.running_demo else 0.5,
429
+ ),
430
+ ]
431
+ ),
432
+ ),
433
+ color=alt.Color(
434
+ "Loss:N",
435
+ scale=alt.Scale(
436
+ domain=["Cross Entropy", "NTL-WAS", "NTL-MAE"],
437
+ range=["orangered", "limegreen", "blueviolet"],
438
+ ),
439
+ ),
440
+ tooltip=["Loss", "Value"],
441
+ )
442
+ .properties(height=300)
443
+ )
444
+
445
+ text = chart.mark_text(
446
+ align="center", baseline="bottom", dy=-5, fontSize=14
447
+ ).encode(text=alt.Text("Value:Q", format=".3f"))
448
+ final_chart = chart + text
449
+ st.altair_chart(final_chart, use_container_width=True)
450
+
451
+
452
+ # # Add value labels on top of bars
453
+ # text = chart.mark_text(align="center", baseline="bottom", dy=-5, fontSize=14).encode(
454
+ # text=alt.Text("Value:Q", format=".3f")
455
+ # )
456
+
457
+ # # Combine chart and text
458
+ # final_chart = chart + text
459
 
460
  # Display chart with the full container width
461
+ # st.altair_chart(final_chart, use_container_width=True)
462
 
463
  # --- Polling Rerun for Demo Mode ---
464
  # If the demo is running and we haven't just advanced (which would have caused a rerun),
 
466
  if st.session_state.running_demo:
467
  # This check is implicitly: if we are here and demo is running, it means
468
  # the time-based advance condition was NOT met in the block at the top.
469
+ time.sleep(DEMO_INTERVAL)
470
  st.rerun()
471
 
472
+
473
  st.markdown("""
474
+ ### ๐Ÿค” TL;DR โ€” Why NTL?
475
+ Cross Entropy only cares if the prediction is exactly right or wrong โŒโœ… โ€” it doesnโ€™t care *how close* a guess is!
476
+ Thatโ€™s bad for LLMs doing math and numeric reasoning ๐Ÿงฎ.
477
 
478
+ ๐Ÿ’ฅ NTL fixes that: it behaves like a regression loss on the token head, rewarding predictions that are numerically close.
 
479
  """)
480
 
481
+ st.markdown("#### ๐Ÿ“š Further Resources")
 
482
  st.markdown("""
483
+ - ๐Ÿ“„ [ICML 2025 Paper](https://arxiv.org/abs/2411.02083)
484
+ - ๐ŸŒ [NTL Landing Page](https://tum-ai.github.io/number-token-loss/)
485
+ - ๐Ÿ’ป [GitHub Code](https://github.com/tum-ai/number-token-loss)
486
  """)