Devin Xie commited on
Commit
219e1f9
·
1 Parent(s): e5dae6e

initial commit

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. .streamlit/config.toml +2 -0
  3. app.py +53 -0
  4. requirements.txt +6 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ draw_venv/
.streamlit/config.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [theme]
2
+ backgroundColor="rgb(150, 150, 150)"
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
2
+ from PIL import Image
3
+ import streamlit as st
4
+ import torch
5
+ from streamlit_drawable_canvas import st_canvas
6
+
7
+ st.set_page_config(page_title="Draw Something!", layout="centered")
8
+
9
+ if "prediction" not in st.session_state:
10
+ st.session_state["prediction"] = "Draw something!"
11
+
12
+ st.markdown(f"<h1 style='text-align: center;'>{st.session_state['prediction']}</h1>", unsafe_allow_html=True)
13
+
14
+ processor = AutoImageProcessor.from_pretrained("kmewhort/resnet34-sketch-classifier")
15
+ model = AutoModelForImageClassification.from_pretrained("kmewhort/resnet34-sketch-classifier")
16
+
17
+ canvas = st_canvas(
18
+ stroke_width=5,
19
+ stroke_color="#000000",
20
+ background_color="#FFFFFF",
21
+ height=700,
22
+ width=700,
23
+ drawing_mode="freedraw",
24
+ )
25
+
26
+ def predict_drawing():
27
+ if canvas.image_data is not None:
28
+ drawing = canvas.image_data.astype("uint8")
29
+ image = Image.fromarray(drawing).convert("L")
30
+ image = image.convert("RGB")
31
+
32
+ inputs = processor(images=image, return_tensors="pt")
33
+
34
+ with torch.no_grad():
35
+ logits = model(**inputs).logits
36
+
37
+ predicted_class_idx = logits.argmax(-1).item()
38
+ st.session_state["prediction"] = model.config.id2label[predicted_class_idx]
39
+ else:
40
+ st.session_state["prediction"] = "Draw something!"
41
+
42
+ if canvas.image_data is not None:
43
+ predict_drawing()
44
+
45
+ css = '''
46
+ <style>
47
+ section.stMain {
48
+ overflow: hidden;
49
+ }
50
+ </style>
51
+ '''
52
+ st.markdown(css, unsafe_allow_html=True)
53
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ Pillow==11.1.0
2
+ streamlit==1.42.0
3
+ streamlit_drawable_canvas==0.9.3
4
+ transformers==4.48.3
5
+ torch==2.6.0
6
+ torchvision==0.21.0