Spaces:
Runtime error
Runtime error
Javi
commited on
Commit
·
b0cb25e
1
Parent(s):
de38ce1
Both classification and prompt ranking working
Browse files- streamlit_app.py +110 -67
streamlit_app.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
from PIL import Image
|
2 |
import streamlit as st
|
3 |
import booste
|
@@ -9,76 +11,117 @@ from session_state import SessionState, get_state
|
|
9 |
BOOSTE_API_KEY = "3818ba84-3526-4029-9dc8-ef3038697ea2"
|
10 |
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
if task_name == "Image classification":
|
21 |
-
|
22 |
-
|
23 |
-
accept_multiple_files=False)
|
24 |
st.markdown("or choose one from")
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
col1, col2 = st.beta_columns([2, 1])
|
48 |
-
with col1:
|
49 |
-
st.markdown("Image to classify")
|
50 |
-
if session_state.image is not None:
|
51 |
-
st.image(session_state.image, use_column_width=True)
|
52 |
-
else:
|
53 |
-
st.warning("Select an image")
|
54 |
-
|
55 |
-
with col2:
|
56 |
-
st.markdown("Classes to choose from")
|
57 |
-
if session_state.processed_classes is not None:
|
58 |
-
for class_name in session_state.processed_classes:
|
59 |
-
st.write(class_name)
|
60 |
-
else:
|
61 |
-
st.warning("Enter the classes to classify from")
|
62 |
-
|
63 |
-
# Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
|
64 |
-
if st.button("Predict"):
|
65 |
-
with st.spinner("Predicting..."):
|
66 |
-
clip_response = booste.clip(BOOSTE_API_KEY,
|
67 |
-
prompts=input_prompts,
|
68 |
-
images=[session_state.image],
|
69 |
-
pretty_print=True)
|
70 |
-
st.markdown("### Results")
|
71 |
-
simplified_clip_results = [(prompt[len('A picture of a '):],
|
72 |
-
list(results.values())[0]["probabilityRelativeToPrompts"])
|
73 |
-
for prompt, results in clip_response.items()]
|
74 |
-
simplified_clip_results = sorted(simplified_clip_results, key=lambda x: x[1], reverse=True)
|
75 |
-
max_class_name_length = max(len(class_name) for class_name, _ in simplified_clip_results)
|
76 |
-
|
77 |
-
for prompt, probability in simplified_clip_results:
|
78 |
-
progress_bar = "".join([":large_blue_circle:"] * int(probability * 10) +
|
79 |
-
[":black_circle:"] * int((1 - probability) * 10))
|
80 |
-
st.markdown(f"### {prompt}: {progress_bar} {probability:.3f}")
|
81 |
-
st.write(clip_response)
|
82 |
|
83 |
session_state.sync()
|
84 |
|
|
|
1 |
+
from typing import Optional, List
|
2 |
+
|
3 |
from PIL import Image
|
4 |
import streamlit as st
|
5 |
import booste
|
|
|
11 |
BOOSTE_API_KEY = "3818ba84-3526-4029-9dc8-ef3038697ea2"
|
12 |
|
13 |
|
14 |
+
class Sections:
|
15 |
+
@staticmethod
|
16 |
+
def header():
|
17 |
+
st.markdown("# CLIP playground")
|
18 |
+
st.markdown("### Try OpenAI's CLIP model in your browser")
|
19 |
+
st.markdown(" ");
|
20 |
+
st.markdown(" ")
|
21 |
+
with st.beta_expander("What is CLIP?"):
|
22 |
+
st.markdown("Nice CLIP explaination")
|
23 |
+
st.markdown(" ");
|
24 |
+
st.markdown(" ")
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
def image_uploader(accept_multiple_files: bool) -> Optional[List[str]]:
|
28 |
+
uploaded_image = st.file_uploader("Upload image", type=[".png", ".jpg", ".jpeg"],
|
29 |
+
accept_multiple_files=accept_multiple_files)
|
30 |
+
|
31 |
+
@staticmethod
|
32 |
+
def image_picker(state: SessionState):
|
33 |
+
col1, col2, col3 = st.beta_columns(3)
|
34 |
+
with col1:
|
35 |
+
default_image_1 = "https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg"
|
36 |
+
st.image(default_image_1, use_column_width=True)
|
37 |
+
if st.button("Select image 1"):
|
38 |
+
state.image = default_image_1
|
39 |
+
with col2:
|
40 |
+
default_image_2 = "https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg"
|
41 |
+
st.image(default_image_2, use_column_width=True)
|
42 |
+
if st.button("Select image 2"):
|
43 |
+
state.image = default_image_2
|
44 |
+
with col3:
|
45 |
+
default_image_3 = "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg"
|
46 |
+
st.image(default_image_3, use_column_width=True)
|
47 |
+
if st.button("Select image 3"):
|
48 |
+
state.image = default_image_3
|
49 |
+
|
50 |
+
@staticmethod
|
51 |
+
def prompts_input(state: SessionState, input_label: str, prompt_prefix: str = ''):
|
52 |
+
raw_classes = st.text_input(input_label)
|
53 |
+
if raw_classes:
|
54 |
+
state.prompts = [prompt_prefix + class_name for class_name in raw_classes.split(";") if len(class_name) > 1]
|
55 |
+
state.prompt_prefix = prompt_prefix
|
56 |
+
|
57 |
+
@staticmethod
|
58 |
+
def input_preview(state: SessionState):
|
59 |
+
col1, col2 = st.beta_columns([2, 1])
|
60 |
+
with col1:
|
61 |
+
st.markdown("Image to classify")
|
62 |
+
if state.image is not None:
|
63 |
+
st.image(state.image, use_column_width=True)
|
64 |
+
else:
|
65 |
+
st.warning("Select an image")
|
66 |
+
|
67 |
+
with col2:
|
68 |
+
st.markdown("Labels to choose from")
|
69 |
+
if state.processed_classes is not None:
|
70 |
+
for prompt in state.prompts:
|
71 |
+
st.write(prompt[len(state.prompt_prefix):])
|
72 |
+
else:
|
73 |
+
st.warning("Enter the classes to classify from")
|
74 |
|
75 |
+
@staticmethod
|
76 |
+
def classification_output(state: SessionState):
|
77 |
+
# Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
|
78 |
+
if st.button("Predict"):
|
79 |
+
with st.spinner("Predicting..."):
|
80 |
+
clip_response = booste.clip(BOOSTE_API_KEY,
|
81 |
+
prompts=state.prompts,
|
82 |
+
images=[state.image],
|
83 |
+
pretty_print=True)
|
84 |
+
st.markdown("### Results")
|
85 |
+
simplified_clip_results = [(prompt[len(state.prompt_prefix):],
|
86 |
+
list(results.values())[0]["probabilityRelativeToPrompts"])
|
87 |
+
for prompt, results in clip_response.items()]
|
88 |
+
simplified_clip_results = sorted(simplified_clip_results, key=lambda x: x[1], reverse=True)
|
89 |
+
|
90 |
+
for prompt, probability in simplified_clip_results:
|
91 |
+
percentage_prob = int(probability * 100)
|
92 |
+
st.markdown(
|
93 |
+
f"###      {prompt}")
|
94 |
+
st.write(clip_response)
|
95 |
+
|
96 |
+
|
97 |
+
task_name: str = st.sidebar.radio("Task", options=["Image classification", "Image ranking", "Prompt ranking"])
|
98 |
+
session_state = get_state()
|
99 |
if task_name == "Image classification":
|
100 |
+
Sections.header()
|
101 |
+
Sections.image_uploader(accept_multiple_files=False)
|
|
|
102 |
st.markdown("or choose one from")
|
103 |
+
Sections.image_picker(session_state)
|
104 |
+
input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
|
105 |
+
Sections.prompts_input(session_state, input_label, prompt_prefix='A picture of a ')
|
106 |
+
Sections.input_preview(session_state)
|
107 |
+
Sections.classification_output(session_state)
|
108 |
+
elif task_name == "Prompt ranking":
|
109 |
+
Sections.header()
|
110 |
+
Sections.image_uploader(accept_multiple_files=False)
|
111 |
+
st.markdown("or choose one from")
|
112 |
+
Sections.image_picker(session_state)
|
113 |
+
input_label = "Enter the prompts to choose from separated by a semi-colon. " \
|
114 |
+
"(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
|
115 |
+
Sections.prompts_input(session_state, input_label)
|
116 |
+
Sections.input_preview(session_state)
|
117 |
+
Sections.classification_output(session_state)
|
118 |
+
elif task_name == "Image ranking":
|
119 |
+
Sections.header()
|
120 |
+
Sections.image_uploader(accept_multiple_files=True)
|
121 |
+
st.markdown("or use random dataset")
|
122 |
+
Sections.image_picker(session_state)
|
123 |
+
|
124 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
session_state.sync()
|
127 |
|