File size: 1,540 Bytes
219e1f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import streamlit as st
import torch
from streamlit_drawable_canvas import st_canvas

st.set_page_config(page_title="Draw Something!", layout="centered")

if "prediction" not in st.session_state:
    st.session_state["prediction"] = "Draw something!"

st.markdown(f"<h1 style='text-align: center;'>{st.session_state['prediction']}</h1>", unsafe_allow_html=True)

processor = AutoImageProcessor.from_pretrained("kmewhort/resnet34-sketch-classifier")
model = AutoModelForImageClassification.from_pretrained("kmewhort/resnet34-sketch-classifier")

canvas = st_canvas(
    stroke_width=5,
    stroke_color="#000000",
    background_color="#FFFFFF",
    height=700,
    width=700,
    drawing_mode="freedraw",
)

def predict_drawing():
    if canvas.image_data is not None:
        drawing = canvas.image_data.astype("uint8")
        image = Image.fromarray(drawing).convert("L")
        image = image.convert("RGB")

        inputs = processor(images=image, return_tensors="pt")
        
        with torch.no_grad():
            logits = model(**inputs).logits

        predicted_class_idx = logits.argmax(-1).item()
        st.session_state["prediction"] = model.config.id2label[predicted_class_idx]
    else:
        st.session_state["prediction"] = "Draw something!"

if canvas.image_data is not None:
    predict_drawing()

css = '''
<style>
section.stMain {
    overflow: hidden;
}
</style>
'''
st.markdown(css, unsafe_allow_html=True)