from transformers import AutoImageProcessor, AutoModelForImageClassification from PIL import Image import streamlit as st import torch from streamlit_drawable_canvas import st_canvas st.set_page_config(page_title="Draw Something!", layout="centered") if "prediction" not in st.session_state: st.session_state["prediction"] = "Draw something!" st.markdown(f"

{st.session_state['prediction']}

", unsafe_allow_html=True) processor = AutoImageProcessor.from_pretrained("kmewhort/resnet34-sketch-classifier") model = AutoModelForImageClassification.from_pretrained("kmewhort/resnet34-sketch-classifier") canvas = st_canvas( stroke_width=5, stroke_color="#000000", background_color="#FFFFFF", height=700, width=700, drawing_mode="freedraw", ) def predict_drawing(): if canvas.image_data is not None: drawing = canvas.image_data.astype("uint8") image = Image.fromarray(drawing).convert("L") image = image.convert("RGB") inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits predicted_class_idx = logits.argmax(-1).item() st.session_state["prediction"] = model.config.id2label[predicted_class_idx] else: st.session_state["prediction"] = "Draw something!" if canvas.image_data is not None: predict_drawing() css = ''' ''' st.markdown(css, unsafe_allow_html=True)