File size: 3,583 Bytes
e60235b
 
 
 
 
 
 
 
 
 
 
 
5c32e5a
e60235b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc62c89
e60235b
 
fc62c89
6d05e40
e60235b
 
fc62c89
e60235b
fc62c89
e60235b
 
fc62c89
e60235b
 
 
 
 
 
 
 
 
 
 
fc62c89
e60235b
 
 
 
 
 
 
 
 
 
 
 
fc62c89
e60235b
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import re
import pandas as pd
import streamlit as st

import  opendashboards.utils.utils as utils


@st.cache_data
def load_runs(project, filters, min_steps=10):
    runs = []
    msg = st.empty()
    for run in utils.get_runs(project, filters, api_key=st.secrets['WANDB_API_KEY']):
        step = run.summary.get('_step',0)
        if step < min_steps:
            msg.warning(f'Skipped run `{run.name}` because it contains {step} events (<{min_steps})')
            continue

        duration = run.summary.get('_runtime')
        end_time = run.summary.get('_timestamp')
        # extract values for selected tags
        rules = {'hotkey': re.compile('^[0-9a-z]{48}$',re.IGNORECASE), 'version': re.compile('^\\d\.\\d+\.\\d+$'), 'spec_version': re.compile('\\d{4}$')}
        tags = {k: tag for k, rule in rules.items() for tag in run.tags if rule.match(tag)}
        # include bool flag for remaining tags
        tags.update({k: k in run.tags for k in ('mock','custom_gating_model','nsfw_filter','outsource_scoring','disable_set_weights')})

        runs.append({
            'state': run.state,
            'num_steps': step,
            'entity': run.entity,
            'id': run.id,
            'name': run.name,
            'project': run.project,
            'url': run.url,
            'path': os.path.join(run.entity, run.project, run.id),
            'start_time': pd.to_datetime(end_time-duration, unit="s"),
            'end_time': pd.to_datetime(end_time, unit="s"),
            'duration': pd.to_datetime(duration, unit="s"),
            **tags
        })
    msg.empty()
    return pd.DataFrame(runs).astype({'state': 'category', 'hotkey': 'category', 'version': 'category', 'spec_version': 'category'})


@st.cache_data
def load_data(selected_runs, load=True, save=False):

    frames = []
    n_events = 0
    successful = 0
    progress = st.progress(0, 'Loading data')
    info = st.empty()
    if not os.path.exists('data/'):
        os.makedirs('data/')
    for i, idx in enumerate(selected_runs.index):
        run = selected_runs.loc[idx]
        prog_msg = f'Loading data {i/len(selected_runs)*100:.0f}% ({successful}/{len(selected_runs)} runs, {n_events} events)'

        file_path = os.path.join('data',f'history-{run.id}.csv')

        if load and os.path.exists(file_path):
            progress.progress(i/len(selected_runs),f'{prog_msg}... **reading** `{file_path}`')
            try:
                df = utils.load_data(file_path)
            except Exception as e:
                info.warning(f'Failed to load history from `{file_path}`')
                st.exception(e)
                continue
        else:
            progress.progress(i/len(selected_runs),f'{prog_msg}... **downloading** `{run.path}`')
            try:
                # Download the history from wandb
                df = utils.download_data(run.path)
                # Add metadata to the dataframe
                df.assign(**run.to_dict())

                if save and run.state != 'running':
                    df.to_csv(file_path, index=False)
                    # st.info(f'Saved history to {file_path}')
            except Exception as e:
                info.warning(f'Failed to download history for `{run.path}`')
                st.exception(e)
                continue

        frames.append(df)
        n_events += df.shape[0]
        successful += 1

    progress.empty()
    if not frames:
        info.error('No data loaded')
        st.stop()
    # Remove rows which contain chain weights as it messes up schema
    return pd.concat(frames)