Spaces:
Runtime error
Runtime error
Commit
·
e942bb1
1
Parent(s):
5657ecb
Added model architecture selection
Browse files- app.py +9 -3
- images_mocker.py +0 -31
- session_state.py +0 -86
app.py
CHANGED
@@ -26,8 +26,8 @@ def load_image_from_url(url: str) -> Image.Image:
|
|
26 |
return Image.open(requests.get(url, stream=True).raw)
|
27 |
|
28 |
@st.cache
|
29 |
-
def load_model() -> ClipModel:
|
30 |
-
return ClipModel()
|
31 |
|
32 |
def init_state():
|
33 |
if "images" not in st.session_state:
|
@@ -38,6 +38,8 @@ def init_state():
|
|
38 |
st.session_state.predictions = None
|
39 |
if "default_text_input" not in st.session_state:
|
40 |
st.session_state.default_text_input = None
|
|
|
|
|
41 |
|
42 |
|
43 |
def limit_number_images():
|
@@ -278,7 +280,7 @@ if __name__ == "__main__":
|
|
278 |
task_name: str = col2.selectbox("", options=["Prompt ranking", "Image ranking", "Image classification"])
|
279 |
st.markdown("<br>", unsafe_allow_html=True)
|
280 |
init_state()
|
281 |
-
model = load_model()
|
282 |
if task_name == "Image classification":
|
283 |
Sections.image_uploader(accept_multiple_files=False)
|
284 |
if st.session_state.images is None:
|
@@ -311,6 +313,10 @@ if __name__ == "__main__":
|
|
311 |
limit_number_prompts()
|
312 |
Sections.multiple_images_input_preview()
|
313 |
Sections.classification_output(model)
|
|
|
|
|
|
|
|
|
314 |
|
315 |
st.markdown("<br><br><br><br>Made by [@JavierFnts](https://twitter.com/JavierFnts) | [How was CLIP Playground built?](https://twitter.com/JavierFnts/status/1363522529072214019)"
|
316 |
"", unsafe_allow_html=True)
|
|
|
26 |
return Image.open(requests.get(url, stream=True).raw)
|
27 |
|
28 |
@st.cache
|
29 |
+
def load_model(model_architecture: str) -> ClipModel:
|
30 |
+
return ClipModel(model_architecture)
|
31 |
|
32 |
def init_state():
|
33 |
if "images" not in st.session_state:
|
|
|
38 |
st.session_state.predictions = None
|
39 |
if "default_text_input" not in st.session_state:
|
40 |
st.session_state.default_text_input = None
|
41 |
+
if "model_architecture" not in st.session_state:
|
42 |
+
st.session_state.model_architecture = "RN50"
|
43 |
|
44 |
|
45 |
def limit_number_images():
|
|
|
280 |
task_name: str = col2.selectbox("", options=["Prompt ranking", "Image ranking", "Image classification"])
|
281 |
st.markdown("<br>", unsafe_allow_html=True)
|
282 |
init_state()
|
283 |
+
model = load_model(st.session_state.model_architecture)
|
284 |
if task_name == "Image classification":
|
285 |
Sections.image_uploader(accept_multiple_files=False)
|
286 |
if st.session_state.images is None:
|
|
|
313 |
limit_number_prompts()
|
314 |
Sections.multiple_images_input_preview()
|
315 |
Sections.classification_output(model)
|
316 |
+
|
317 |
+
with st.expander("Advanced settings"):
|
318 |
+
st.session_state.model_architecture = st.selectbox("Model architecture", options=['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32',
|
319 |
+
'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'], index=0)
|
320 |
|
321 |
st.markdown("<br><br><br><br>Made by [@JavierFnts](https://twitter.com/JavierFnts) | [How was CLIP Playground built?](https://twitter.com/JavierFnts/status/1363522529072214019)"
|
322 |
"", unsafe_allow_html=True)
|
images_mocker.py
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
from typing import List
|
2 |
-
import uuid
|
3 |
-
from mock import patch
|
4 |
-
|
5 |
-
|
6 |
-
class ImagesMocker:
|
7 |
-
"""HACK ALERT: I needed a way to call the booste API without storing the images first
|
8 |
-
(as that is not allowed in streamlit sharing). If you have a better idea on hwo to this let me know!"""
|
9 |
-
|
10 |
-
def __init__(self):
|
11 |
-
self.pil_patch = patch('PIL.Image.open', lambda x: self.image_id2image(x))
|
12 |
-
self.path_patch = patch('os.path.exists', lambda x: True)
|
13 |
-
self.image_id2image_lookup = {}
|
14 |
-
|
15 |
-
def start_mocking(self):
|
16 |
-
self.pil_patch.start()
|
17 |
-
self.path_patch.start()
|
18 |
-
|
19 |
-
def stop_mocking(self):
|
20 |
-
self.pil_patch.stop()
|
21 |
-
self.path_patch.stop()
|
22 |
-
|
23 |
-
def image_id2image(self, image_id: str):
|
24 |
-
return self.image_id2image_lookup[image_id]
|
25 |
-
|
26 |
-
def calculate_image_id2image_lookup(self, images: List):
|
27 |
-
self.image_id2image_lookup = {str(uuid.uuid4()) + ".png": image for image in images}
|
28 |
-
|
29 |
-
@property
|
30 |
-
def image_ids(self):
|
31 |
-
return list(self.image_id2image_lookup.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
session_state.py
DELETED
@@ -1,86 +0,0 @@
|
|
1 |
-
# From https://gist.github.com/okld/0aba4869ba6fdc8d49132e6974e2e662
|
2 |
-
from streamlit.hashing import _CodeHasher
|
3 |
-
|
4 |
-
try:
|
5 |
-
# Before Streamlit 0.65
|
6 |
-
from streamlit.ReportThread import get_report_ctx
|
7 |
-
from streamlit.server.Server import Server
|
8 |
-
except ModuleNotFoundError:
|
9 |
-
# After Streamlit 0.65
|
10 |
-
from streamlit.report_thread import get_report_ctx
|
11 |
-
from streamlit.server.server import Server
|
12 |
-
|
13 |
-
|
14 |
-
class SessionState:
|
15 |
-
def __init__(self, session, hash_funcs):
|
16 |
-
"""Initialize SessionState instance."""
|
17 |
-
self.__dict__["_state"] = {
|
18 |
-
"data": {},
|
19 |
-
"hash": None,
|
20 |
-
"hasher": _CodeHasher(hash_funcs),
|
21 |
-
"is_rerun": False,
|
22 |
-
"session": session,
|
23 |
-
}
|
24 |
-
|
25 |
-
def __call__(self, **kwargs):
|
26 |
-
"""Initialize state data once."""
|
27 |
-
for item, value in kwargs.items():
|
28 |
-
if item not in self._state["data"]:
|
29 |
-
self._state["data"][item] = value
|
30 |
-
|
31 |
-
def __getitem__(self, item):
|
32 |
-
"""Return a saved state value, None if item is undefined."""
|
33 |
-
return self._state["data"].get(item, None)
|
34 |
-
|
35 |
-
def __getattr__(self, item):
|
36 |
-
"""Return a saved state value, None if item is undefined."""
|
37 |
-
return self._state["data"].get(item, None)
|
38 |
-
|
39 |
-
def __setitem__(self, item, value):
|
40 |
-
"""Set state value."""
|
41 |
-
self._state["data"][item] = value
|
42 |
-
|
43 |
-
def __setattr__(self, item, value):
|
44 |
-
"""Set state value."""
|
45 |
-
self._state["data"][item] = value
|
46 |
-
|
47 |
-
def clear(self):
|
48 |
-
"""Clear session state and request a rerun."""
|
49 |
-
self._state["data"].clear()
|
50 |
-
self._state["session"].request_rerun()
|
51 |
-
|
52 |
-
def sync(self):
|
53 |
-
"""Rerun the app with all state values up to date from the beginning to fix rollbacks."""
|
54 |
-
|
55 |
-
# Ensure to rerun only once to avoid infinite loops
|
56 |
-
# caused by a constantly changing state value at each run.
|
57 |
-
#
|
58 |
-
# Example: state.value += 1
|
59 |
-
if self._state["is_rerun"]:
|
60 |
-
self._state["is_rerun"] = False
|
61 |
-
|
62 |
-
elif self._state["hash"] is not None:
|
63 |
-
if self._state["hash"] != self._state["hasher"].to_bytes(self._state["data"], None):
|
64 |
-
self._state["is_rerun"] = True
|
65 |
-
self._state["session"].request_rerun()
|
66 |
-
|
67 |
-
self._state["hash"] = self._state["hasher"].to_bytes(self._state["data"], None)
|
68 |
-
|
69 |
-
|
70 |
-
def get_session():
|
71 |
-
session_id = get_report_ctx().session_id
|
72 |
-
session_info = Server.get_current()._get_session_info(session_id)
|
73 |
-
|
74 |
-
if session_info is None:
|
75 |
-
raise RuntimeError("Couldn't get your Streamlit Session object.")
|
76 |
-
|
77 |
-
return session_info.session
|
78 |
-
|
79 |
-
|
80 |
-
def get_state(hash_funcs=None):
|
81 |
-
session = get_session()
|
82 |
-
|
83 |
-
if not hasattr(session, "_custom_session_state"):
|
84 |
-
session._custom_session_state = SessionState(session, hash_funcs)
|
85 |
-
|
86 |
-
return session._custom_session_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|