File size: 3,073 Bytes
ed1918f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from PIL import Image
import streamlit as st
import booste

from session_state import SessionState, get_state

# Unfortunately Streamlit sharing does not allow to hide enviroment variables yet.
# Do not copy this API key, go to https://www.booste.io/ and get your own, it is free!
BOOSTE_API_KEY = "3818ba84-3526-4029-9dc8-ef3038697ea2"


task_name: str = st.sidebar.radio("Task", options=["Image classification", "Image ranking", "Prompt ranking"])

st.markdown("# CLIP playground")
st.markdown("### Try OpenAI's CLIP model in your browser")
st.markdown(" "); st.markdown(" ")
with st.beta_expander("What is CLIP?"):
    st.markdown("Nice CLIP explaination")
st.markdown(" "); st.markdown(" ")
if task_name == "Image classification":
    session_state = get_state()
    uploaded_image = st.file_uploader("Upload image", type=[".png", ".jpg", ".jpeg"],
                                      accept_multiple_files=False)
    st.markdown("or choose one from")
    col1, col2, col3 = st.beta_columns(3)
    with col1:
        default_image_1 = "https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg"
        st.image(default_image_1, use_column_width=True)
        if st.button("Select image 1"):
            session_state.image = default_image_1
    with col2:
        default_image_2 = "https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg"
        st.image(default_image_2, use_column_width=True)
        if st.button("Select image 2"):
            session_state.image = default_image_2
    with col3:
        default_image_3 = "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg"
        st.image(default_image_3, use_column_width=True)
        if st.button("Select image 3"):
            session_state.image = default_image_3
    raw_classes = st.text_input("Enter the classes to chose from separated by a comma."
                                " (f.x. `banana, sailing boat, honesty, apple`)")
    if raw_classes:
        session_state.processed_classes = raw_classes.split(",")
        input_prompts = ["A picture of a " + class_name for class_name in session_state.processed_classes]

col1, col2 = st.beta_columns([2, 1])
with col1:
    st.markdown("Image to classify")
    if session_state.image is not None:
        st.image(session_state.image, use_column_width=True)
    else:
        st.warning("Select an image")

with col2:
    st.markdown("Classes to choose from")
    if session_state.processed_classes is not None:
        for class_name in session_state.processed_classes:
            st.write(class_name)
    else:
        st.warning("Enter the classes to classify from")

# Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
if st.button("Predict"):
    with st.spinner("Predicting..."):
        clip_response = booste.clip(BOOSTE_API_KEY,
                        prompts=input_prompts,
                        images=[session_state.image],
                        pretty_print=True)
        st.write(clip_response)


session_state.sync()