GEMRec-Gallery / pages /1_🖼️_Gallery.py
Ricercar's picture
new version! multiple pages!
bca2bcb
raw
history blame
15.7 kB
import streamlit as st
import numpy as np
import pandas as pd
import glob
from datasets import load_dataset, Dataset, load_from_disk
from huggingface_hub import login
import os
import requests
from bs4 import BeautifulSoup
import altair as alt
from streamlit_extras.switch_page_button import switch_page
SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'avg_rank', 'pop': 'model_download_count'}
# hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["x"])
@st.cache_resource
def altair_histogram(hist_data, sort_by, mini, maxi):
brushed = alt.selection_interval(encodings=['x'], name="brushed")
chart = (
alt.Chart(hist_data)
.mark_bar(opacity=0.7, cornerRadius=2)
.encode(alt.X(f"{sort_by}:Q", bin=alt.Bin(maxbins=25)), y="count()")
# .add_selection(brushed)
# .properties(width=800, height=300)
)
# Create a transparent rectangle for highlighting the range
highlight = (
alt.Chart(pd.DataFrame({'x1': [mini], 'x2': [maxi]}))
.mark_rect(opacity=0.3)
.encode(x='x1', x2='x2')
# .properties(width=800, height=300)
)
# Layer the chart and the highlight rectangle
layered_chart = alt.layer(chart, highlight)
return layered_chart
class GalleryApp:
def __init__(self, promptBook, images_ds):
self.promptBook = promptBook
self.images_ds = images_ds
def gallery_standard(self, items, col_num, info):
rows = len(items) // col_num + 1
containers = [st.container() for _ in range(rows)]
for idx in range(0, len(items), col_num):
row_idx = idx // col_num
with containers[row_idx]:
cols = st.columns(col_num)
for j in range(col_num):
if idx + j < len(items):
with cols[j]:
# show image
image = self.images_ds[items.iloc[idx + j]['row_idx'].item()]['image']
st.image(image, use_column_width=True)
# handel checkbox information
prompt_id = items.iloc[idx + j]['prompt_id']
modelVersion_id = items.iloc[idx + j]['modelVersion_id']
check_init = True if modelVersion_id in st.session_state.selected_dict.get(prompt_id, []) else False
# show checkbox
checked = st.checkbox('Select', key=f'select_{idx + j}', value=check_init)
if checked:
st.session_state.selected_dict[prompt_id] = st.session_state.selected_dict.get(prompt_id, []) + [modelVersion_id]
else:
try:
st.session_state.selected_dict[prompt_id].remove(modelVersion_id)
except:
pass
# show selected info
for key in info:
st.write(f"**{key}**: {items.iloc[idx + j][key]}")
def selection_panel(self, items):
selecters = st.columns([1, 4])
# select sort type
with selecters[0]:
sort_type = st.selectbox('Sort by', ['Scores', 'IDs and Names'])
if sort_type == 'Scores':
sort_by = 'weighted_score_sum'
# select other options
with selecters[1]:
if sort_type == 'IDs and Names':
sub_selecters = st.columns([3, 1])
# select sort by
with sub_selecters[0]:
sort_by = st.selectbox('Sort by',
['model_name', 'model_id', 'modelVersion_name', 'modelVersion_id'],
label_visibility='hidden')
continue_idx = 1
else:
# add custom weights
sub_selecters = st.columns([1, 1, 1, 1])
if 'score_weights' not in st.session_state:
st.session_state.score_weights = [1.0, 0.8, 0.2, 0.84]
with sub_selecters[0]:
clip_weight = st.number_input('Clip Score Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[0], step=0.1, help='the weight for normalized clip score')
with sub_selecters[1]:
rank_weight = st.number_input('Distinctiveness Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[1], step=0.1, help='the weight for average rank')
with sub_selecters[2]:
pop_weight = st.number_input('Popularity Weight', min_value=-100.0, max_value=100.0, value=st.session_state.score_weights[2], step=0.1, help='the weight for normalized popularity score')
items.loc[:, 'weighted_score_sum'] = round(items['norm_clip'] * clip_weight + items['avg_rank'] * rank_weight + items[
'norm_pop'] * pop_weight, 4)
continue_idx = 3
# select threshold
with sub_selecters[continue_idx]:
dist_threshold = st.number_input('Distinctiveness Threshold', min_value=0.0, max_value=1.0, value=st.session_state.score_weights[3], step=0.01, help='Only show models with distinctiveness score lower than this threshold, set 1.0 to show all images')
items = items[items['avg_rank'] < dist_threshold].reset_index(drop=True)
# save latest weights
st.session_state.score_weights = [clip_weight, rank_weight, pop_weight, dist_threshold]
# draw a distribution histogram
if sort_type == 'Scores':
try:
with st.expander('Show score distribution histogram and select score range'):
st.write('**Score distribution histogram**')
chart_space = st.container()
# st.write('Select the range of scores to show')
hist_data = pd.DataFrame(items[sort_by])
mini = hist_data[sort_by].min().item()
mini = mini//0.1 * 0.1
maxi = hist_data[sort_by].max().item()
maxi = maxi//0.1 * 0.1 + 0.1
st.write('**Select the range of scores to show**')
r = st.slider('Select the range of scores to show', min_value=mini, max_value=maxi, value=(mini, maxi), step=0.05, label_visibility='collapsed')
with chart_space:
st.altair_chart(altair_histogram(hist_data, sort_by, r[0], r[1]), use_container_width=True)
# event_dict = altair_component(altair_chart=altair_histogram(hist_data, sort_by))
# r = event_dict.get(sort_by)
if r:
items = items[(items[sort_by] >= r[0]) & (items[sort_by] <= r[1])].reset_index(drop=True)
# st.write(r)
except:
pass
display_options = st.columns([1, 4])
with display_options[0]:
# select order
order = st.selectbox('Order', ['Ascending', 'Descending'], index=1 if sort_type == 'Scores' else 0)
if order == 'Ascending':
order = True
else:
order = False
with display_options[1]:
# select info to show
info = st.multiselect('Show Info',
['model_download_count', 'clip_score', 'avg_rank', 'model_name', 'model_id',
'modelVersion_name', 'modelVersion_id', 'clip+rank', 'clip+pop', 'rank+pop',
'clip+rank+pop', 'weighted_score_sum'],
default=sort_by)
# apply sorting to dataframe
items = items.sort_values(by=[sort_by], ascending=order).reset_index(drop=True)
# select number of columns
col_num = st.slider('Number of columns', min_value=1, max_value=9, value=4, step=1, key='col_num')
return items, info, col_num
def sidebar(self):
with st.sidebar:
prompt_tags = self.promptBook['tag'].unique()
# sort tags by alphabetical order
prompt_tags = np.sort(prompt_tags)[::-1]
tag = st.selectbox('Select a tag', prompt_tags)
items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True)
original_prompts = np.sort(items['prompt'].unique())[::-1]
# remove the first four items in the prompt, which are mostly the same
if tag != 'abstract':
prompts = [', '.join(x.split(', ')[4:]) for x in original_prompts]
prompt = st.selectbox('Select prompt', prompts)
idx = prompts.index(prompt)
prompt_full = ', '.join(original_prompts[idx].split(', ')[:4]) + ', ' + prompt
else:
prompt_full = st.selectbox('Select prompt', original_prompts)
items = items[items['prompt'] == prompt_full].reset_index(drop=True)
prompt_id = items['prompt_id'].unique()[0]
# show image metadata
image_metadatas = ['prompt_id', 'prompt', 'negativePrompt', 'sampler', 'cfgScale', 'size', 'seed']
for key in image_metadatas:
label = ' '.join(key.split('_')).capitalize()
st.write(f"**{label}**")
if items[key][0] == ' ':
st.write('`None`')
else:
st.caption(f"{items[key][0]}")
# for tag as civitai, add civitai reference
if tag == 'civitai':
try:
st.write('**Civitai Reference**')
res = requests.get(f'https://civitai.com/images/{prompt_id.item()}')
# st.write(res.text)
soup = BeautifulSoup(res.text, 'html.parser')
image_section = soup.find('div', {'class': 'mantine-12rlksp'})
image_url = image_section.find('img')['src']
st.image(image_url, use_column_width=True)
except:
pass
return prompt_tags, tag, prompt_id, items
def app(self):
st.title('Model Visualization and Retrieval')
st.write('This is a gallery of images generated by the models')
prompt_tags, tag, prompt_id, items = self.sidebar()
# add safety check for some prompts
safety_check = True
unsafe_prompts = {}
# initialize unsafe prompts
for prompt_tag in prompt_tags:
unsafe_prompts[prompt_tag] = []
# manually add unsafe prompts
unsafe_prompts['civitai'] = [375790, 366222, 295008, 256477]
unsafe_prompts['people'] = [53]
unsafe_prompts['art'] = [23]
unsafe_prompts['abstract'] = [10, 12]
unsafe_prompts['food'] = [34]
if int(prompt_id.item()) in unsafe_prompts[tag]:
st.warning('This prompt may contain unsafe content. They might be offensive, depressing, or sexual.')
safety_check = st.checkbox('I understand that this prompt may contain unsafe content. Show these images anyway.', key=f'{prompt_id}')
if safety_check:
items, info, col_num = self.selection_panel(items)
# self.gallery_standard(items, col_num, info)
with st.form(key=f'{prompt_id}'):
# buttons = st.columns([1, 1, 1])
buttons_space = st.columns([1, 1, 1, 1])
gallery_space = st.empty()
with buttons_space[0]:
continue_btn = st.form_submit_button('Confirm Selection', use_container_width=True, type='primary')
if continue_btn:
self.submit_actions('Continue', prompt_id)
with buttons_space[1]:
select_btn = st.form_submit_button('Select All', use_container_width=True)
if select_btn:
self.submit_actions('Select', prompt_id)
with buttons_space[2]:
deselect_btn = st.form_submit_button('Deselect All', use_container_width=True)
if deselect_btn:
self.submit_actions('Deselect', prompt_id)
with buttons_space[3]:
refresh_btn = st.form_submit_button('Refresh', on_click=gallery_space.empty, use_container_width=True)
with gallery_space.container():
with st.spinner('Loading images...'):
self.gallery_standard(items, col_num, info)
def submit_actions(self, status, prompt_id):
if status == 'Select':
modelVersions = self.promptBook[self.promptBook['prompt_id'] == prompt_id]['modelVersion_id'].unique()
st.session_state.selected_dict[prompt_id] = modelVersions.tolist()
print(st.session_state.selected_dict, 'select')
elif status == 'Deselect':
st.session_state.selected_dict[prompt_id] = []
print(st.session_state.selected_dict, 'deselect')
# self.promptBook.loc[self.promptBook['prompt_id'] == prompt_id, 'checked'] = False
pass
elif status == 'Continue':
# switch_page("ranking")
pass
@st.cache_data
def load_hf_dataset():
# login to huggingface
login(token=os.environ.get("HF_TOKEN"))
# load from huggingface
roster = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferRoster', split='train'))
promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferMetadata', split='train'))
images_ds = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
# process dataset
roster = roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name',
'model_download_count']].drop_duplicates().reset_index(drop=True)
# # add 'checked' column to promptBook if not exist
# if 'checked' not in promptBook.columns:
# promptBook.loc[:, 'checked'] = False
# add 'custom_score_weights' column to promptBook if not exist
if 'weighted_score_sum' not in promptBook.columns:
promptBook.loc[:, 'weighted_score_sum'] = 0
# merge roster and promptbook
promptBook = promptBook.merge(roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']],
on=['model_id', 'modelVersion_id'], how='left')
# add column to record current row index
promptBook.loc[:, 'row_idx'] = promptBook.index
return roster, promptBook, images_ds
if __name__ == "__main__":
st.set_page_config(page_title="Model Coffer Gallery", page_icon="🖼️", layout="wide")
if 'user_id' not in st.session_state:
st.warning('Please log in first.')
home_btn = st.button('Go to Home Page')
if home_btn:
switch_page("home")
else:
st.write('You have already logged in as ' + st.session_state.user_id[0])
roster, promptBook, st.session_state["images_ds"] = load_hf_dataset()
# print(promptBook.columns)
# initialize selected_dict
if 'selected_dict' not in st.session_state:
st.session_state['selected_dict'] = {}
app = GalleryApp(promptBook=promptBook, images_ds=st.session_state.images_ds)
app.app()