Spaces:
Running
Running
new data cache method!
Browse files- app.py +103 -45
- requirements.txt +2 -1
app.py
CHANGED
@@ -3,7 +3,6 @@ import numpy as np
|
|
3 |
import random
|
4 |
import pandas as pd
|
5 |
import glob
|
6 |
-
import csv
|
7 |
from PIL import Image
|
8 |
import datasets
|
9 |
from datasets import load_dataset, Dataset, load_from_disk
|
@@ -13,13 +12,28 @@ import requests
|
|
13 |
from bs4 import BeautifulSoup
|
14 |
import re
|
15 |
|
|
|
|
|
|
|
16 |
SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'avg_rank', 'pop': 'model_download_count'}
|
17 |
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
class GalleryApp:
|
20 |
-
def __init__(self, promptBook):
|
21 |
self.promptBook = promptBook
|
22 |
-
|
23 |
|
24 |
def gallery_masonry(self, items, col_num, info):
|
25 |
cols = st.columns(col_num)
|
@@ -27,7 +41,7 @@ class GalleryApp:
|
|
27 |
# items = items.sort_values(by=['brisque'], ascending=True).reset_index(drop=True)
|
28 |
for idx in range(len(items)):
|
29 |
with cols[idx % col_num]:
|
30 |
-
image =
|
31 |
st.image(image,
|
32 |
use_column_width=True,
|
33 |
)
|
@@ -58,7 +72,7 @@ class GalleryApp:
|
|
58 |
if idx + j < len(items):
|
59 |
with cols[j]:
|
60 |
# show image
|
61 |
-
image =
|
62 |
|
63 |
st.image(image,
|
64 |
use_column_width=True,
|
@@ -184,11 +198,12 @@ class GalleryApp:
|
|
184 |
with sub_selecters[2]:
|
185 |
pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=1.0, step=0.1)
|
186 |
|
187 |
-
items.loc[:, 'weighted_score_sum'] = items['norm_clip'] * clip_weight + items['avg_rank'] * rank_weight + items[
|
188 |
-
'norm_pop'] * pop_weight
|
189 |
|
190 |
continue_idx = 3
|
191 |
|
|
|
192 |
with sub_selecters[continue_idx]:
|
193 |
order = st.selectbox('Order', ['Ascending', 'Descending'], index=1 if sort_type == 'Scores' else 0)
|
194 |
if order == 'Ascending':
|
@@ -211,6 +226,15 @@ class GalleryApp:
|
|
211 |
items = items[items['checked'] == True].reset_index(drop=True)
|
212 |
print(items)
|
213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
info = st.multiselect('Show Info',
|
215 |
['model_download_count', 'clip_score', 'avg_rank', 'model_name', 'model_id',
|
216 |
'modelVersion_name', 'modelVersion_id', 'clip+rank', 'clip+pop', 'rank+pop',
|
@@ -303,6 +327,7 @@ class GalleryApp:
|
|
303 |
|
304 |
if safety_check:
|
305 |
items, info, col_num = self.selection_panel_2(items)
|
|
|
306 |
# self.gallery_standard(items, col_num, info)
|
307 |
|
308 |
with st.form(key=f'{prompt_id}', clear_on_submit=False):
|
@@ -340,44 +365,77 @@ class GalleryApp:
|
|
340 |
dataset.push_to_hub('NYUSHPRP/ModelCofferMetadata', split='train')
|
341 |
|
342 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
if __name__ == '__main__':
|
344 |
login(token=os.environ.get("HF_TOKEN"))
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
#
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
#
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
|
|
|
|
|
|
|
|
383 |
app.app()
|
|
|
3 |
import random
|
4 |
import pandas as pd
|
5 |
import glob
|
|
|
6 |
from PIL import Image
|
7 |
import datasets
|
8 |
from datasets import load_dataset, Dataset, load_from_disk
|
|
|
12 |
from bs4 import BeautifulSoup
|
13 |
import re
|
14 |
|
15 |
+
import altair as alt
|
16 |
+
from streamlit_vega_lite import vega_lite_component, altair_component, _component_func
|
17 |
+
|
18 |
SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'avg_rank', 'pop': 'model_download_count'}
|
19 |
|
20 |
|
21 |
+
# hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["x"])
|
22 |
+
@st.cache_resource
|
23 |
+
def altair_histogram(hist_data, sort_by):
|
24 |
+
brushed = alt.selection_interval(encodings=['x'], name="brushed")
|
25 |
+
return (
|
26 |
+
alt.Chart(hist_data)
|
27 |
+
.mark_bar()
|
28 |
+
.encode(alt.X(f"{sort_by}:Q", bin=True), y="count()")
|
29 |
+
.add_selection(brushed)
|
30 |
+
.properties(width=600, height=300)
|
31 |
+
)
|
32 |
+
|
33 |
class GalleryApp:
|
34 |
+
def __init__(self, promptBook, images_ds):
|
35 |
self.promptBook = promptBook
|
36 |
+
self.images_ds = images_ds
|
37 |
|
38 |
def gallery_masonry(self, items, col_num, info):
|
39 |
cols = st.columns(col_num)
|
|
|
41 |
# items = items.sort_values(by=['brisque'], ascending=True).reset_index(drop=True)
|
42 |
for idx in range(len(items)):
|
43 |
with cols[idx % col_num]:
|
44 |
+
image = self.images_ds[items.iloc[idx]['row_idx'].item()]['image']
|
45 |
st.image(image,
|
46 |
use_column_width=True,
|
47 |
)
|
|
|
72 |
if idx + j < len(items):
|
73 |
with cols[j]:
|
74 |
# show image
|
75 |
+
image = self.images_ds[items.iloc[idx+j]['row_idx'].item()]['image']
|
76 |
|
77 |
st.image(image,
|
78 |
use_column_width=True,
|
|
|
198 |
with sub_selecters[2]:
|
199 |
pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=1.0, step=0.1)
|
200 |
|
201 |
+
items.loc[:, 'weighted_score_sum'] = round(items['norm_clip'] * clip_weight + items['avg_rank'] * rank_weight + items[
|
202 |
+
'norm_pop'] * pop_weight, 4)
|
203 |
|
204 |
continue_idx = 3
|
205 |
|
206 |
+
|
207 |
with sub_selecters[continue_idx]:
|
208 |
order = st.selectbox('Order', ['Ascending', 'Descending'], index=1 if sort_type == 'Scores' else 0)
|
209 |
if order == 'Ascending':
|
|
|
226 |
items = items[items['checked'] == True].reset_index(drop=True)
|
227 |
print(items)
|
228 |
|
229 |
+
if sort_type == 'Scores':
|
230 |
+
st.write('Select the range of scores to show')
|
231 |
+
hist_data = pd.DataFrame(items[sort_by])
|
232 |
+
event_dict = altair_component(altair_chart=altair_histogram(hist_data, sort_by))
|
233 |
+
r = event_dict.get(sort_by)
|
234 |
+
if r:
|
235 |
+
items = items[(items[sort_by] >= r[0]) & (items[sort_by] <= r[1])].reset_index(drop=True)
|
236 |
+
st.write(r)
|
237 |
+
|
238 |
info = st.multiselect('Show Info',
|
239 |
['model_download_count', 'clip_score', 'avg_rank', 'model_name', 'model_id',
|
240 |
'modelVersion_name', 'modelVersion_id', 'clip+rank', 'clip+pop', 'rank+pop',
|
|
|
327 |
|
328 |
if safety_check:
|
329 |
items, info, col_num = self.selection_panel_2(items)
|
330 |
+
|
331 |
# self.gallery_standard(items, col_num, info)
|
332 |
|
333 |
with st.form(key=f'{prompt_id}', clear_on_submit=False):
|
|
|
365 |
dataset.push_to_hub('NYUSHPRP/ModelCofferMetadata', split='train')
|
366 |
|
367 |
|
368 |
+
@st.cache_data
|
369 |
+
def load_hf_dataset():
|
370 |
+
# load from huggingface
|
371 |
+
roster = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferRoster', split='train'))
|
372 |
+
promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferMetadata', split='train'))
|
373 |
+
images_ds = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
|
374 |
+
|
375 |
+
# process dataset
|
376 |
+
roster = roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name',
|
377 |
+
'model_download_count']].drop_duplicates().reset_index(drop=True)
|
378 |
+
|
379 |
+
# add 'checked' column to promptBook if not exist
|
380 |
+
if 'checked' not in promptBook.columns:
|
381 |
+
promptBook.loc[:, 'checked'] = False
|
382 |
+
|
383 |
+
# add 'custom_score_weights' column to promptBook if not exist
|
384 |
+
if 'weighted_score_sum' not in promptBook.columns:
|
385 |
+
promptBook.loc[:, 'weighted_score_sum'] = 0
|
386 |
+
|
387 |
+
# merge roster and promptbook
|
388 |
+
promptBook = promptBook.merge(roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']],
|
389 |
+
on=['model_id', 'modelVersion_id'], how='left')
|
390 |
+
|
391 |
+
# add column to record current row index
|
392 |
+
promptBook.loc[:, 'row_idx'] = promptBook.index
|
393 |
+
|
394 |
+
return roster, promptBook, images_ds
|
395 |
+
|
396 |
+
|
397 |
if __name__ == '__main__':
|
398 |
login(token=os.environ.get("HF_TOKEN"))
|
399 |
+
st.set_page_config(layout="wide")
|
400 |
+
|
401 |
+
# if 'roster' not in st.session_state:
|
402 |
+
# print('loading roster')
|
403 |
+
# # st.session_state.roster = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferRoster', split='train'))
|
404 |
+
# st.session_state.roster = pd.DataFrame(load_from_disk(os.path.join(os.getcwd(), 'data', 'roster')))
|
405 |
+
# st.session_state.roster = st.session_state.roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name',
|
406 |
+
# 'model_download_count']].drop_duplicates().reset_index(drop=True)
|
407 |
+
# # add model download count from roster to promptbook dataframe
|
408 |
+
# if 'promptBook' not in st.session_state:
|
409 |
+
# print('loading promptBook')
|
410 |
+
#
|
411 |
+
# st.session_state.promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferMetadata', split='train'))
|
412 |
+
# # add 'checked' column to promptBook if not exist
|
413 |
+
# if 'checked' not in st.session_state.promptBook.columns:
|
414 |
+
# st.session_state.promptBook.loc[:, 'checked'] = False
|
415 |
+
#
|
416 |
+
# # add 'custom_score_weights' column to promptBook if not exist
|
417 |
+
# if 'weighted_score_sum' not in st.session_state.promptBook.columns:
|
418 |
+
# st.session_state.promptBook.loc[:, 'weighted_score_sum'] = 0
|
419 |
+
#
|
420 |
+
# st.session_state.images = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
|
421 |
+
# # st.session_state.images = load_dataset('NYUSHPRP/ModelCofferPromptBook', split='train', streaming=True)
|
422 |
+
# print(st.session_state.images)
|
423 |
+
# print('images loaded')
|
424 |
+
# # st.session_state.promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferPromptBook', split='train'))
|
425 |
+
# st.session_state.promptBook = st.session_state.promptBook.merge(st.session_state.roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']], on=['model_id', 'modelVersion_id'], how='left')
|
426 |
+
#
|
427 |
+
# # add column to record current row index
|
428 |
+
# st.session_state.promptBook['row_idx'] = st.session_state.promptBook.index
|
429 |
+
# print('promptBook loaded')
|
430 |
+
# # print(st.session_state.promptBook)
|
431 |
+
#
|
432 |
+
# check_roster_error = False
|
433 |
+
# if check_roster_error:
|
434 |
+
# # print all rows with the same model_id and modelVersion_id but different model_download_count in roster
|
435 |
+
# print(st.session_state.roster[st.session_state.roster.duplicated(subset=['model_id', 'modelVersion_id'], keep=False)].sort_values(by=['model_id', 'modelVersion_id']))
|
436 |
+
roster, promptBook, images_ds = load_hf_dataset()
|
437 |
+
# if 'images' not in st.session_state:
|
438 |
+
# st.session_state.images = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
|
439 |
+
|
440 |
+
app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
|
441 |
app.app()
|
requirements.txt
CHANGED
@@ -2,4 +2,5 @@ huggingface_hub
|
|
2 |
streamlit-elements==0.1.0
|
3 |
streamlit-extras
|
4 |
altair<5
|
5 |
-
streamlit-plotly-events
|
|
|
|
2 |
streamlit-elements==0.1.0
|
3 |
streamlit-extras
|
4 |
altair<5
|
5 |
+
streamlit-plotly-events
|
6 |
+
streamlit-vega-lite
|