Spaces:
Running
Running
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
from streamlit_elements import elements, mui, html, dashboard, nivo | |
from streamlit_extras.switch_page_button import switch_page | |
from pages.Gallery import load_hf_dataset | |
class RankingApp: | |
def __init__(self, promptBook, images_endpoint, batch_size=4): | |
self.promptBook = promptBook | |
self.images_endpoint = images_endpoint | |
self.batch_size = batch_size | |
# self.batch_num = len(self.promptBook) // self.batch_size | |
# self.batch_num += 1 if len(self.promptBook) % self.batch_size != 0 else 0 | |
if 'counter' not in st.session_state: | |
st.session_state.counter = 0 | |
def sidebar(self): | |
with st.sidebar: | |
prompt_tags = self.promptBook['tag'].unique() | |
prompt_tags = np.sort(prompt_tags) | |
tag = st.selectbox('Select a prompt tag', prompt_tags) | |
items = self.promptBook[self.promptBook['tag'] == tag].reset_index(drop=True) | |
prompts = np.sort(items['prompt'].unique())[::-1] | |
selected_prompt = st.selectbox('Select a prompt', prompts) | |
items = items[items['prompt'] == selected_prompt].reset_index(drop=True) | |
prompt_id = items['prompt_id'].unique()[0] | |
with st.form(key='prompt_form'): | |
# input image metadata | |
prompt = st.text_area('Prompt', selected_prompt, height=150, key='prompt', disabled=True) | |
negative_prompt = st.text_area('Negative Prompt', items['negativePrompt'].unique()[0], height=150, key='negative_prompt', disabled=True) | |
st.form_submit_button('Generate Images', type='primary', use_container_width=True) | |
return prompt_tags, tag, prompt_id, items | |
def draggable_images(self, items, layout='portrait'): | |
# init ranking by the order of items | |
if 'ranking' not in st.session_state: | |
st.session_state.ranking = {} | |
for i in range(len(items)): | |
st.session_state.ranking[str(items['image_id'][i])] = i | |
print(items) | |
with elements('dashboard'): | |
if layout == 'portrait': | |
col_num = 4 | |
layout = [dashboard.Item(str(items['image_id'][i]), i % col_num, i//col_num, 1, 2, isResizable=False) for i in range(len(items))] | |
elif layout == 'landscape': | |
col_num = 2 | |
layout = [ | |
dashboard.Item(str(items['image_id'][i]), i % col_num * 2, i // col_num, 2, 1.4, isResizable=False) for | |
i in range(len(items)) | |
] | |
with dashboard.Grid(layout, cols={'lg': 4, 'md': 4, 'sm': 4, 'xs': 4, 'xxs': 2}, onLayoutChange=self.handle_layout_change, margin=[18, 18], containerPadding=[0, 0]): | |
for i in range(len(layout)): | |
with mui.Card(key=str(items['image_id'][i]), variant="outlined"): | |
rank = st.session_state.ranking[str(items['image_id'][i])] + 1 | |
mui.Chip(label=rank, | |
# variant="outlined" if rank!=1 else "default", | |
color="primary" if rank == 1 else "warning" if rank == 2 else "info", | |
size="small", | |
sx={"position": "absolute", "left": "-0.3rem", "top": "-0.3rem"}) | |
img_url = self.images_endpoint + str(items['image_id'][i]) + '.png' | |
mui.CardMedia( | |
component="img", | |
# image={"data:image/png;base64", img_str}, | |
image=img_url, | |
alt="There should be an image", | |
sx={"height": "100%", "object-fit": "fit", 'bgcolor': 'black'}, | |
) | |
def handle_layout_change(self, updated_layout): | |
# print(updated_layout) | |
sorted_list = sorted(updated_layout, key=lambda x: (x['y'], x['x'])) | |
sorted_list = [str(item['i']) for item in sorted_list] | |
for k in st.session_state.ranking.keys(): | |
st.session_state.ranking[k] = sorted_list.index(k) | |
def app(self): | |
st.title('Personal Image Ranking') | |
st.write('Here you can test out your selected images with any prompt you like.') | |
# st.write(self.promptBook) | |
prompt_tags, tag, prompt_id, items = self.sidebar() | |
sorting, control = st.columns((11, 1), gap='large') | |
with sorting: | |
# st.write('## Sorting') | |
# st.write('Please drag the images to sort them.') | |
st.progress((st.session_state.counter + 1) / self.batch_num, text=f"Batch {st.session_state.counter + 1} / {self.batch_num}") | |
self.draggable_images(items.iloc[self.batch_size*st.session_state.counter: self.batch_size*(st.session_state.counter+1)], layout='portrait') | |
with control: | |
st.button(":arrow_right:") | |
st.button(":slightly_frowning_face:") | |
if __name__ == "__main__": | |
st.set_page_config(page_title="Personal Image Ranking", page_icon="🎖️️", layout="wide") | |
if 'user_id' not in st.session_state: | |
st.warning('Please log in first.') | |
home_btn = st.button('Go to Home Page') | |
if home_btn: | |
switch_page("home") | |
else: | |
selected_modelVersions = [] | |
for key, value in st.session_state.selected_dict.items(): | |
for v in value: | |
if v not in selected_modelVersions: | |
selected_modelVersions.append(v) | |
if len(selected_modelVersions) == 0: | |
st.info('You have not checked any image yet. Please go back to the gallery page and check some images.') | |
gallery_btn = st.button('Go to Gallery') | |
if gallery_btn: | |
switch_page('gallery') | |
else: | |
# st.write('You have checked ' + str(len(selected_modelVersions)) + ' images.') | |
roster, promptBook, images_ds = load_hf_dataset() | |
print(st.session_state.selected_dict) | |
# st.write("## roster") | |
# st.write(roster[roster['modelVersion_id'].isin(selected_modelVersions)]) | |
# st.write(roster) | |
# st.write("## promptBook") | |
# st.write(promptBook) | |
# only select the part of the promptbook where tag is the same as st.session_state.selected_dict.keys(), while model version ids are the same as corresponding values to each key | |
promptBook_selected = pd.DataFrame() | |
for key, value in st.session_state.selected_dict.items(): | |
promptBook_selected = promptBook_selected.append(promptBook[(promptBook['prompt_id'] == key) & (promptBook['modelVersion_id'].isin(value))]) | |
promptBook_selected = promptBook_selected.reset_index(drop=True) | |
images_endpoint = "https://modelcofferbucket.s3-accelerate.amazonaws.com/" | |
app = RankingApp(promptBook_selected, images_endpoint, batch_size=4) | |
app.app() | |