hgbdt-viz / streamlit_viz.py
none
Working version of the streamlit animation
045d7d4
raw
history blame
7.47 kB
import joblib
import time
import plotly.graph_objects as go
import streamlit as st
import pandas as pd
import numpy as np
FEATS = [
'srcip',
'sport',
'dstip',
'dsport',
'proto',
#'state', I dropped this one when I trained the model
'dur',
'sbytes',
'dbytes',
'sttl',
'dttl',
'sloss',
'dloss',
'service',
'Sload',
'Dload',
'Spkts',
'Dpkts',
'swin',
'dwin',
'stcpb',
'dtcpb',
'smeansz',
'dmeansz',
'trans_depth',
'res_bdy_len',
'Sjit',
'Djit',
'Stime',
'Ltime',
'Sintpkt',
'Dintpkt',
'tcprtt',
'synack',
'ackdat',
'is_sm_ips_ports',
'ct_state_ttl',
'ct_flw_http_mthd',
'is_ftp_login',
'ct_ftp_cmd',
'ct_srv_src',
'ct_srv_dst',
'ct_dst_ltm',
'ct_src_ltm',
'ct_src_dport_ltm',
'ct_dst_sport_ltm',
'ct_dst_src_ltm',
]
COLORS = [
'aliceblue','aqua','aquamarine','azure',
'bisque','black','blanchedalmond','blue',
'blueviolet','brown','burlywood','cadetblue',
'chartreuse','chocolate','coral','cornflowerblue',
'cornsilk','crimson','cyan','darkblue','darkcyan',
'darkgoldenrod','darkgray','darkgreen',
'darkkhaki','darkmagenta','darkolivegreen','darkorange',
'darkorchid','darkred','darksalmon','darkseagreen',
'darkslateblue','darkslategray',
'darkturquoise','darkviolet','deeppink','deepskyblue',
'dimgray','dodgerblue',
'forestgreen','fuchsia','gainsboro',
'gold','goldenrod','gray','green',
'greenyellow','honeydew','hotpink','indianred','indigo',
'ivory','khaki','lavender','lavenderblush','lawngreen',
'lemonchiffon','lightblue','lightcoral','lightcyan',
'lightgoldenrodyellow','lightgray',
'lightgreen','lightpink','lightsalmon','lightseagreen',
'lightskyblue','lightslategray',
'lightsteelblue','lightyellow','lime','limegreen',
'linen','magenta','maroon','mediumaquamarine',
'mediumblue','mediumorchid','mediumpurple',
'mediumseagreen','mediumslateblue','mediumspringgreen',
'mediumturquoise','mediumvioletred','midnightblue',
'mintcream','mistyrose','moccasin','navy',
'oldlace','olive','olivedrab','orange','orangered',
'orchid','palegoldenrod','palegreen','paleturquoise',
'palevioletred','papayawhip','peachpuff','peru','pink',
'plum','powderblue','purple','red','rosybrown',
'royalblue','saddlebrown','salmon','sandybrown',
'seagreen','seashell','sienna','silver','skyblue',
'slateblue','slategray','slategrey','snow','springgreen',
'steelblue','tan','teal','thistle','tomato','turquoise',
'violet','wheat','yellow','yellowgreen'
]
def build_parents(tree, visit_order, node_id2plot_id):
parents = [None]
parent_plot_ids = [None]
directions = [None]
for i in visit_order[1:]:
parent = tree[tree['right']==i].index
if parent.empty:
p = tree[tree['left']==i].index[0]
parent_plot_ids.append(str(node_id2plot_id[p]))
parents.append(p)
directions.append('l')
else:
parent_plot_ids.append(str(node_id2plot_id[parent[0]]))
parents.append(parent[0])
directions.append('r')
return parents, parent_plot_ids, directions
def build_labels_colors(tree, visit_order, parents, parent_plot_ids, directions):
labels = ['Histogram Gradient-Boosted Decision Tree']
colors = ['white']
for i, parent, parent_plot_id, direction in zip(
visit_order,
parents,
parent_plot_ids,
directions
):
# skip the first one (the root)
if i == 0:
continue
node = tree.loc[i]
feat = FEATS[int(tree.loc[int(parent), 'feature_idx'])]
thresh = tree.loc[int(parent), 'num_threshold']
if direction == 'l':
labels.append(f"[{parent_plot_id}.L] {feat} <= {thresh}")
else:
labels.append(f"[{parent_plot_id}.R] {feat} > {thresh}")
# colors
offset = FEATS.index(feat)
colors.append(COLORS[offset])
return labels, colors
def build_plot(tree):
#https://stackoverflow.com/questions/64393535/python-plotly-treemap-ids-format-and-how-to-display-multiple-duplicated-labels-i
# if you use `ids`, then `parents` has to be in terms of `ids`
visit_order = breadth_first_traverse(tree)
node_id2plot_id = {node:i for i, node in enumerate(visit_order)}
parents, parent_plot_ids, directions = build_parents(tree, visit_order, node_id2plot_id)
labels, colors = build_labels_colors(tree, visit_order, parents, parent_plot_ids, directions)
# this should just be ['0', '1', '2', . . .]
plot_ids = [str(node_id2plot_id[x]) for x in visit_order]
return go.Treemap(
values=tree['count'].to_numpy(),
labels=labels,
ids=plot_ids,
parents=parent_plot_ids,
marker_colors=colors,
)
def breadth_first_traverse(tree):
"""
https://www.101computing.net/breadth-first-traversal-of-a-binary-tree/
Iterative version makes more sense since I have the whole tree in a table
instead of just nodes and pointers
"""
q = [0]
visited_nodes = []
while len(q) != 0:
cur = q.pop(0)
visited_nodes.append(cur)
if tree.loc[cur, 'left'] != 0:
q.append(tree.loc[cur, 'left'])
if tree.loc[cur, 'right'] != 0:
q.append(tree.loc[cur, 'right'])
return visited_nodes
def main():
# load the data
hgb = joblib.load('hgb_classifier.joblib')
trees = [pd.DataFrame(x[0].nodes) for x in hgb._predictors]
# make the plots
graph_objs = [build_plot(tree) for tree in trees]
figures = [go.Figure(graph_obj) for graph_obj in graph_objs]
frames = [go.Frame(data=graph_obj) for graph_obj in graph_objs]
# show them with streamlit
# this puts them all on the screen at once
# like each new one shows up below the previous one
# instead of replacing the previous one
#for fig in figures:
# st.plotly_chart(fig)
# time.sleep(1)
# This works the way I want
# but the plot is tiny
# also it recalcualtes all of the plots
# every time the slider value changes
#
# I tried to cache the plots but build_plot() takes
# a DataFrame which is mutable and therefore unhashable I guess
# so it won't let me cache that function
# I could pack the dataframe bytes to smuggle them past that check
# but whatever
idx = st.slider(
label='which step to show',
min_value=0,
max_value=len(figures)-1,
value=0,
step=1
)
st.plotly_chart(figures[idx])
st.markdown(f'## Tree {idx}')
st.dataframe(trees[idx])
# Maybe just show a Plotly animated chart
# https://plotly.com/python/animations/#using-a-slider-and-buttons
# They don't really document the animation stuff on their website
# but it's in here
# https://raw.githubusercontent.com/plotly/plotly.js/master/dist/plot-schema.json
# I guess it's only in the JS docs and hasn't made it to the Python docs yet
# https://plotly.com/javascript/animations/
# trying to find stuff here instead
# https://plotly.com/python-api-reference/generated/plotly.graph_objects.layout.updatemenu.html?highlight=updatemenu
# this one finally set the speed
# no mention of how they figured this out but thank goodness I found it
# https://towardsdatascience.com/basic-animation-with-matplotlib-and-plotly-5eef4ad6c5aa
ani_fig = go.Figure(
data=graph_objs[0],
frames=frames,
layout=go.Layout(
updatemenus=[{
'type':'buttons',
'buttons':[{
'label':'Play',
'method': 'animate',
'args':[None, {
'frame': {'duration':5000},
'transition': {'duration': 2500}
}]
}]
}]
)
)
st.plotly_chart(ani_fig)
if __name__=='__main__':
main()