Hack90 commited on
Commit
a513cc4
·
verified ·
1 Parent(s): d17b69e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -21
app.py CHANGED
@@ -1000,27 +1000,27 @@ with ui.navset_card_tab(id="tab"):
1000
  selected='compliment'
1001
  )
1002
  def plot_loss_rates(df, param_types, loss_types, model_types):
1003
- # interplot each column to be same number of points
1004
- x = np.linspace(0, 1, 1000)
1005
- loss_rates = []
1006
- labels = []
1007
- #drop the column step
1008
- df = df.drop(columns=['Step'])
1009
- for param_type in param_types:
1010
- for loss_type in loss_types:
1011
- for model_type in model_types:
1012
- y = df[df['param_type']==param_type && df['loss_type']==loss_type && df['model_type']==model_type]['loss'].dropna().astype('float', errors = 'ignore').dropna().values
1013
- f = interp1d(np.linspace(0, 1, len(y)), y)
1014
- loss_rates.append(f(x))
1015
- labels.append(param_type +'_'+loss_type +'_'+model_type)
1016
- fig, ax = plt.subplots()
1017
- for i, loss_rate in enumerate(loss_rates):
1018
- ax.plot(x, loss_rate, label=labels[i])
1019
- ax.legend()
1020
- ax.set_title(f'Loss rates for a {type} parameter model across context windows')
1021
- ax.set_xlabel('Training steps')
1022
- ax.set_ylabel('Loss rate')
1023
- return fig
1024
 
1025
  import matplotlib as mpl
1026
  @render.plot()
 
1000
  selected='compliment'
1001
  )
1002
  def plot_loss_rates(df, param_types, loss_types, model_types):
1003
+ # interplot each column to be same number of points
1004
+ x = np.linspace(0, 1, 1000)
1005
+ loss_rates = []
1006
+ labels = []
1007
+ #drop the column step
1008
+ df = df.drop(columns=['Step'])
1009
+ for param_type in param_types:
1010
+ for loss_type in loss_types:
1011
+ for model_type in model_types:
1012
+ y = df[df['param_type']==param_type && df['loss_type']==loss_type && df['model_type']==model_type]['loss'].dropna().astype('float', errors = 'ignore').dropna().values
1013
+ f = interp1d(np.linspace(0, 1, len(y)), y)
1014
+ loss_rates.append(f(x))
1015
+ labels.append(param_type +'_'+loss_type +'_'+model_type)
1016
+ fig, ax = plt.subplots()
1017
+ for i, loss_rate in enumerate(loss_rates):
1018
+ ax.plot(x, loss_rate, label=labels[i])
1019
+ ax.legend()
1020
+ ax.set_title(f'Loss rates for a {type} parameter model across context windows')
1021
+ ax.set_xlabel('Training steps')
1022
+ ax.set_ylabel('Loss rate')
1023
+ return fig
1024
 
1025
  import matplotlib as mpl
1026
  @render.plot()