File size: 1,330 Bytes
d418945
 
 
376bb46
d418945
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376bb46
 
d418945
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from shiny import render
from shiny.express import input, ui
from datasets import load_dataset
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd 
from scipy.interpolate import interp1d
import numpy as np

ui.page_opts(fillable=True)
ui.panel_title("Does context size matter for viral dna models?")

with ui.card():
    ui.input_selectize(  
    "plot_type",  
    "Select your model size:",
    ["14M", "31M", "70M", "160M", "410M"],
    multiple=False,  
)

def plot_loss_rates(df, type): 
    # interplot each column to be same number of points
    x = np.linspace(0, 1, 1000)
    loss_rates = []
    labels = ['32', '64', '128', '256', '512', '1024']
    for col in df.columns:
        y = df[col].values.dropna()
        f = interp1d(np.linspace(0, 1, len(y)), y)
        loss_rates.append(f(x))
    fig, ax = plt.subplots()
    for i, loss_rate in enumerate(loss_rates):
        ax.plot(x, loss_rate, label=labels[i])
    ax.legend()
    ax.set_title(f'Loss rates for {type} model')
    ax.set_xlabel('Training steps')
    ax.set_ylabel('Loss rate')
    return fig

 
    
    
    
@render.plot
def plot():
    fig = None
    if input.plot_type() == "14M":
        df = pd.read_csv('14m.csv')
        fig = plot_loss_rates(df, '14M')
    return fig