Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
### CHINCHILLA PARAMS:
|
6 |
+
E = 1.62
|
7 |
+
A = 406.4
|
8 |
+
B = 410.7
|
9 |
+
alpha = 0.336
|
10 |
+
beta = 0.283
|
11 |
+
|
12 |
+
Bn = 10**9
|
13 |
+
|
14 |
+
G = ((alpha*A)/(beta*B))**(1/(alpha+beta))
|
15 |
+
###
|
16 |
+
|
17 |
+
def to_flops(N, D):
|
18 |
+
return 6 * N * D
|
19 |
+
|
20 |
+
def n_opt(C):
|
21 |
+
return G * ((C/6) ** (beta / (alpha+beta)))
|
22 |
+
|
23 |
+
def d_opt(C):
|
24 |
+
return (1/G) * ((C/6) ** (alpha / (alpha+beta)))
|
25 |
+
|
26 |
+
def get_kd(kn):
|
27 |
+
frac = (A/B)*(G**(-alpha-beta))
|
28 |
+
kd = (1-((kn**-alpha -1)*frac))**(1/(-beta))
|
29 |
+
return kd
|
30 |
+
|
31 |
+
def compute_overhead(kn, kd):
|
32 |
+
return kn*kd - 1
|
33 |
+
|
34 |
+
### PRECOMPUTE CURVE:
|
35 |
+
kn_min = 0.2
|
36 |
+
kn_max = 2
|
37 |
+
|
38 |
+
kns = np.linspace(0.05, 2, 100)
|
39 |
+
overheads = []
|
40 |
+
for kn in np.linspace(0.2, 2, 100):
|
41 |
+
kd = get_kd(kn)
|
42 |
+
overheads.append(compute_overhead(kn, kd)*100)
|
43 |
+
|
44 |
+
def plot_curve(kn, kd):
|
45 |
+
plt.plot(kns, overheads)
|
46 |
+
plt.scatter([kn], [kd])
|
47 |
+
plt.xlabel("Fraction of compute optimal model size")
|
48 |
+
plt.ylabel("Compute overhead (%)")
|
49 |
+
|
50 |
+
with gr.Blocks() as demo:
|
51 |
+
N = gr.number(value=1, label="Model size (in B parameters)")
|
52 |
+
D = gr.number(value=100, label="Dataset size (in B tokens")
|
53 |
+
|
54 |
+
C = to_flops(N * Bn, D * Bn)
|
55 |
+
N_opt = n_opt(C)
|
56 |
+
D_opt = d_opt(C)
|
57 |
+
|
58 |
+
kn = N/N_opt
|
59 |
+
|
60 |
+
plot_curve(kn, 100*overhead(kn, get_kd(kn)))
|
61 |
+
|
62 |
+
gr.Plot(value=plt)
|
63 |
+
gr.Markdown(f"""Compute budget (TFLOPs): {C:.2E}\nTraining compute overhead (%): {100*overhead(kn, get_kd(kn)).2f}\nInference cost fraction (%): {kn*100:.2f}""")
|
64 |
+
|
65 |
+
demo.launch()
|