Spaces:
Runtime error
Runtime error
File size: 5,580 Bytes
b0cb25e ed1918f b0cb25e ed1918f b0cb25e ed1918f b0cb25e ed1918f b0cb25e ed1918f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from typing import Optional, List
from PIL import Image
import streamlit as st
import booste
from session_state import SessionState, get_state
# Unfortunately Streamlit sharing does not allow to hide enviroment variables yet.
# Do not copy this API key, go to https://www.booste.io/ and get your own, it is free!
BOOSTE_API_KEY = "3818ba84-3526-4029-9dc8-ef3038697ea2"
class Sections:
@staticmethod
def header():
st.markdown("# CLIP playground")
st.markdown("### Try OpenAI's CLIP model in your browser")
st.markdown(" ");
st.markdown(" ")
with st.beta_expander("What is CLIP?"):
st.markdown("Nice CLIP explaination")
st.markdown(" ");
st.markdown(" ")
@staticmethod
def image_uploader(accept_multiple_files: bool) -> Optional[List[str]]:
uploaded_image = st.file_uploader("Upload image", type=[".png", ".jpg", ".jpeg"],
accept_multiple_files=accept_multiple_files)
@staticmethod
def image_picker(state: SessionState):
col1, col2, col3 = st.beta_columns(3)
with col1:
default_image_1 = "https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg"
st.image(default_image_1, use_column_width=True)
if st.button("Select image 1"):
state.image = default_image_1
with col2:
default_image_2 = "https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg"
st.image(default_image_2, use_column_width=True)
if st.button("Select image 2"):
state.image = default_image_2
with col3:
default_image_3 = "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg"
st.image(default_image_3, use_column_width=True)
if st.button("Select image 3"):
state.image = default_image_3
@staticmethod
def prompts_input(state: SessionState, input_label: str, prompt_prefix: str = ''):
raw_classes = st.text_input(input_label)
if raw_classes:
state.prompts = [prompt_prefix + class_name for class_name in raw_classes.split(";") if len(class_name) > 1]
state.prompt_prefix = prompt_prefix
@staticmethod
def input_preview(state: SessionState):
col1, col2 = st.beta_columns([2, 1])
with col1:
st.markdown("Image to classify")
if state.image is not None:
st.image(state.image, use_column_width=True)
else:
st.warning("Select an image")
with col2:
st.markdown("Labels to choose from")
if state.processed_classes is not None:
for prompt in state.prompts:
st.write(prompt[len(state.prompt_prefix):])
else:
st.warning("Enter the classes to classify from")
@staticmethod
def classification_output(state: SessionState):
# Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
if st.button("Predict"):
with st.spinner("Predicting..."):
clip_response = booste.clip(BOOSTE_API_KEY,
prompts=state.prompts,
images=[state.image],
pretty_print=True)
st.markdown("### Results")
simplified_clip_results = [(prompt[len(state.prompt_prefix):],
list(results.values())[0]["probabilityRelativeToPrompts"])
for prompt, results in clip_response.items()]
simplified_clip_results = sorted(simplified_clip_results, key=lambda x: x[1], reverse=True)
for prompt, probability in simplified_clip_results:
percentage_prob = int(probability * 100)
st.markdown(
f"###      {prompt}")
st.write(clip_response)
task_name: str = st.sidebar.radio("Task", options=["Image classification", "Image ranking", "Prompt ranking"])
session_state = get_state()
if task_name == "Image classification":
Sections.header()
Sections.image_uploader(accept_multiple_files=False)
st.markdown("or choose one from")
Sections.image_picker(session_state)
input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
Sections.prompts_input(session_state, input_label, prompt_prefix='A picture of a ')
Sections.input_preview(session_state)
Sections.classification_output(session_state)
elif task_name == "Prompt ranking":
Sections.header()
Sections.image_uploader(accept_multiple_files=False)
st.markdown("or choose one from")
Sections.image_picker(session_state)
input_label = "Enter the prompts to choose from separated by a semi-colon. " \
"(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
Sections.prompts_input(session_state, input_label)
Sections.input_preview(session_state)
Sections.classification_output(session_state)
elif task_name == "Image ranking":
Sections.header()
Sections.image_uploader(accept_multiple_files=True)
st.markdown("or use random dataset")
Sections.image_picker(session_state)
session_state.sync()
|