Spaces:
Sleeping
Sleeping
update bar chart
Browse files- app.py +46 -52
- test_altair.py +22 -47
app.py
CHANGED
|
@@ -20,16 +20,38 @@ SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'avg_rank', 'pop': 'model_do
|
|
| 20 |
|
| 21 |
# hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["x"])
|
| 22 |
@st.cache_resource
|
| 23 |
-
def altair_histogram(hist_data, sort_by):
|
| 24 |
brushed = alt.selection_interval(encodings=['x'], name="brushed")
|
| 25 |
-
|
|
|
|
| 26 |
alt.Chart(hist_data)
|
| 27 |
-
.mark_bar()
|
| 28 |
-
.encode(alt.X(f"{sort_by}:Q", bin=
|
| 29 |
-
.add_selection(brushed)
|
| 30 |
-
.properties(width=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
)
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
class GalleryApp:
|
| 34 |
def __init__(self, promptBook, images_ds):
|
| 35 |
self.promptBook = promptBook
|
|
@@ -169,7 +191,6 @@ class GalleryApp:
|
|
| 169 |
|
| 170 |
return items, info, col_num
|
| 171 |
|
| 172 |
-
|
| 173 |
def selection_panel_2(self, items):
|
| 174 |
selecters = st.columns([1, 5])
|
| 175 |
|
|
@@ -226,14 +247,25 @@ class GalleryApp:
|
|
| 226 |
items = items[items['checked'] == True].reset_index(drop=True)
|
| 227 |
print(items)
|
| 228 |
|
|
|
|
| 229 |
if sort_type == 'Scores':
|
| 230 |
-
st.
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
info = st.multiselect('Show Info',
|
| 239 |
['model_download_count', 'clip_score', 'avg_rank', 'model_name', 'model_id',
|
|
@@ -308,7 +340,6 @@ class GalleryApp:
|
|
| 308 |
except:
|
| 309 |
pass
|
| 310 |
|
| 311 |
-
|
| 312 |
# add safety check for some prompts
|
| 313 |
safety_check = True
|
| 314 |
unsafe_prompts = {}
|
|
@@ -398,44 +429,7 @@ if __name__ == '__main__':
|
|
| 398 |
login(token=os.environ.get("HF_TOKEN"))
|
| 399 |
st.set_page_config(layout="wide")
|
| 400 |
|
| 401 |
-
# if 'roster' not in st.session_state:
|
| 402 |
-
# print('loading roster')
|
| 403 |
-
# # st.session_state.roster = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferRoster', split='train'))
|
| 404 |
-
# st.session_state.roster = pd.DataFrame(load_from_disk(os.path.join(os.getcwd(), 'data', 'roster')))
|
| 405 |
-
# st.session_state.roster = st.session_state.roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name',
|
| 406 |
-
# 'model_download_count']].drop_duplicates().reset_index(drop=True)
|
| 407 |
-
# # add model download count from roster to promptbook dataframe
|
| 408 |
-
# if 'promptBook' not in st.session_state:
|
| 409 |
-
# print('loading promptBook')
|
| 410 |
-
#
|
| 411 |
-
# st.session_state.promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferMetadata', split='train'))
|
| 412 |
-
# # add 'checked' column to promptBook if not exist
|
| 413 |
-
# if 'checked' not in st.session_state.promptBook.columns:
|
| 414 |
-
# st.session_state.promptBook.loc[:, 'checked'] = False
|
| 415 |
-
#
|
| 416 |
-
# # add 'custom_score_weights' column to promptBook if not exist
|
| 417 |
-
# if 'weighted_score_sum' not in st.session_state.promptBook.columns:
|
| 418 |
-
# st.session_state.promptBook.loc[:, 'weighted_score_sum'] = 0
|
| 419 |
-
#
|
| 420 |
-
# st.session_state.images = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
|
| 421 |
-
# # st.session_state.images = load_dataset('NYUSHPRP/ModelCofferPromptBook', split='train', streaming=True)
|
| 422 |
-
# print(st.session_state.images)
|
| 423 |
-
# print('images loaded')
|
| 424 |
-
# # st.session_state.promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferPromptBook', split='train'))
|
| 425 |
-
# st.session_state.promptBook = st.session_state.promptBook.merge(st.session_state.roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']], on=['model_id', 'modelVersion_id'], how='left')
|
| 426 |
-
#
|
| 427 |
-
# # add column to record current row index
|
| 428 |
-
# st.session_state.promptBook['row_idx'] = st.session_state.promptBook.index
|
| 429 |
-
# print('promptBook loaded')
|
| 430 |
-
# # print(st.session_state.promptBook)
|
| 431 |
-
#
|
| 432 |
-
# check_roster_error = False
|
| 433 |
-
# if check_roster_error:
|
| 434 |
-
# # print all rows with the same model_id and modelVersion_id but different model_download_count in roster
|
| 435 |
-
# print(st.session_state.roster[st.session_state.roster.duplicated(subset=['model_id', 'modelVersion_id'], keep=False)].sort_values(by=['model_id', 'modelVersion_id']))
|
| 436 |
roster, promptBook, images_ds = load_hf_dataset()
|
| 437 |
-
# if 'images' not in st.session_state:
|
| 438 |
-
# st.session_state.images = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
|
| 439 |
|
| 440 |
app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
|
| 441 |
app.app()
|
|
|
|
| 20 |
|
| 21 |
# hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["x"])
|
| 22 |
@st.cache_resource
|
| 23 |
+
def altair_histogram(hist_data, sort_by, mini, maxi):
|
| 24 |
brushed = alt.selection_interval(encodings=['x'], name="brushed")
|
| 25 |
+
|
| 26 |
+
chart = (
|
| 27 |
alt.Chart(hist_data)
|
| 28 |
+
.mark_bar(opacity=0.7, cornerRadius=2)
|
| 29 |
+
.encode(alt.X(f"{sort_by}:Q", bin=alt.Bin(maxbins=20)), y="count()")
|
| 30 |
+
# .add_selection(brushed)
|
| 31 |
+
# .properties(width=800, height=300)
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Create a transparent rectangle for highlighting the range
|
| 35 |
+
highlight = (
|
| 36 |
+
alt.Chart(pd.DataFrame({'x1': [mini], 'x2': [maxi]}))
|
| 37 |
+
.mark_rect(opacity=0.3)
|
| 38 |
+
.encode(x='x1', x2='x2')
|
| 39 |
+
# .properties(width=800, height=300)
|
| 40 |
)
|
| 41 |
|
| 42 |
+
# Layer the chart and the highlight rectangle
|
| 43 |
+
layered_chart = alt.layer(chart, highlight)
|
| 44 |
+
|
| 45 |
+
return layered_chart
|
| 46 |
+
|
| 47 |
+
# return (
|
| 48 |
+
# alt.Chart(hist_data)
|
| 49 |
+
# .mark_bar()
|
| 50 |
+
# .encode(alt.X(f"{sort_by}:Q", bin=alt.Bin(maxbins=20)), y="count()")
|
| 51 |
+
# .add_selection(brushed)
|
| 52 |
+
# .properties(width=600, height=300)
|
| 53 |
+
# )
|
| 54 |
+
|
| 55 |
class GalleryApp:
|
| 56 |
def __init__(self, promptBook, images_ds):
|
| 57 |
self.promptBook = promptBook
|
|
|
|
| 191 |
|
| 192 |
return items, info, col_num
|
| 193 |
|
|
|
|
| 194 |
def selection_panel_2(self, items):
|
| 195 |
selecters = st.columns([1, 5])
|
| 196 |
|
|
|
|
| 247 |
items = items[items['checked'] == True].reset_index(drop=True)
|
| 248 |
print(items)
|
| 249 |
|
| 250 |
+
# draw a distribution histogram
|
| 251 |
if sort_type == 'Scores':
|
| 252 |
+
with st.expander('Show score distribution histogram and select score range'):
|
| 253 |
+
st.write('**Score distribution histogram**')
|
| 254 |
+
chart_space = st.container()
|
| 255 |
+
# st.write('Select the range of scores to show')
|
| 256 |
+
hist_data = pd.DataFrame(items[sort_by])
|
| 257 |
+
mini = hist_data[sort_by].min().item()
|
| 258 |
+
maxi = hist_data[sort_by].max().item()
|
| 259 |
+
st.write('**Select the range of scores to show**')
|
| 260 |
+
r = st.slider('Select the range of scores to show', min_value=mini, max_value=maxi, value=(mini, maxi), label_visibility='collapsed')
|
| 261 |
+
with chart_space:
|
| 262 |
+
st.altair_chart(altair_histogram(hist_data, sort_by, r[0], r[1]), use_container_width=True)
|
| 263 |
+
# event_dict = altair_component(altair_chart=altair_histogram(hist_data, sort_by))
|
| 264 |
+
# r = event_dict.get(sort_by)
|
| 265 |
+
if r:
|
| 266 |
+
items = items[(items[sort_by] >= r[0]) & (items[sort_by] <= r[1])].reset_index(drop=True)
|
| 267 |
+
# st.write(r)
|
| 268 |
+
|
| 269 |
|
| 270 |
info = st.multiselect('Show Info',
|
| 271 |
['model_download_count', 'clip_score', 'avg_rank', 'model_name', 'model_id',
|
|
|
|
| 340 |
except:
|
| 341 |
pass
|
| 342 |
|
|
|
|
| 343 |
# add safety check for some prompts
|
| 344 |
safety_check = True
|
| 345 |
unsafe_prompts = {}
|
|
|
|
| 429 |
login(token=os.environ.get("HF_TOKEN"))
|
| 430 |
st.set_page_config(layout="wide")
|
| 431 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
roster, promptBook, images_ds = load_hf_dataset()
|
|
|
|
|
|
|
| 433 |
|
| 434 |
app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
|
| 435 |
app.app()
|
test_altair.py
CHANGED
|
@@ -1,50 +1,25 @@
|
|
| 1 |
-
import altair as alt
|
| 2 |
import streamlit as st
|
|
|
|
| 3 |
import pandas as pd
|
| 4 |
-
import numpy as np
|
| 5 |
-
|
| 6 |
-
from streamlit_vega_lite import vega_lite_component, altair_component, _component_func
|
| 7 |
-
|
| 8 |
-
hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["abc"])
|
| 9 |
-
print(hist_data)
|
| 10 |
-
|
| 11 |
-
@st.cache_resource
|
| 12 |
-
def altair_histogram():
|
| 13 |
-
brushed = alt.selection_interval(encodings=["x"], name="brushed")
|
| 14 |
-
|
| 15 |
-
return (
|
| 16 |
-
alt.Chart(hist_data)
|
| 17 |
-
.mark_bar()
|
| 18 |
-
.encode(alt.X("abc:Q", bin=True), y="count()")
|
| 19 |
-
.add_selection(brushed)
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
chart = altair_histogram()
|
| 23 |
-
res = st.altair_chart(chart, use_container_width=True)
|
| 24 |
-
# print(res)
|
| 25 |
-
event_dict = altair_component(altair_chart=altair_histogram())
|
| 26 |
-
chart_dict = chart.to_dict()
|
| 27 |
-
print(chart_dict)
|
| 28 |
-
altair_chart = chart.copy()
|
| 29 |
-
datasets = {}
|
| 30 |
-
|
| 31 |
-
def id_transform(data):
|
| 32 |
-
"""Altair data transformer that returns a fake named dataset with the
|
| 33 |
-
object id."""
|
| 34 |
-
name = f"d{id(data)}"
|
| 35 |
-
datasets[name] = data
|
| 36 |
-
return {"name": name}
|
| 37 |
-
|
| 38 |
-
alt.data_transformers.register("id", id_transform)
|
| 39 |
-
|
| 40 |
-
with alt.data_transformers.enable("id"):
|
| 41 |
-
chart_dict = altair_chart.to_dict()
|
| 42 |
-
# st.write(event_dict)
|
| 43 |
-
|
| 44 |
-
event_dict = _component_func(spec=chart_dict, **datasets, key=None, default={})
|
| 45 |
-
# print(chart_dict)
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
import altair as alt
|
| 3 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
+
# Generate random data for the chart
|
| 6 |
+
data = pd.DataFrame({
|
| 7 |
+
'Category': ['A', 'B', 'C', 'D', 'E'],
|
| 8 |
+
'Value': [0.2, 0.5, 0.8, 1.2, 1.5]
|
| 9 |
+
})
|
| 10 |
+
|
| 11 |
+
# Define the color scale for the bars
|
| 12 |
+
color_scale = alt.Scale(
|
| 13 |
+
domain=[0, 1], # Values between 0 and 1 will be blue
|
| 14 |
+
range=['steelblue', 'lightgray']
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# Create the bar chart using Altair
|
| 18 |
+
chart = alt.Chart(data).mark_bar().encode(
|
| 19 |
+
x='Category',
|
| 20 |
+
y='Value',
|
| 21 |
+
color=alt.Color('Value', scale=color_scale)
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Render the chart using Streamlit
|
| 25 |
+
st.altair_chart(chart, use_container_width=True)
|