Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -5,7 +5,7 @@ import pandas as pd
|
|
5 |
from pathlib import Path
|
6 |
import matplotlib
|
7 |
import numpy as np
|
8 |
-
import gradio
|
9 |
|
10 |
############################################################# 2D Line Plot ########################################################
|
11 |
### dvq stuff, obvs this will just be an import in the final version
|
@@ -883,6 +883,39 @@ with ui.navset_card_tab(id="tab"):
|
|
883 |
return fig
|
884 |
with ui.nav_panel("Viral Model"):
|
885 |
gr.load("models/Hack90/virus_pythia_31_1024").launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
886 |
# @render.image
|
887 |
# def image():
|
888 |
# img = None
|
|
|
5 |
from pathlib import Path
|
6 |
import matplotlib
|
7 |
import numpy as np
|
8 |
+
import gradio as gr
|
9 |
|
10 |
############################################################# 2D Line Plot ########################################################
|
11 |
### dvq stuff, obvs this will just be an import in the final version
|
|
|
883 |
return fig
|
884 |
with ui.nav_panel("Viral Model"):
|
885 |
gr.load("models/Hack90/virus_pythia_31_1024").launch()
|
886 |
+
|
887 |
+
with ui.nav_panel("Viral Model Training"):
|
888 |
+
ui.page_opts(fillable=True)
|
889 |
+
ui.panel_title("Does context size matter for a nucleotide model?")
|
890 |
+
|
891 |
+
def plot_loss_rates(df, type):
|
892 |
+
# interplot each column to be same number of points
|
893 |
+
x = np.linspace(0, 1, 1000)
|
894 |
+
loss_rates = []
|
895 |
+
labels = ['32', '64', '128', '256', '512', '1024']
|
896 |
+
#drop the column step
|
897 |
+
df = df.drop(columns=['Step'])
|
898 |
+
for col in df.columns:
|
899 |
+
y = df[col].dropna().astype('float', errors = 'ignore').dropna().values
|
900 |
+
f = interp1d(np.linspace(0, 1, len(y)), y)
|
901 |
+
loss_rates.append(f(x))
|
902 |
+
fig, ax = plt.subplots()
|
903 |
+
for i, loss_rate in enumerate(loss_rates):
|
904 |
+
ax.plot(x, loss_rate, label=labels[i])
|
905 |
+
ax.legend()
|
906 |
+
ax.set_title(f'Loss rates for a {type} parameter model')
|
907 |
+
ax.set_xlabel('Training steps')
|
908 |
+
ax.set_ylabel('Loss rate')
|
909 |
+
return fig
|
910 |
+
|
911 |
+
@render.plot
|
912 |
+
def plot():
|
913 |
+
fig = None
|
914 |
+
df = pd.read_csv('14m.csv')
|
915 |
+
mpl.rcParams.update(mpl.rcParamsDefault)
|
916 |
+
fig = plot_loss_rates(df, '14M')
|
917 |
+
return fig
|
918 |
+
|
919 |
# @render.image
|
920 |
# def image():
|
921 |
# img = None
|