|
import numpy as np |
|
import torch |
|
from pathlib import Path |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
from torchvision import transforms |
|
import gradio as gr |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((28, 28)), |
|
transforms.Grayscale(), |
|
transforms.ToTensor() |
|
]) |
|
labels = ["๐ (ศูนย์)", "๑ (หนึ่ง)", "๒ (สอง)", "๓ (สาม)", "๔ (สี่)", "๕ (ห้า)", "๖ (หก)", "๗ (เจ็ด)", "๘ (แปด)", "๙ (เก้า)"] |
|
LABELS = {i:k for i, k in enumerate(labels)} |
|
|
|
|
|
|
|
class DropoutThaiDigit(nn.Module): |
|
def __init__(self): |
|
super(DropoutThaiDigit, self).__init__() |
|
self.fc1 = nn.Linear(28 * 28, 392) |
|
self.fc2 = nn.Linear(392, 196) |
|
self.fc3 = nn.Linear(196, 98) |
|
self.fc4 = nn.Linear(98, 10) |
|
self.dropout = nn.Dropout(0.1) |
|
|
|
def forward(self, x): |
|
x = x.view(-1, 28 * 28) |
|
x = self.fc1(x) |
|
x = F.relu(x) |
|
x = self.dropout(x) |
|
x = self.fc2(x) |
|
x = F.relu(x) |
|
x = self.dropout(x) |
|
x = self.fc3(x) |
|
x = F.relu(x) |
|
x = self.dropout(x) |
|
x = self.fc4(x) |
|
return x |
|
|
|
|
|
model = DropoutThaiDigit() |
|
model.load_state_dict(torch.load("thai_digit_net.pth")) |
|
model.eval() |
|
|
|
|
|
def predict(img): |
|
""" |
|
Predict function takes image and return top 5 predictions |
|
as a dictionary: |
|
|
|
{label: confidence, label: confidence, ...} |
|
""" |
|
if img is None: |
|
return None |
|
img = transform(img) |
|
probs = model(img).softmax(dim=1).ravel() |
|
probs, indices = torch.topk(probs, 5) |
|
probs, indices = probs.tolist(), indices.tolist() |
|
confidences = {LABELS[i]: v for i, v in zip(indices, probs)} |
|
return confidences |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Sketchpad(label="Draw Here", brush_radius=5, shape=(120, 120)), |
|
outputs=gr.Label(label="Guess"), |
|
title="Thai Digit Handwritten Classification", |
|
description="ทดลองวาดภาพตัวอักษรเลขไทยลงใน Sketchpad ด้านล่างเพื่อทำนายผลตัวเลข ตั้งแต่ ๐ (ศูนย์) ๑ (หนึ่ง) ๒ (สอง) ๓ (สาม) ๔ (สี่) ๕ (ห้า) ๖ (หก) ๗ (เจ็ด) ๘ (แปด) จนถึง ๙ (เก้า)", |
|
live=True |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|