File size: 12,346 Bytes
081801e
b0cb25e
af5047d
b0cb25e
081801e
af5047d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed1918f
af5047d
 
 
 
ed1918f
 
 
 
 
 
081801e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af5047d
081801e
 
 
ed1918f
b0cb25e
 
 
 
 
 
 
 
 
 
 
 
 
af5047d
 
 
 
 
 
 
 
 
 
 
b0cb25e
 
 
 
 
 
 
 
081801e
b0cb25e
 
 
 
081801e
b0cb25e
 
 
 
081801e
 
 
 
 
 
 
 
 
 
 
 
 
 
b0cb25e
 
 
 
 
 
 
 
 
081801e
b0cb25e
 
 
081801e
 
b0cb25e
 
 
 
 
081801e
b0cb25e
 
 
 
ed1918f
081801e
 
 
 
 
 
 
 
 
 
af5047d
 
081801e
 
 
af5047d
081801e
 
 
 
 
 
 
 
b0cb25e
 
 
 
 
af5047d
 
 
 
 
 
 
 
 
 
 
 
 
b0cb25e
081801e
 
 
 
 
 
 
 
 
 
 
 
 
 
af5047d
 
 
 
 
 
 
081801e
 
 
 
 
 
b0cb25e
 
 
ed1918f
af5047d
b0cb25e
af5047d
 
 
 
b0cb25e
 
081801e
b0cb25e
 
af5047d
b0cb25e
af5047d
 
 
 
b0cb25e
 
 
081801e
b0cb25e
 
af5047d
b0cb25e
af5047d
 
 
 
081801e
 
 
af5047d
ed1918f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
import random
from typing import Optional, List
import uuid

import streamlit as st
from mock import patch

class ImagesMocker:
    """HACK ALERT: I needed a way to call the booste API without storing the images first
     (as that is not allowed in streamlit sharing). If you have a better idea on hwo to this let me know!"""

    def __init__(self):
        self.pil_patch = patch('PIL.Image.open', lambda x: self.image_id2image(x))
        self.path_patch = patch('os.path.exists', lambda x: True)
        self.image_id2image_lookup = {}

    def start_mocking(self):
        self.pil_patch.start()
        self.path_patch.start()

    def stop_mocking(self):
        self.pil_patch.stop()
        self.path_patch.stop()

    def image_id2image(self, image_id: str):
        return self.image_id2image_lookup[image_id]

    def calculate_image_id2image_lookup(self, images: List):
        self.image_id2image_lookup = {str(uuid.uuid4()) + ".png": image for image in images}
    @property
    def image_ids(self):
        return list(self.image_id2image_lookup.keys())

images_mocker = ImagesMocker()
import booste

from PIL import Image
from session_state import SessionState, get_state

# Unfortunately Streamlit sharing does not allow to hide enviroment variables yet.
# Do not copy this API key, go to https://www.booste.io/ and get your own, it is free!
BOOSTE_API_KEY = "3818ba84-3526-4029-9dc8-ef3038697ea2"

IMAGES_LINKS = ["https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg",
                "https://cdn.pixabay.com/photo/2019/09/06/04/25/beach-4455433_960_720.jpg",
                "https://cdn.pixabay.com/photo/2019/10/19/12/21/hot-air-balloons-4561264_960_720.jpg",
                "https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg",
                "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg",
                "https://cdn.pixabay.com/photo/2020/12/28/22/48/buddha-5868759_960_720.jpg",
                "https://cdn.pixabay.com/photo/2019/11/11/14/30/zebra-4618513_960_720.jpg",
                "https://cdn.pixabay.com/photo/2020/11/04/15/29/coffee-beans-5712780_960_720.jpg",
                "https://cdn.pixabay.com/photo/2020/03/24/20/42/namibia-4965457_960_720.jpg",
                "https://cdn.pixabay.com/photo/2020/08/27/07/31/restaurant-5521372_960_720.jpg",
                "https://cdn.pixabay.com/photo/2020/08/28/06/13/building-5523630_960_720.jpg",
                "https://cdn.pixabay.com/photo/2020/08/24/21/41/couple-5515141_960_720.jpg",
                "https://cdn.pixabay.com/photo/2020/01/31/07/10/billboards-4807268_960_720.jpg",
                "https://cdn.pixabay.com/photo/2017/07/31/20/48/shell-2560930_960_720.jpg",
                "https://cdn.pixabay.com/photo/2020/08/13/01/29/koala-5483931_960_720.jpg",
                "https://cdn.pixabay.com/photo/2016/11/29/04/52/architecture-1867411_960_720.jpg",
                ]

@st.cache  # Cache this so that it doesn't change every time something changes in the page
def select_random_dataset():
    return random.sample(IMAGES_LINKS, 10)


class Sections:
    @staticmethod
    def header():
        st.markdown("# CLIP playground")
        st.markdown("### Try OpenAI's CLIP model in your browser")
        st.markdown(" ");
        st.markdown(" ")
        with st.beta_expander("What is CLIP?"):
            st.markdown("Nice CLIP explaination")
        st.markdown(" ");
        st.markdown(" ")

    @staticmethod
    def image_uploader(state: SessionState, accept_multiple_files: bool):
        uploaded_images = st.file_uploader("Upload image", type=[".png", ".jpg", ".jpeg"],
                                           accept_multiple_files=accept_multiple_files)
        if uploaded_images is not None or (accept_multiple_files and len(uploaded_images) > 1):
            images = []
            if not accept_multiple_files:
                uploaded_images = [uploaded_images]
            for uploaded_image in uploaded_images:
                images.append(Image.open(uploaded_image))
            state.images = images


    @staticmethod
    def image_picker(state: SessionState):
        col1, col2, col3 = st.beta_columns(3)
        with col1:
            default_image_1 = "https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg"
            st.image(default_image_1, use_column_width=True)
            if st.button("Select image 1"):
                state.images = [default_image_1]
        with col2:
            default_image_2 = "https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg"
            st.image(default_image_2, use_column_width=True)
            if st.button("Select image 2"):
                state.images = [default_image_2]
        with col3:
            default_image_3 = "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg"
            st.image(default_image_3, use_column_width=True)
            if st.button("Select image 3"):
                state.images = [default_image_3]

    @staticmethod
    def dataset_picker(state: SessionState):
        columns = st.beta_columns(5)
        state.dataset = select_random_dataset()
        image_idx = 0
        for col in columns:
            col.image(state.dataset[image_idx])
            image_idx += 1
            col.image(state.dataset[image_idx])
            image_idx += 1
        if st.button("Select random dataset"):
            state.images = state.dataset

    @staticmethod
    def prompts_input(state: SessionState, input_label: str, prompt_prefix: str = ''):
        raw_classes = st.text_input(input_label)
        if raw_classes:
            state.prompts = [prompt_prefix + class_name for class_name in raw_classes.split(";") if len(class_name) > 1]
            state.prompt_prefix = prompt_prefix

    @staticmethod
    def single_image_input_preview(state: SessionState):
        col1, col2 = st.beta_columns([2, 1])
        with col1:
            st.markdown("Image to classify")
            if state.images is not None:
                st.image(state.images[0], use_column_width=True)
            else:
                st.warning("Select an image")

        with col2:
            st.markdown("Labels to choose from")
            if state.prompts is not None:
                for prompt in state.prompts:
                    st.write(prompt[len(state.prompt_prefix):])
            else:
                st.warning("Enter the classes to classify from")

    @staticmethod
    def multiple_images_input_preview(state: SessionState):
        st.markdown("Images to classify")
        col1, col2, col3 = st.beta_columns(3)
        if state.images is not None:
            for idx, image in enumerate(state.images):
                if idx < len(state.images) / 2:
                    col1.image(state.images[idx], use_column_width=True)
                else:
                    col2.image(state.images[idx], use_column_width=True)
            if len(state.images) < 2:
                col2.warning("At least 2 images required")
        else:
            col1.warning("Select an image")


        with col3:
            st.markdown("Query prompt")
            if state.prompts is not None:
                for prompt in state.prompts:
                    st.write(prompt[len(state.prompt_prefix):])
            else:
                st.warning("Enter the prompt to classify")

    @staticmethod
    def classification_output(state: SessionState):
        # Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
        if st.button("Predict"):
            with st.spinner("Predicting..."):
                if isinstance(state.images[0], str):
                    print("Regular call!")
                    clip_response = booste.clip(BOOSTE_API_KEY,
                                                prompts=state.prompts,
                                                images=state.images)
                else:
                    print("Hacky call!")
                    images_mocker.calculate_image_id2image_lookup(state.images)
                    images_mocker.start_mocking()
                    clip_response = booste.clip(BOOSTE_API_KEY,
                                                prompts=state.prompts,
                                                images=images_mocker.image_ids)
                    images_mocker.stop_mocking()
                st.markdown("### Results")
                # st.write(clip_response)
                if len(state.images) == 1:
                    simplified_clip_results = [(prompt[len(state.prompt_prefix):],
                                                list(results.values())[0]["probabilityRelativeToPrompts"])
                                               for prompt, results in clip_response.items()]
                    simplified_clip_results = sorted(simplified_clip_results, key=lambda x: x[1], reverse=True)

                    for prompt, probability in simplified_clip_results:
                        percentage_prob = int(probability * 100)
                        st.markdown(
                            f"### ![prob](https://progress-bar.dev/{percentage_prob}/?width=200) &nbsp &nbsp {prompt}")
                else:
                    st.markdown(f"### {state.prompts[0]}")
                    assert len(state.prompts) == 1
                    if isinstance(state.images[0], str):
                        simplified_clip_results = [(image, results["probabilityRelativeToImages"]) for image, results
                                                   in list(clip_response.values())[0].items()]
                    else:
                        simplified_clip_results = [(images_mocker.image_id2image(image),
                                                    results["probabilityRelativeToImages"]) for image, results
                                                   in list(clip_response.values())[0].items()]
                    simplified_clip_results = sorted(simplified_clip_results, key=lambda x: x[1], reverse=True)
                    for image, probability in simplified_clip_results[:5]:
                        col1, col2 = st.beta_columns([1, 3])
                        col1.image(image, use_column_width=True)
                        percentage_prob = int(probability * 100)
                        col2.markdown(f"### ![prob](https://progress-bar.dev/{percentage_prob}/?width=200)")


task_name: str = st.sidebar.radio("Task", options=["Image classification", "Image ranking", "Prompt ranking"])
if task_name == "Image classification":
    session_state = get_state()
    Sections.header()
    Sections.image_uploader(session_state, accept_multiple_files=False)
    if session_state.images is None:
        st.markdown("or choose one from")
        Sections.image_picker(session_state)
    input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
    Sections.prompts_input(session_state, input_label, prompt_prefix='A picture of a ')
    Sections.single_image_input_preview(session_state)
    Sections.classification_output(session_state)
elif task_name == "Prompt ranking":
    session_state = get_state()
    Sections.header()
    Sections.image_uploader(session_state, accept_multiple_files=False)
    if session_state.images is None:
        st.markdown("or choose one from")
        Sections.image_picker(session_state)
    input_label = "Enter the prompts to choose from separated by a semi-colon. " \
                  "(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
    Sections.prompts_input(session_state, input_label)
    Sections.single_image_input_preview(session_state)
    Sections.classification_output(session_state)
elif task_name == "Image ranking":
    session_state = get_state()
    Sections.header()
    Sections.image_uploader(session_state, accept_multiple_files=True)
    if session_state.images is None:
        st.markdown("or use this random dataset")
        Sections.dataset_picker(session_state)
    Sections.prompts_input(session_state, "Enter the prompt to query the images by")
    Sections.multiple_images_input_preview(session_state)
    Sections.classification_output(session_state)
    print(session_state.images)

session_state.sync()