Hack90 commited on
Commit
9c70cba
·
verified ·
1 Parent(s): 654d483

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -1
app.py CHANGED
@@ -5,7 +5,7 @@ import pandas as pd
5
  from pathlib import Path
6
  import matplotlib
7
  import numpy as np
8
- import gradio
9
 
10
  ############################################################# 2D Line Plot ########################################################
11
  ### dvq stuff, obvs this will just be an import in the final version
@@ -883,6 +883,39 @@ with ui.navset_card_tab(id="tab"):
883
  return fig
884
  with ui.nav_panel("Viral Model"):
885
  gr.load("models/Hack90/virus_pythia_31_1024").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
886
  # @render.image
887
  # def image():
888
  # img = None
 
5
  from pathlib import Path
6
  import matplotlib
7
  import numpy as np
8
+ import gradio as gr
9
 
10
  ############################################################# 2D Line Plot ########################################################
11
  ### dvq stuff, obvs this will just be an import in the final version
 
883
  return fig
884
  with ui.nav_panel("Viral Model"):
885
  gr.load("models/Hack90/virus_pythia_31_1024").launch()
886
+
887
+ with ui.nav_panel("Viral Model Training"):
888
+ ui.page_opts(fillable=True)
889
+ ui.panel_title("Does context size matter for a nucleotide model?")
890
+
891
+ def plot_loss_rates(df, type):
892
+ # interplot each column to be same number of points
893
+ x = np.linspace(0, 1, 1000)
894
+ loss_rates = []
895
+ labels = ['32', '64', '128', '256', '512', '1024']
896
+ #drop the column step
897
+ df = df.drop(columns=['Step'])
898
+ for col in df.columns:
899
+ y = df[col].dropna().astype('float', errors = 'ignore').dropna().values
900
+ f = interp1d(np.linspace(0, 1, len(y)), y)
901
+ loss_rates.append(f(x))
902
+ fig, ax = plt.subplots()
903
+ for i, loss_rate in enumerate(loss_rates):
904
+ ax.plot(x, loss_rate, label=labels[i])
905
+ ax.legend()
906
+ ax.set_title(f'Loss rates for a {type} parameter model')
907
+ ax.set_xlabel('Training steps')
908
+ ax.set_ylabel('Loss rate')
909
+ return fig
910
+
911
+ @render.plot
912
+ def plot():
913
+ fig = None
914
+ df = pd.read_csv('14m.csv')
915
+ mpl.rcParams.update(mpl.rcParamsDefault)
916
+ fig = plot_loss_rates(df, '14M')
917
+ return fig
918
+
919
  # @render.image
920
  # def image():
921
  # img = None