chenmouxiang's picture
Update app.py
3fe5b5b verified
raw
history blame
5.51 kB
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import math
from datetime import datetime
from matplotlib.ticker import FuncFormatter
# Predefined hyperparameter sets
PARAM_SETS = {
"Stack-V2-Python": {"E": 0.69123678, "A": 0.01130616 * 1e9, "k": 0.393463, "alpha": 0.18937067},
"Pile": {"E": 1.28254036, "A": 0.2035367 * 1e9, "k": 0.33027934, "alpha": 0.19479807}
}
def pred_loss(E, A, k, alpha, n, p):
return E + (A / (n * (1 + np.log(p) * k))) ** alpha
def generate_plot(E, A, k, alpha):
plt.clf()
colors = ['#2B83BA', '#7BB7D6', '#ED7D5F', '#D7191C']
ax = plt.gca()
for i, p in enumerate([1, 2, 4, 8]):
x_plot = np.linspace(535813376 * 0.9, 4353203200 * 1.1, 100)
y_plot = pred_loss(E, A, k, alpha, x_plot, p)
ax.plot(x_plot, y_plot, marker=None, markersize=1, linewidth=3, color=colors[int(math.log(p, 2))], label=f"$P={p}$")
ax.legend(fontsize=12)
# ax.set_xscale("log")
# ax.set_yscale("log")
def billions(x, pos):
if x < 1e9:
result = ""
else:
result = f'{x * 1e-9:.1f}B'
return result
ax.xaxis.set_major_formatter(FuncFormatter(billions))
ax.xaxis.set_minor_formatter(FuncFormatter(billions))
ax.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: f"{x:.2f}"))
ax.yaxis.set_minor_formatter(FuncFormatter(lambda x, pos: f"{x:.2f}"))
ax.set_xlim(535813376 * 0.9, 4353203200 * 1.1)
ax.set_ylim(ax.get_ylim()[0] * 1, ax.get_ylim()[1] * 1.01)
ax.text(0.03, 0.03, f"$E={E}$\n$A={A}$\n$k={k}$\n$\\alpha={alpha}$", transform=ax.transAxes, fontsize=10, verticalalignment='bottom', multialignment='left')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_xlabel('Parameters (Non-Embedding)', fontsize=12)
ax.set_ylabel(f'Loss', fontsize=12)
return plt
OUTPUT_TEMPLATE = """Loss for a {n}B model when P={p} is: **{loss:.5f}**. It is equivalant to:
- A **{n1}B** model with **P=1**;
- A **{n2}B** model with **P=2**;
- A **{n4}B** model with **P=4**;
- A **{n8}B** model with **P=8**;
Note: The equivalent parameters are for reference only. In some reasoning tasks, scaling the parallel streams will obtain more performance gains than the loss benefits!
Enjoy it! 😊"""
def process_inputs(E, A, k, alpha, n, p):
"""Process inputs and return results"""
n = n * 1e9
plot = generate_plot(E, A, k, alpha)
loss = pred_loss(E, A, k, alpha, n, p)
n1 = n * (k * np.log(p) + 1) / (k * np.log(1) + 1) / 1e9
n2 = n * (k * np.log(p) + 1) / (k * np.log(2) + 1) / 1e9
n4 = n * (k * np.log(p) + 1) / (k * np.log(4) + 1) / 1e9
n8 = n * (k * np.log(p) + 1) / (k * np.log(8) + 1) / 1e9
print(f"[{datetime.now()}] {E = }, {A = }, {k = }, {alpha = }, {n = }, {p = }")
return plot, OUTPUT_TEMPLATE.format(n=round(n / 1e9, 2), p=p, n1=round(n1, 2), n2=round(n2, 2), n4=round(n4, 2), n8=round(n8, 2), loss=loss)
# Create interface
HEAD = """<div align="center">
# Parallel Scaling Law Visualization
[![Paper](https://img.shields.io/badge/arXiv-2505.10475-red)](https://arxiv.org/abs/2505.10475)
</div>
"""
with gr.Blocks() as demo:
gr.Markdown(HEAD)
with gr.Row():
with gr.Column():
gr.Markdown("""$$
\\text{Loss}=E+\\left(
\\frac{A}{\\text{Parameters}\\times (1+k\\log P)}
\\right)^{\\alpha}
$$""")
# Input values
N = gr.Number(value=2.8, label="N: Number of Non-Embedding Model Parameters (in Billion)")
P = gr.Number(value=4, label="P: Number of Parallel Streams")
gr.Markdown("---")
# Hyperparameter selection section
param_set = gr.Dropdown(
choices=["Custom"] + list(PARAM_SETS.keys()),
value=list(PARAM_SETS.keys())[0],
label="Select our pre-fitted parameters for two datasets"
)
# Custom parameter inputs
param_E = gr.Number(value=PARAM_SETS["Stack-V2-Python"]['E'], label="E")
param_A = gr.Number(value=PARAM_SETS["Stack-V2-Python"]['A'], label="A")
param_k = gr.Number(value=PARAM_SETS["Stack-V2-Python"]['k'], label="k")
param_alpha = gr.Number(value=PARAM_SETS["Stack-V2-Python"]['alpha'], label="alpha")
plot, output = process_inputs(PARAM_SETS["Stack-V2-Python"]['E'], PARAM_SETS["Stack-V2-Python"]['A'], PARAM_SETS["Stack-V2-Python"]['k'], PARAM_SETS["Stack-V2-Python"]['alpha'], 2.8, 4)
with gr.Column():
submit_btn = gr.Button("Calculate")
# Output section
plot_output = gr.Plot(label="Scaling Law Curve", value=plot)
result_output = gr.Markdown(label="Result", value=output)
# Auto-fill parameters when selecting predefined sets
def update_params(param_set):
if param_set in PARAM_SETS:
params = PARAM_SETS[param_set]
return [params["E"], params["A"], params["k"], params["alpha"]]
return [gr.skip(), gr.skip(), gr.skip(), gr.skip()]
param_set.change(
update_params,
inputs=[param_set],
outputs=[param_E, param_A, param_k, param_alpha]
)
# Submit button event
click_event = submit_btn.click(
process_inputs,
inputs=[param_E, param_A, param_k, param_alpha,
N, P],
outputs=[plot_output, result_output]
)
demo.launch()