File size: 9,378 Bytes
101093d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import os
import warnings
import re
import tqdm
import wandb
from traceback import print_exc
import plotly.express as px
import pandas as pd
from concurrent.futures import ProcessPoolExecutor

import opendashboards.utils.utils as utils

from IPython.display import display

api= wandb.Api(timeout=60)
wandb.login(anonymous="allow")

def pull_wandb_runs(project='openvalidators', filters=None, min_steps=50, max_steps=100_000, ntop=10, summary_filters=None ):
    # TODO: speed this up by storing older runs
    
    all_runs = api.runs(project, filters=filters)
    print(f'Using {ntop}/{len(all_runs)} runs with more than {min_steps} events')
    pbar = tqdm.tqdm(all_runs)
    runs = []
    n_events = 0
    successful = 0
    for i, run in enumerate(pbar):

        summary = run.summary
        if summary_filters is not None and not summary_filters(summary):
            continue
        step = summary.get('_step',0)
        if step < min_steps or step > max_steps:
            # warnings.warn(f'Skipped run `{run.name}` because it contains {step} events (<{min_steps})')
            continue

        prog_msg = f'Loading data {i/len(all_runs)*100:.0f}% ({successful}/{len(all_runs)} runs, {n_events} events)'
        pbar.set_description(f'{prog_msg}... **fetching** `{run.name}`')

        duration = summary.get('_runtime')
        end_time = summary.get('_timestamp')
        # extract values for selected tags
        rules = {'hotkey': re.compile('^[0-9a-z]{48}$',re.IGNORECASE), 'version': re.compile('^\\d\.\\d+\.\\d+$'), 'spec_version': re.compile('\\d{4}$')}
        tags = {k: tag for k, rule in rules.items() for tag in run.tags if rule.match(tag)}
        # include bool flag for remaining tags
        tags.update({k: True for k in run.tags if k not in tags.keys() and k not in tags.values()})

        runs.append({
            'state': run.state,
            'num_steps': step,
            'num_completions': step*sum(len(v) for k, v in run.summary.items() if k.endswith('completions') and isinstance(v, list)),
            'entity': run.entity,
            'user': run.user.name,
            'username': run.user.username,
            'run_id': run.id,
            'run_name': run.name,
            'project': run.project,
            'run_url': run.url,
            'run_path': os.path.join(run.entity, run.project, run.id),
            'start_time': pd.to_datetime(end_time-duration, unit="s"),
            'end_time': pd.to_datetime(end_time, unit="s"),
            'duration': pd.to_timedelta(duration, unit="s").round('s'),
            **tags
        })
        n_events += step
        successful += 1
        if successful >= ntop:
            break

    return pd.DataFrame(runs).astype({'state': 'category', 'hotkey': 'category', 'version': 'category', 'spec_version': 'category'})

def plot_gantt(df_runs):
    fig = px.timeline(df_runs,
                x_start="start_time", x_end="end_time", y="username", color="state",
                title="Timeline of Runs",
                category_orders={'run_name': df_runs.run_name.unique()},#,'username': sorted(df_runs.username.unique())},
                hover_name="run_name",
                hover_data=['hotkey','user','username','run_id','num_steps','num_completions'],
                color_discrete_map={'running': 'green', 'finished': 'grey', 'killed':'blue', 'crashed':'orange', 'failed': 'red'},
                opacity=0.3,
                width=1200,
                height=800,
                template="plotly_white",
    )
    fig.update_yaxes(tickfont_size=8, title='')
    fig.show()

def load_data(run_id, run_path=None, load=True, save=False, timeout=30):

    file_path = os.path.join('data/runs/',f'history-{run_id}.csv')

    if load and os.path.exists(file_path):
        df = pd.read_csv(file_path, nrows=None)
        # filter out events with missing step length
        df = df.loc[df.step_length.notna()]

        # detect list columns which as stored as strings
        list_cols = [c for c in df.columns if df[c].dtype == "object" and df[c].str.startswith("[").all()]
        # convert string representation of list to list
        df[list_cols] = df[list_cols].applymap(eval, na_action='ignore')

    else:
        # Download the history from wandb and add metadata
        run = api.run(run_path)
        df = pd.DataFrame(list(run.scan_history()))

        print(f'Downloaded {df.shape[0]} events from {run_path!r} with id {run_id!r}')

        if save:
            df.to_csv(file_path, index=False)

    # Convert timestamp to datetime.
    df._timestamp = pd.to_datetime(df._timestamp, unit="s")
    return df.sort_values("_timestamp")


def calculate_stats(df_long, rm_failed=True, rm_zero_reward=True, freq='H', save_path=None ):

    df_long._timestamp = pd.to_datetime(df_long._timestamp)
    # if dataframe has columns such as followup_completions and answer_completions, convert to multiple rows
    if 'completions' not in df_long.columns:
        df_long.set_index(['_timestamp','run_id'], inplace=True)
        df_schema = pd.concat([
            df_long[['followup_completions','followup_rewards']].rename(columns={'followup_completions':'completions', 'followup_rewards':'rewards'}),
            df_long[['answer_completions','answer_rewards']].rename(columns={'answer_completions':'completions', 'answer_rewards':'rewards'})
        ])
        df_long = df_schema.reset_index()

    if rm_failed:
        df_long = df_long.loc[ df_long.completions.str.len()>0 ]

    if rm_zero_reward:
        df_long = df_long.loc[ df_long.rewards>0 ]

    print(f'Calculating stats for dataframe with shape {df_long.shape}')

    g = df_long.groupby([pd.Grouper(key='_timestamp', axis=0, freq=freq), 'run_id'])

    stats = g.agg({'completions':['nunique','count'], 'rewards':['sum','mean','std']})

    stats.columns = ['_'.join(c) for c in stats.columns]
    stats['completions_diversity'] = stats['completions_nunique'] / stats['completions_count']
    stats = stats.reset_index()

    if save_path:        
        stats.to_csv(save_path, index=False)

    return stats


def clean_data(df):
    return df.dropna(subset=df.filter(regex='completions|rewards').columns, how='any').dropna(axis=1, how='all')

def explode_data(df):
    list_cols = utils.get_list_col_lengths(df)
    return utils.explode_data(df, list(list_cols.keys())).apply(pd.to_numeric, errors='ignore')


def process(run, load=True, save=False, freq='H'):

    try:
      
        stats_path = f'data/aggs/stats-{run["run_id"]}.csv'
        if os.path.exists(stats_path):
            print(f'Loaded stats file {stats_path}')
            return pd.read_csv(stats_path)

        # Load data and add extra columns from wandb run
        df = load_data(run_id=run['run_id'],
                    run_path=run['run_path'],
                    load=load,
                    save=save, 
                    save = (run['state'] != 'running') & run['end_time']
                    ).assign(**run.to_dict())
        # Clean and explode dataframe
        df_long = explode_data(clean_data(df))
        # Remove original dataframe from memory
        del df
        # Get and save stats
        return calculate_stats(df_long, freq=freq, save_path=stats_path)
    
    except Exception as e:
        print(f'Error processing run {run["run_id"]}: {e}')

if __name__ == '__main__':

    # TODO: flag to overwrite runs that were running when downloaded and saved: check if file date is older than run end time.
    
    filters = None# {"tags": {"$in": [f'1.1.{i}' for i in range(10)]}}
    # filters={'tags': {'$in': ['5F4tQyWrhfGVcNhoqeiNsR6KjD4wMZ2kfhLj4oHYuyHbZAc3']}} # Is foundation validator
    df_runs = pull_wandb_runs(ntop=500, filters=filters)#summary_filters=lambda s: s.get('augment_prompt'))

    os.makedirs('data/runs/', exist_ok=True)
    os.makedirs('data/aggs/', exist_ok=True)
    df_runs.to_csv('data/wandb.csv', index=False)
    
    display(df_runs)
    plot_gantt(df_runs)

    with ProcessPoolExecutor(max_workers=min(32, df_runs.shape[0])) as executor:
        futures = [executor.submit(process, run, load=True, save=True) for _, run in df_runs.iterrows()]

        # Use tqdm to add a progress bar
        results = []
        with tqdm.tqdm(total=len(futures)) as pbar:
            for future in futures:
                try:
                    result = future.result()
                    results.append(result)
                except Exception as e:
                    print(f'generated an exception: {print_exc(e)}')
                pbar.update(1)

    if not results:
        raise ValueError('No runs were successfully processed.')

   # Concatenate the results into a single dataframe
    df = pd.concat(results, ignore_index=True)

    df.to_csv('data/processed.csv', index=False)

    display(df)

    fig = px.line(df.astype({'_timestamp':str}),
              x='_timestamp',
              y='completions_diversity',
            #   y=['Unique','Total'],
        line_group='run_id',
        # color='hotkey',
        # color_discrete_sequence=px.colors.sequential.YlGnBu,
        title='Completion Diversity over Time',
        labels={'_timestamp':'', 'completions_diversity':'Diversity', 'uids':'UID','value':'counts', 'variable':'Completions'},
        width=800, height=600,
        template='plotly_white',
        ).update_traces(opacity=0.3)
    fig.show()