import joblib import time import plotly.graph_objects as go import streamlit as st import pandas as pd import numpy as np FEATS = [ 'srcip', 'sport', 'dstip', 'dsport', 'proto', #'state', I dropped this one when I trained the model 'dur', 'sbytes', 'dbytes', 'sttl', 'dttl', 'sloss', 'dloss', 'service', 'Sload', 'Dload', 'Spkts', 'Dpkts', 'swin', 'dwin', 'stcpb', 'dtcpb', 'smeansz', 'dmeansz', 'trans_depth', 'res_bdy_len', 'Sjit', 'Djit', 'Stime', 'Ltime', 'Sintpkt', 'Dintpkt', 'tcprtt', 'synack', 'ackdat', 'is_sm_ips_ports', 'ct_state_ttl', 'ct_flw_http_mthd', 'is_ftp_login', 'ct_ftp_cmd', 'ct_srv_src', 'ct_srv_dst', 'ct_dst_ltm', 'ct_src_ltm', 'ct_src_dport_ltm', 'ct_dst_sport_ltm', 'ct_dst_src_ltm', ] COLORS = [ 'aliceblue','aqua','aquamarine','azure', 'bisque','black','blanchedalmond','blue', 'blueviolet','brown','burlywood','cadetblue', 'chartreuse','chocolate','coral','cornflowerblue', 'cornsilk','crimson','cyan','darkblue','darkcyan', 'darkgoldenrod','darkgray','darkgreen', 'darkkhaki','darkmagenta','darkolivegreen','darkorange', 'darkorchid','darkred','darksalmon','darkseagreen', 'darkslateblue','darkslategray', 'darkturquoise','darkviolet','deeppink','deepskyblue', 'dimgray','dodgerblue', 'forestgreen','fuchsia','gainsboro', 'gold','goldenrod','gray','green', 'greenyellow','honeydew','hotpink','indianred','indigo', 'ivory','khaki','lavender','lavenderblush','lawngreen', 'lemonchiffon','lightblue','lightcoral','lightcyan', 'lightgoldenrodyellow','lightgray', 'lightgreen','lightpink','lightsalmon','lightseagreen', 'lightskyblue','lightslategray', 'lightsteelblue','lightyellow','lime','limegreen', 'linen','magenta','maroon','mediumaquamarine', 'mediumblue','mediumorchid','mediumpurple', 'mediumseagreen','mediumslateblue','mediumspringgreen', 'mediumturquoise','mediumvioletred','midnightblue', 'mintcream','mistyrose','moccasin','navy', 'oldlace','olive','olivedrab','orange','orangered', 'orchid','palegoldenrod','palegreen','paleturquoise', 'palevioletred','papayawhip','peachpuff','peru','pink', 'plum','powderblue','purple','red','rosybrown', 'royalblue','saddlebrown','salmon','sandybrown', 'seagreen','seashell','sienna','silver','skyblue', 'slateblue','slategray','slategrey','snow','springgreen', 'steelblue','tan','teal','thistle','tomato','turquoise', 'violet','wheat','yellow','yellowgreen' ] def build_parents(tree, visit_order, node_id2plot_id): parents = [None] parent_plot_ids = [None] directions = [None] for i in visit_order[1:]: parent = tree[tree['right']==i].index if parent.empty: p = tree[tree['left']==i].index[0] parent_plot_ids.append(str(node_id2plot_id[p])) parents.append(p) directions.append('l') else: parent_plot_ids.append(str(node_id2plot_id[parent[0]])) parents.append(parent[0]) directions.append('r') return parents, parent_plot_ids, directions def build_labels_colors(tree, visit_order, parents, parent_plot_ids, directions): labels = ['Histogram Gradient-Boosted Decision Tree'] colors = ['white'] for i, parent, parent_plot_id, direction in zip( visit_order, parents, parent_plot_ids, directions ): # skip the first one (the root) if i == 0: continue node = tree.loc[i] feat = FEATS[int(tree.loc[int(parent), 'feature_idx'])] thresh = tree.loc[int(parent), 'num_threshold'] if direction == 'l': labels.append(f"[{parent_plot_id}.L] {feat} <= {thresh}") else: labels.append(f"[{parent_plot_id}.R] {feat} > {thresh}") # colors offset = FEATS.index(feat) colors.append(COLORS[offset]) return labels, colors def build_plot(tree): #https://stackoverflow.com/questions/64393535/python-plotly-treemap-ids-format-and-how-to-display-multiple-duplicated-labels-i # if you use `ids`, then `parents` has to be in terms of `ids` visit_order = breadth_first_traverse(tree) node_id2plot_id = {node:i for i, node in enumerate(visit_order)} parents, parent_plot_ids, directions = build_parents(tree, visit_order, node_id2plot_id) labels, colors = build_labels_colors(tree, visit_order, parents, parent_plot_ids, directions) # this should just be ['0', '1', '2', . . .] plot_ids = [str(node_id2plot_id[x]) for x in visit_order] return go.Treemap( values=tree['count'].to_numpy(), labels=labels, ids=plot_ids, parents=parent_plot_ids, marker_colors=colors, ) def breadth_first_traverse(tree): """ https://www.101computing.net/breadth-first-traversal-of-a-binary-tree/ Iterative version makes more sense since I have the whole tree in a table instead of just nodes and pointers """ q = [0] visited_nodes = [] while len(q) != 0: cur = q.pop(0) visited_nodes.append(cur) if tree.loc[cur, 'left'] != 0: q.append(tree.loc[cur, 'left']) if tree.loc[cur, 'right'] != 0: q.append(tree.loc[cur, 'right']) return visited_nodes def main(): # load the data hgb = joblib.load('hgb_classifier.joblib') trees = [pd.DataFrame(x[0].nodes) for x in hgb._predictors] # make the plots graph_objs = [build_plot(tree) for tree in trees] figures = [go.Figure(graph_obj) for graph_obj in graph_objs] frames = [go.Frame(data=graph_obj) for graph_obj in graph_objs] # show them with streamlit # this puts them all on the screen at once # like each new one shows up below the previous one # instead of replacing the previous one #for fig in figures: # st.plotly_chart(fig) # time.sleep(1) # This works the way I want # but the plot is tiny # also it recalcualtes all of the plots # every time the slider value changes # # I tried to cache the plots but build_plot() takes # a DataFrame which is mutable and therefore unhashable I guess # so it won't let me cache that function # I could pack the dataframe bytes to smuggle them past that check # but whatever idx = st.slider( label='which step to show', min_value=0, max_value=len(figures)-1, value=0, step=1 ) st.plotly_chart(figures[idx]) st.markdown(f'## Tree {idx}') st.dataframe(trees[idx]) # Maybe just show a Plotly animated chart # https://plotly.com/python/animations/#using-a-slider-and-buttons # They don't really document the animation stuff on their website # but it's in here # https://raw.githubusercontent.com/plotly/plotly.js/master/dist/plot-schema.json # I guess it's only in the JS docs and hasn't made it to the Python docs yet # https://plotly.com/javascript/animations/ # trying to find stuff here instead # https://plotly.com/python-api-reference/generated/plotly.graph_objects.layout.updatemenu.html?highlight=updatemenu # this one finally set the speed # no mention of how they figured this out but thank goodness I found it # https://towardsdatascience.com/basic-animation-with-matplotlib-and-plotly-5eef4ad6c5aa ani_fig = go.Figure( data=graph_objs[0], frames=frames, layout=go.Layout( updatemenus=[{ 'type':'buttons', 'buttons':[{ 'label':'Play', 'method': 'animate', 'args':[None, { 'frame': {'duration':5000}, 'transition': {'duration': 2500} }] }] }] ) ) st.plotly_chart(ani_fig) if __name__=='__main__': main()