Spaces:
Sleeping
Sleeping
from transformers import AutoImageProcessor, AutoModelForImageClassification | |
from PIL import Image | |
import streamlit as st | |
import torch | |
from streamlit_drawable_canvas import st_canvas | |
st.set_page_config(page_title="Draw Something!", layout="centered") | |
if "prediction" not in st.session_state: | |
st.session_state["prediction"] = "Draw something!" | |
st.markdown(f"<h1 style='text-align: center;'>{st.session_state['prediction']}</h1>", unsafe_allow_html=True) | |
processor = AutoImageProcessor.from_pretrained("kmewhort/resnet34-sketch-classifier") | |
model = AutoModelForImageClassification.from_pretrained("kmewhort/resnet34-sketch-classifier") | |
canvas = st_canvas( | |
stroke_width=5, | |
stroke_color="#000000", | |
background_color="#FFFFFF", | |
height=700, | |
width=700, | |
drawing_mode="freedraw", | |
) | |
def predict_drawing(): | |
if canvas.image_data is not None: | |
drawing = canvas.image_data.astype("uint8") | |
image = Image.fromarray(drawing).convert("L") | |
image = image.convert("RGB") | |
inputs = processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
predicted_class_idx = logits.argmax(-1).item() | |
st.session_state["prediction"] = model.config.id2label[predicted_class_idx] | |
else: | |
st.session_state["prediction"] = "Draw something!" | |
if canvas.image_data is not None: | |
predict_drawing() | |
css = ''' | |
<style> | |
section.stMain { | |
overflow: hidden; | |
} | |
</style> | |
''' | |
st.markdown(css, unsafe_allow_html=True) | |