NumberTokenLoss / src /streamlit_app.py
jannisborn's picture
update
2854521 unverified
import logging
import time
import altair as alt
import numpy as np
import pandas as pd
import streamlit as st
import streamlit_vertical_slider as svs
import torch
from scenarios import dirac, gauss, make_bimodal_scenarios
logging.getLogger("streamlit.watcher.local_sources_watcher").setLevel(logging.ERROR)
DEMO_INTERVAL = 0.75
CE_SCALING = 0.25
MAX_LOSS_PLOT = 6
LAST_STEP = -1
# Define options globally as it's used in initialization and UI
options = [str(i) for i in range(10)] + ["Text"]
def compute_losses(probs: torch.Tensor, gt_token: str) -> tuple[float, float, float]:
"""Compute CE, NTL-MAE, NTL-WAS losses for the given probability vector and ground truth token."""
ce_loss = CE_SCALING * -torch.log(
torch.clamp(probs[options.index(gt_token)], min=1e-9)
)
numeric_mass = probs[:10].sum()
if gt_token == "Text" or numeric_mass < 1e-6:
return ce_loss.item(), 0.0, 0.0
gt_numeric = int(gt_token)
token_vals = torch.arange(10, dtype=torch.float32)
mae = numeric_mass * abs(torch.dot(token_vals, probs[:10]) - gt_numeric)
was = numeric_mass * torch.dot(probs[:10], torch.abs(token_vals - gt_numeric))
return round(ce_loss.item(), 3), round(mae.item(), 3), round(was.item(), 3)
# --- Session State Initialization ---
# Ensure all session state variables are initialized before first use, especially by widgets.
if "running_demo" not in st.session_state:
st.session_state.running_demo = False
if "demo_step" not in st.session_state:
st.session_state.demo_step = 0
if "last_update_time" not in st.session_state:
st.session_state.last_update_time = 0
if "loss_container" not in st.session_state:
st.session_state.loss_container = None
if "previous_chart_html" not in st.session_state:
st.session_state.previous_chart_html = ""
if "active_scenarios" not in st.session_state:
# default if you want one to load on first show
st.session_state.active_scenarios = dirac
if "loss_history" not in st.session_state:
st.session_state.loss_history = []
if "df_loss_plot" not in st.session_state:
# Initialize an empty DataFrame for loss history
st.session_state.df_loss_plot = pd.DataFrame(
columns=["step", "x_val", "Loss Type", "Loss Value"]
)
# Initialize states for sliders and ground_truth selector
# Using len(options) to correctly size for 0-9 + "Text"
for i in range(len(options)):
if f"slider_{i}" not in st.session_state:
st.session_state[f"slider_{i}"] = 0
if "ground_truth" not in st.session_state:
st.session_state["ground_truth"] = options[5]
if "manual_ground_truth" not in st.session_state:
st.session_state["manual_ground_truth"] = options[5]
if "demo_name" not in st.session_state:
st.session_state["demo_name"] = "Dirac"
st.title("NTL -- The Number Token Loss ๐Ÿš€")
st.markdown(
"""This is the interactive demo for our [ICML 2025](https://arxiv.org/abs/2411.02083) paper!๐ŸŽ‰
โžก๏ธ NTL augments cross-entropy to help LMs reason better with numbers ๐Ÿง 
"""
)
st.subheader("Demo 1 โ€” NTL vs. Cross Entropy in 3 Scenarios")
st.markdown("""
1๏ธโƒฃ Pick a ground truth token: a digit (0โ€“9) or "Text" ๐Ÿ“ (simulates generic text tokens).
2๏ธโƒฃ Choose a demo:
- **Dirac** โšก: All probability mass on one token.
- **Gaussian** ๐ŸŒŠ: Soft bell-curve around the true number.
- **Bimodal** ๐ŸŽฏ: Two peaks moving away from the target.
Watch how losses evolve as predictions get worse โ€” and see how NTL shines compared to CE! ๐ŸŒŸ
""")
if "ground_truth" not in st.session_state:
st.session_state["ground_truth"] = "4"
gt = st.selectbox("Ground Truth Token", options=options, key="ground_truth")
def apply_scenario(step_idx):
scenario = st.session_state.active_scenarios[step_idx]
for i, val in enumerate(scenario["values"]):
st.session_state[f"slider_{i}"] = val
def start_dirac_demo():
st.session_state.loss_history = []
st.session_state.active_scenarios = dirac
st.session_state.demo_name = "Dirac"
st.session_state.running_demo = True
st.session_state.demo_step = 0
st.session_state.last_update_time = time.time()
apply_scenario(0)
def start_gauss_demo():
st.session_state.loss_history = []
st.session_state.active_scenarios = gauss
st.session_state.demo_name = "Gauss"
st.session_state.running_demo = True
st.session_state.demo_step = 0
st.session_state.last_update_time = time.time()
apply_scenario(0)
def start_bimodal_demo():
st.session_state.loss_history = []
gt = st.session_state["ground_truth"]
st.session_state.active_scenarios = make_bimodal_scenarios(gt, options)
st.session_state.demo_name = f"Bimodal (GT={gt})"
st.session_state.running_demo = True
st.session_state.demo_step = 0
st.session_state.last_update_time = time.time()
apply_scenario(0)
def stop_demo():
st.session_state.running_demo = False
# --- Demo State Advancement Logic ---
# This block handles advancing the demo. If it advances, it updates session state
# and then reruns. This ensures widgets are drawn with the new state in the next run.
if st.session_state.running_demo:
scenario = st.session_state.active_scenarios
current_time = time.time()
if current_time - st.session_state.last_update_time > DEMO_INTERVAL:
# if we havenโ€™t yet shown the last scenario, advance
if st.session_state.demo_step < len(scenario) - 1:
st.session_state.demo_step += 1
apply_scenario(st.session_state.demo_step)
st.session_state.last_update_time = current_time
# st.rerun() # not needed, leading to too many reruns
else:
# we just displayed the final case โ†’ stop
st.session_state.running_demo = False
# --- UI Rendering ---
# This section renders the main UI. It executes after any potential rerun from the block above.
if st.session_state.running_demo:
st.info(
f"Showing scenario {st.session_state.demo_step + 1}"
f"/{len(st.session_state.active_scenarios)}: "
f"{st.session_state.active_scenarios[st.session_state.demo_step]['name']}"
)
if st.button("Stop Demo"):
st.session_state.running_demo = False
st.rerun()
else:
col1, col2, col3 = st.columns(3)
with col1:
if st.button("Run: Dirac"):
start_dirac_demo()
st.rerun()
with col2:
if st.button("Run: Gauss"):
start_gauss_demo()
st.rerun()
with col3:
if st.button("Run: Bimodal"):
start_bimodal_demo()
st.rerun()
current_prob_values_from_state = [
st.session_state.get(f"slider_{j}", 0)
for j in range(len(options)) # 1.0 / len(options)) for j in range(len(options))
]
total_from_state = sum(current_prob_values_from_state)
probs_for_charts = (
torch.ones(len(options)) / len(options)
if total_from_state == 0
else torch.tensor([v / total_from_state for v in current_prob_values_from_state])
)
# Use manual GT token when not in running demo
gt_choice_for_charts = (
st.session_state["manual_ground_truth"]
if not st.session_state.running_demo
else st.session_state["ground_truth"]
)
if gt_choice_for_charts == "Text":
gt_index_for_charts = 10 # Assuming "Text" is the 11th item (index 10)
gt_numeric_for_charts = None
else:
gt_index_for_charts = int(gt_choice_for_charts)
gt_numeric_for_charts = gt_index_for_charts
gt = st.session_state["ground_truth"]
demo_name = st.session_state["demo_name"]
st.markdown(f'#### Predicted distribution (<span style="color:darkgreen;">ground truth: {gt}</span>)', unsafe_allow_html=True)
df_dist = pd.DataFrame(
{"token": options, "probability": probs_for_charts.numpy().round(2)}
)
df_dist["is_gt"] = df_dist["token"] == gt
bars = (
alt.Chart(df_dist)
.mark_bar(color="dodgerblue", size=40)
.encode(
x=alt.X(
"token:N",
title="Token",
sort=options,
axis=alt.Axis(
labelAngle=0,
labelFontSize=14,
titleFontSize=16,
labelAlign="center",
labelFlush=False,
),
),
color=alt.condition(
"datum.is_gt",
alt.value("darkgreen"), # color for ground truth
alt.value("dodgerblue") # color for others
),
y=alt.Y(
"probability:Q",
title="Probability",
scale=alt.Scale(domain=[0, 1]),
axis=alt.Axis(format=".2f", labelFontSize=14, titleFontSize=16),
),
tooltip=[
alt.Tooltip("token:N", title="Token"),
alt.Tooltip("probability:Q", title="Predicted Prob.", format=".2f"),
alt.Tooltip("is_gt:N", title="Ground Truth")
]
)
)
st.altair_chart(bars.properties(height=200), use_container_width=True, theme="streamlit")
ce_val, mae_val, was_val = compute_losses(probs_for_charts, gt_choice_for_charts)
if (
st.session_state.running_demo
and len(st.session_state.loss_history) < st.session_state.demo_step + 1
):
step = st.session_state.demo_step
scenario = st.session_state.active_scenarios[step]
ce, mae, was = compute_losses(probs_for_charts, gt_choice_for_charts)
# pick x_val differently for bimodal vs others
if st.session_state.demo_name.startswith("Bimodal"):
x_val = scenario["name"] # e.g. "(4,4)", "(3,5)", โ€ฆ
else:
# exactly like before:
best_idx = np.argmax(scenario["values"])
x_val = options[best_idx] # "0", "1", โ€ฆ, or "Text"
st.session_state.loss_history.append(
{
"step": step,
"x_val": x_val,
"Cross Entropy": ce,
"NTL-MAE": mae,
"NTL-WAS": was,
}
)
st.session_state.df_loss_plot = pd.DataFrame(st.session_state.loss_history).melt(id_vars=["step", "x_val"],
value_vars=["Cross Entropy", "NTL-MAE", "NTL-WAS"],
var_name="Loss Type",
value_name="Loss Value")
loss_data = {"Loss": ["Cross Entropy"], "Value": [ce_val]}
if was_val != "N/A":
loss_data["Loss"].append("NTL-WAS")
loss_data["Value"].append(was_val)
if mae_val != "N/A":
loss_data["Loss"].append("NTL-MAE")
loss_data["Value"].append(mae_val)
loss_df = pd.DataFrame(loss_data)
if st.session_state.demo_name.startswith("Bimodal"):
domain = [sc["name"] for sc in st.session_state.active_scenarios]
x_title = f"Offset from GT {st.session_state['ground_truth']}"
else:
domain = options
x_title = f"Maximum of predicted {st.session_state['demo_name']} distribution"
# ============== Chart Display ==============
st.markdown("#### Loss as a function of predicted distribution")
grouped_chart = (
alt.Chart(st.session_state.df_loss_plot)
.mark_bar()
.encode(
x=alt.X(
"x_val:N",
title=x_title,
sort=domain,
scale=alt.Scale(domain=domain),
axis=alt.Axis(labelAngle=0, labelFontSize=14, titleFontSize=16),
),
y=alt.Y(
"Loss Value:Q",
title="Loss Value",
scale=alt.Scale(domain=[0, MAX_LOSS_PLOT], nice=False, clamp=True),
axis=alt.Axis(labelFontSize=14, titleFontSize=16),
),
color=alt.Color(
"Loss Type:N",
scale=alt.Scale(
domain=["Cross Entropy", "NTL-WAS", "NTL-MAE"],
range=["red", "limegreen", "blueviolet"],
),
legend=alt.Legend(
title="",
orient="top",
direction="horizontal",
columns=3,
),
),
xOffset="Loss Type:N", # grouped bars
tooltip=[
alt.Tooltip("x_val:N", title="Scenario"),
alt.Tooltip("Loss Type:N", title="Loss Type"),
alt.Tooltip("Loss Value:Q", title="Value", format=".3f"),
],
)
.properties(height=250)
)
st.altair_chart(grouped_chart, use_container_width=True, theme="streamlit")
# Create a single chart for loss visualization
if not st.session_state.running_demo:
for i in range(len(options)):
st.session_state[f"slider_{i}"] = 0.0
st.session_state.demo_step = 0
st.subheader("Demo 2 -- Manual loss comparison")
st.subheader("๐Ÿงช Demo 2 โ€” Craft your own distribution")
st.markdown("""
This demo gives you more control but is harder to interpret. See it as a playground! ๐ŸŽจ
Manually adjust the sliders to change the predicted probabilities for each token.
The demo normalizes the values to form a valid probability distribution and calculates the losses.
๐Ÿ‘ฃ **Steps:**
- Use the **vertical sliders** to allocate probability to each token.
- Choose the correct **Ground Truth Token** (0โ€“9 or "Text" ๐Ÿ“œ).
- Observe how each loss function reacts.
๐Ÿ’ก **Tip:** Want to trick the loss? Try putting all mass on the wrong token or spread it wildly. See how NTL handles it! ๐Ÿ˜ˆ
""")
manual_gt = st.selectbox(
"Ground Truth Token",
options=options,
key="manual_ground_truth",
)
loss_df = pd.DataFrame(
{
"Loss": ["Cross Entropy", "NTL-MAE", "NTL-WAS"],
"Value": [ce_val, mae_val, was_val],
}
)
# Sliders and Ground Truth Selector
# These widgets will read their initial values from st.session_state.
# User interactions will update st.session_state directly due to their keys.
st.markdown("#### Adjust the predicted token probability")
cols = st.columns(len(options))
for i, col in enumerate(cols):
label = options[i] # Use token name directly for label
with col:
svs.vertical_slider(
label=label,
min_value=0.0,
max_value=1.0,
step=0.01,
height=50,
key=f"slider_{i}",
slider_color="green",
track_color="lightgray",
thumb_color="black",
)
chart = (
alt.Chart(loss_df)
.mark_bar()
.encode(
x=alt.X("Loss:N", sort=loss_df["Loss"].tolist()),
y=alt.Y(
"Value:Q",
scale=alt.Scale(
domain=[
0,
max(
loss_df["Value"].max() * 1.2,
20 if st.session_state.running_demo else 0.5,
),
]
),
),
color=alt.Color(
"Loss:N",
scale=alt.Scale(
domain=["Cross Entropy", "NTL-WAS", "NTL-MAE"],
range=["orangered", "limegreen", "blueviolet"],
),
),
tooltip=["Loss", "Value"],
)
.properties(height=300)
)
text = chart.mark_text(
align="center", baseline="bottom", dy=-5, fontSize=14
).encode(text=alt.Text("Value:Q", format=".3f"))
final_chart = chart + text
st.altair_chart(final_chart, use_container_width=True)
# # Add value labels on top of bars
# text = chart.mark_text(align="center", baseline="bottom", dy=-5, fontSize=14).encode(
# text=alt.Text("Value:Q", format=".3f")
# )
# # Combine chart and text
# final_chart = chart + text
# Display chart with the full container width
# st.altair_chart(final_chart, use_container_width=True)
# --- Polling Rerun for Demo Mode ---
# If the demo is running and we haven't just advanced (which would have caused a rerun),
# then we do a short sleep and rerun to keep the polling loop alive.
if st.session_state.running_demo:
# This check is implicitly: if we are here and demo is running, it means
# the time-based advance condition was NOT met in the block at the top.
time.sleep(DEMO_INTERVAL)
st.rerun()
st.markdown("""
### ๐Ÿค” TL;DR โ€” Why NTL?
Cross Entropy only cares if the prediction is exactly right or wrong โŒโœ… โ€” it doesnโ€™t care *how close* a guess is!
Thatโ€™s bad for LLMs doing math and numeric reasoning ๐Ÿงฎ.
๐Ÿ’ฅ NTL fixes that: it behaves like a regression loss on the token head, rewarding predictions that are numerically close.
""")
st.markdown("#### ๐Ÿ“š Further Resources")
st.markdown("""
- ๐Ÿ“„ [ICML 2025 Paper](https://arxiv.org/abs/2411.02083)
- ๐ŸŒ [NTL Landing Page](https://tum-ai.github.io/number-token-loss/)
- ๐Ÿ’ป [GitHub Code](https://github.com/tum-ai/number-token-loss)
""")