Spaces:
Runtime error
Runtime error
Commit
Β·
40497e6
1
Parent(s):
0454d20
Base functionality working again π
Browse files- app.py +45 -47
- clip_model.py +16 -3
app.py
CHANGED
@@ -230,8 +230,7 @@ class Sections:
|
|
230 |
|
231 |
@staticmethod
|
232 |
def classification_output(model: ClipModel):
|
233 |
-
|
234 |
-
if st.button("Predict") and is_valid_prediction_state(): # PREDICT π
|
235 |
with st.spinner("Predicting..."):
|
236 |
|
237 |
st.markdown("### Results")
|
@@ -247,7 +246,6 @@ class Sections:
|
|
247 |
st.markdown(f"### {st.session_state.prompts[0]}")
|
248 |
|
249 |
scores = model.compute_images_probabilities(st.session_state.images, st.session_state.prompts[0])
|
250 |
-
st.json(scores)
|
251 |
scored_images = [(image, score) for image, score in zip(st.session_state.images, scores)]
|
252 |
sorted_scored_images = sorted(scored_images, key=lambda x: x[1], reverse=True)
|
253 |
|
@@ -272,47 +270,47 @@ class Sections:
|
|
272 |
# " It can be whatever you can think of",
|
273 |
# unsafe_allow_html=True)
|
274 |
|
275 |
-
|
276 |
-
Sections.header()
|
277 |
-
col1, col2 = st.columns([1, 2])
|
278 |
-
col1.markdown(" "); col1.markdown(" ")
|
279 |
-
col1.markdown("#### Task selection")
|
280 |
-
task_name: str = col2.selectbox("", options=["Prompt ranking", "Image ranking", "Image classification"])
|
281 |
-
st.markdown("<br>", unsafe_allow_html=True)
|
282 |
-
init_state()
|
283 |
-
model = load_model()
|
284 |
-
if task_name == "Image classification":
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
elif task_name == "Prompt ranking":
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
elif task_name == "Image ranking":
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
st.markdown("<br><br><br><br>Made by [@JavierFnts](https://twitter.com/JavierFnts) | [How was CLIP Playground built?](https://twitter.com/JavierFnts/status/1363522529072214019)"
|
318 |
-
|
|
|
230 |
|
231 |
@staticmethod
|
232 |
def classification_output(model: ClipModel):
|
233 |
+
if st.button("Predict") and is_valid_prediction_state():
|
|
|
234 |
with st.spinner("Predicting..."):
|
235 |
|
236 |
st.markdown("### Results")
|
|
|
246 |
st.markdown(f"### {st.session_state.prompts[0]}")
|
247 |
|
248 |
scores = model.compute_images_probabilities(st.session_state.images, st.session_state.prompts[0])
|
|
|
249 |
scored_images = [(image, score) for image, score in zip(st.session_state.images, scores)]
|
250 |
sorted_scored_images = sorted(scored_images, key=lambda x: x[1], reverse=True)
|
251 |
|
|
|
270 |
# " It can be whatever you can think of",
|
271 |
# unsafe_allow_html=True)
|
272 |
|
273 |
+
if __name__ == "__main__":
|
274 |
+
Sections.header()
|
275 |
+
col1, col2 = st.columns([1, 2])
|
276 |
+
col1.markdown(" "); col1.markdown(" ")
|
277 |
+
col1.markdown("#### Task selection")
|
278 |
+
task_name: str = col2.selectbox("", options=["Prompt ranking", "Image ranking", "Image classification"])
|
279 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
280 |
+
init_state()
|
281 |
+
model = load_model()
|
282 |
+
if task_name == "Image classification":
|
283 |
+
Sections.image_uploader(accept_multiple_files=False)
|
284 |
+
if st.session_state.images is None:
|
285 |
+
st.markdown("or choose one from")
|
286 |
+
Sections.image_picker(default_text_input="banana; boat; bird")
|
287 |
+
input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
|
288 |
+
Sections.prompts_input(input_label, prompt_prefix='A picture of a ')
|
289 |
+
limit_number_images()
|
290 |
+
Sections.single_image_input_preview()
|
291 |
+
Sections.classification_output(model)
|
292 |
+
elif task_name == "Prompt ranking":
|
293 |
+
Sections.image_uploader(accept_multiple_files=False)
|
294 |
+
if st.session_state.images is None:
|
295 |
+
st.markdown("or choose one from")
|
296 |
+
Sections.image_picker(default_text_input="A calm afternoon in the Mediterranean; "
|
297 |
+
"A beautiful creature;"
|
298 |
+
" Something that grows in tropical regions")
|
299 |
+
input_label = "Enter the prompts to choose from separated by a semi-colon. " \
|
300 |
+
"(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
|
301 |
+
Sections.prompts_input(input_label)
|
302 |
+
limit_number_images()
|
303 |
+
Sections.single_image_input_preview()
|
304 |
+
Sections.classification_output(model)
|
305 |
+
elif task_name == "Image ranking":
|
306 |
+
Sections.image_uploader(accept_multiple_files=True)
|
307 |
+
if st.session_state.images is None or len(st.session_state.images) < 2:
|
308 |
+
st.markdown("or use this random dataset")
|
309 |
+
Sections.dataset_picker()
|
310 |
+
Sections.prompts_input("Enter the prompt to query the images by")
|
311 |
+
limit_number_prompts()
|
312 |
+
Sections.multiple_images_input_preview()
|
313 |
+
Sections.classification_output(model)
|
314 |
+
|
315 |
+
st.markdown("<br><br><br><br>Made by [@JavierFnts](https://twitter.com/JavierFnts) | [How was CLIP Playground built?](https://twitter.com/JavierFnts/status/1363522529072214019)"
|
316 |
+
"", unsafe_allow_html=True)
|
clip_model.py
CHANGED
@@ -2,6 +2,8 @@ import clip
|
|
2 |
from PIL.Image import Image
|
3 |
import torch
|
4 |
|
|
|
|
|
5 |
class ClipModel:
|
6 |
def __init__(self, model_name: str = 'RN50') -> None:
|
7 |
"""
|
@@ -42,7 +44,7 @@ class ClipModel:
|
|
42 |
preprocessed_images = [self._img_preprocess(image).unsqueeze(0) for image in images]
|
43 |
tokenized_prompts = clip.tokenize(prompt)
|
44 |
with torch.inference_mode():
|
45 |
-
image_features = self._model.encode_image(
|
46 |
text_features = self._model.encode_text(tokenized_prompts)
|
47 |
|
48 |
# normalized features
|
@@ -51,8 +53,19 @@ class ClipModel:
|
|
51 |
|
52 |
# cosine similarity as logits
|
53 |
logit_scale = self._model.logit_scale.exp()
|
54 |
-
logits_per_image = logit_scale *
|
55 |
|
56 |
probs = list(logits_per_image.softmax(dim=-1).cpu().numpy()[0])
|
57 |
|
58 |
-
return probs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from PIL.Image import Image
|
3 |
import torch
|
4 |
|
5 |
+
|
6 |
+
|
7 |
class ClipModel:
|
8 |
def __init__(self, model_name: str = 'RN50') -> None:
|
9 |
"""
|
|
|
44 |
preprocessed_images = [self._img_preprocess(image).unsqueeze(0) for image in images]
|
45 |
tokenized_prompts = clip.tokenize(prompt)
|
46 |
with torch.inference_mode():
|
47 |
+
image_features = torch.cat([self._model.encode_image(preprocessed_image) for preprocessed_image in preprocessed_images])
|
48 |
text_features = self._model.encode_text(tokenized_prompts)
|
49 |
|
50 |
# normalized features
|
|
|
53 |
|
54 |
# cosine similarity as logits
|
55 |
logit_scale = self._model.logit_scale.exp()
|
56 |
+
logits_per_image = logit_scale * text_features @ image_features.t()
|
57 |
|
58 |
probs = list(logits_per_image.softmax(dim=-1).cpu().numpy()[0])
|
59 |
|
60 |
+
return probs
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
from app import load_default_dataset
|
64 |
+
|
65 |
+
model = ClipModel()
|
66 |
+
images = load_default_dataset()
|
67 |
+
prompts = ['Hello', 'How are you', 'Goodbye']
|
68 |
+
prompts_scores = model.compute_prompts_probabilities(images[0], prompts)
|
69 |
+
images_scores = model.compute_images_probabilities(images, prompts[0])
|
70 |
+
print(f"Prompts scores: {prompts_scores}")
|
71 |
+
print(f"Images scores: {images_scores}")
|