File size: 5,580 Bytes
b0cb25e
 
ed1918f
 
 
 
 
 
 
 
 
 
 
b0cb25e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed1918f
b0cb25e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed1918f
b0cb25e
 
ed1918f
b0cb25e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed1918f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from typing import Optional, List

from PIL import Image
import streamlit as st
import booste

from session_state import SessionState, get_state

# Unfortunately Streamlit sharing does not allow to hide enviroment variables yet.
# Do not copy this API key, go to https://www.booste.io/ and get your own, it is free!
BOOSTE_API_KEY = "3818ba84-3526-4029-9dc8-ef3038697ea2"


class Sections:
    @staticmethod
    def header():
        st.markdown("# CLIP playground")
        st.markdown("### Try OpenAI's CLIP model in your browser")
        st.markdown(" ");
        st.markdown(" ")
        with st.beta_expander("What is CLIP?"):
            st.markdown("Nice CLIP explaination")
        st.markdown(" ");
        st.markdown(" ")

    @staticmethod
    def image_uploader(accept_multiple_files: bool) -> Optional[List[str]]:
        uploaded_image = st.file_uploader("Upload image", type=[".png", ".jpg", ".jpeg"],
                                          accept_multiple_files=accept_multiple_files)

    @staticmethod
    def image_picker(state: SessionState):
        col1, col2, col3 = st.beta_columns(3)
        with col1:
            default_image_1 = "https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg"
            st.image(default_image_1, use_column_width=True)
            if st.button("Select image 1"):
                state.image = default_image_1
        with col2:
            default_image_2 = "https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg"
            st.image(default_image_2, use_column_width=True)
            if st.button("Select image 2"):
                state.image = default_image_2
        with col3:
            default_image_3 = "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg"
            st.image(default_image_3, use_column_width=True)
            if st.button("Select image 3"):
                state.image = default_image_3

    @staticmethod
    def prompts_input(state: SessionState, input_label: str, prompt_prefix: str = ''):
        raw_classes = st.text_input(input_label)
        if raw_classes:
            state.prompts = [prompt_prefix + class_name for class_name in raw_classes.split(";") if len(class_name) > 1]
            state.prompt_prefix = prompt_prefix

    @staticmethod
    def input_preview(state: SessionState):
        col1, col2 = st.beta_columns([2, 1])
        with col1:
            st.markdown("Image to classify")
            if state.image is not None:
                st.image(state.image, use_column_width=True)
            else:
                st.warning("Select an image")

        with col2:
            st.markdown("Labels to choose from")
            if state.processed_classes is not None:
                for prompt in state.prompts:
                    st.write(prompt[len(state.prompt_prefix):])
            else:
                st.warning("Enter the classes to classify from")

    @staticmethod
    def classification_output(state: SessionState):
        # Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
        if st.button("Predict"):
            with st.spinner("Predicting..."):
                clip_response = booste.clip(BOOSTE_API_KEY,
                                            prompts=state.prompts,
                                            images=[state.image],
                                            pretty_print=True)
                st.markdown("### Results")
                simplified_clip_results = [(prompt[len(state.prompt_prefix):],
                                            list(results.values())[0]["probabilityRelativeToPrompts"])
                                           for prompt, results in clip_response.items()]
                simplified_clip_results = sorted(simplified_clip_results, key=lambda x: x[1], reverse=True)

                for prompt, probability in simplified_clip_results:
                    percentage_prob = int(probability * 100)
                    st.markdown(
                        f"### ![prob](https://progress-bar.dev/{percentage_prob}/?width=200) &nbsp &nbsp {prompt}")
                st.write(clip_response)


task_name: str = st.sidebar.radio("Task", options=["Image classification", "Image ranking", "Prompt ranking"])
session_state = get_state()
if task_name == "Image classification":
    Sections.header()
    Sections.image_uploader(accept_multiple_files=False)
    st.markdown("or choose one from")
    Sections.image_picker(session_state)
    input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
    Sections.prompts_input(session_state, input_label, prompt_prefix='A picture of a ')
    Sections.input_preview(session_state)
    Sections.classification_output(session_state)
elif task_name == "Prompt ranking":
    Sections.header()
    Sections.image_uploader(accept_multiple_files=False)
    st.markdown("or choose one from")
    Sections.image_picker(session_state)
    input_label = "Enter the prompts to choose from separated by a semi-colon. " \
                  "(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
    Sections.prompts_input(session_state, input_label)
    Sections.input_preview(session_state)
    Sections.classification_output(session_state)
elif task_name == "Image ranking":
    Sections.header()
    Sections.image_uploader(accept_multiple_files=True)
    st.markdown("or use random dataset")
    Sections.image_picker(session_state)



session_state.sync()