titipata commited on
Commit
21fc817
·
verified ·
1 Parent(s): 7f180e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -42
app.py CHANGED
@@ -47,56 +47,102 @@ model.load_state_dict(torch.load("thai_digit_net.pth"))
47
  model.eval()
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def predict(img):
51
  """
52
- Predict function takes image and return top 5 predictions
53
  as a dictionary:
 
54
  {label: confidence, label: confidence, ...}
55
  """
56
- if img.get("composite") is not None:
57
- if img["composite"].sum() == 0:
58
- return {"No input sketch": 0.0}
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- img_data = img['composite']
61
- img_gray = Image.fromarray(img_data).convert('L').resize((28, 28))
62
- img_tensor = transforms.ToTensor()(img_gray).unsqueeze(0)
63
 
64
- # Make prediction
65
- with torch.no_grad():
66
- probs = model(img_tensor).softmax(dim=1).squeeze()
67
 
68
- probs, indices = torch.topk(probs, 5) # select top 5
69
- probs, indices = probs.tolist(), indices.tolist() # transform to list
70
- return {LABELS[i]: float(v) for i, v in zip(indices, probs)}
71
-
72
- js_func = """
73
- function refresh() {
74
- const url = new URL(window.location);
75
-
76
- if (url.searchParams.get('__theme') !== 'dark') {
77
- url.searchParams.set('__theme', 'dark');
78
- window.location.href = url.href;
79
- }
80
- }
81
- """
82
-
83
- with gr.Blocks(js=js_func) as demo:
84
- gr.Interface(
85
  fn=predict,
86
- inputs=gr.Sketchpad(
87
- label="Draw Here",
88
- brush=gr.Brush(default_size=14, default_color="#FFFFFF", colors=["#FFFFFF"]),
89
- image_mode="L",
90
- layers=False,
91
- eraser=None,
92
- width=400,
93
- height=350
94
- ),
95
- outputs=gr.Label(label="Guess"),
96
- title="Thai Digit Handwritten Classification",
97
- description="ทดลองวาดภาพตัวอักษรเลขไทยลงใน Sketchpad ด้านล่างเพื่อทำนายผลตัวเลข ตั้งแต่ ๐ (ศูนย์) ๑ (หนึ่ง) ๒ (สอง) ๓ (สาม) ๔ (สี่) ๕ (ห้า) ๖ (หก) ๗ (เจ็ด) ๘ (แปด) จนถึง ๙ (เก้า)",
98
- live=True
99
  )
100
 
101
- if __name__ == "__main__":
102
- demo.launch()
 
47
  model.eval()
48
 
49
 
50
+ import numpy as np
51
+ import torch
52
+ from pathlib import Path
53
+ import torch.nn as nn
54
+ import torch.nn.functional as F
55
+ from PIL import Image
56
+ from torchvision import transforms
57
+ import gradio as gr
58
+
59
+
60
+ transform = transforms.Compose([
61
+ transforms.Resize((28, 28)),
62
+ transforms.Grayscale(),
63
+ transforms.ToTensor()
64
+ ])
65
+ labels = ["๐ (ศูนย์)", "๑ (หนึ่ง)", "๒ (สอง)", "๓ (สาม)", "๔ (สี่)", "๕ (ห้า)", "๖ (หก)", "๗ (เจ็ด)", "๘ (แปด)", "๙ (เก้า)"]
66
+ LABELS = {i:k for i, k in enumerate(labels)} # dictionary of index and label
67
+
68
+
69
+ # Load model using DropoutThaiDigit instead
70
+ class DropoutThaiDigit(nn.Module):
71
+ def __init__(self):
72
+ super(DropoutThaiDigit, self).__init__()
73
+ self.fc1 = nn.Linear(28 * 28, 392)
74
+ self.fc2 = nn.Linear(392, 196)
75
+ self.fc3 = nn.Linear(196, 98)
76
+ self.fc4 = nn.Linear(98, 10)
77
+ self.dropout = nn.Dropout(0.1)
78
+
79
+ def forward(self, x):
80
+ x = x.view(-1, 28 * 28)
81
+ x = self.fc1(x)
82
+ x = F.relu(x)
83
+ x = self.dropout(x)
84
+ x = self.fc2(x)
85
+ x = F.relu(x)
86
+ x = self.dropout(x)
87
+ x = self.fc3(x)
88
+ x = F.relu(x)
89
+ x = self.dropout(x)
90
+ x = self.fc4(x)
91
+ return x
92
+
93
+
94
+ model = DropoutThaiDigit()
95
+ model.load_state_dict(torch.load("thai_digit_net.pth"))
96
+ model.eval()
97
+
98
+
99
  def predict(img):
100
  """
101
+ Predict function takes image editor data and returns top 5 predictions
102
  as a dictionary:
103
+
104
  {label: confidence, label: confidence, ...}
105
  """
106
+ if img is None:
107
+ return {}
108
+
109
+ # Handle if Sketchpad returns a dictionary
110
+ if isinstance(img, dict):
111
+ # Try common keys that might contain the image
112
+ img = img.get('image') or img.get('composite') or img.get('background')
113
+ if img is None:
114
+ return {}
115
+
116
+ img = 1 - transform(img) # do not need to use 1 - transform(img) because gradio already do it
117
+ probs = model(img).softmax(dim=1).ravel()
118
+ probs, indices = torch.topk(probs, 5) # select top 5
119
+ confidences = {LABELS[i]: float(prob) for i, prob in zip(indices.tolist(), probs.tolist())}
120
+ return confidences
121
 
 
 
 
122
 
123
+ with gr.Blocks(title="Thai Digit Handwritten Classification") as interface:
124
+ gr.Markdown("# Thai Digit Handwritten Classification")
125
+ gr.Markdown("Draw a Thai digit (๐-๙) in the box below:")
126
 
127
+ with gr.Row():
128
+ with gr.Column():
129
+ input_component = gr.Sketchpad(
130
+ label="Draw Here",
131
+ height=300,
132
+ width=300,
133
+ brush=gr.Brush(default_size=8, colors=["#000000"]),
134
+ type="pil",
135
+ canvas_size=(300, 300),
136
+ )
137
+
138
+ with gr.Column():
139
+ output_component = gr.Label(label="Prediction", num_top_classes=5)
140
+
141
+ # Set up the prediction
142
+ input_component.change(
 
143
  fn=predict,
144
+ inputs=input_component,
145
+ outputs=output_component
 
 
 
 
 
 
 
 
 
 
 
146
  )
147
 
148
+ interface.launch()