jarif commited on
Commit
1c1ac31
·
verified ·
1 Parent(s): 5da88a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -49
app.py CHANGED
@@ -1,49 +1,48 @@
1
- import torch
2
- import torch.nn as nn
3
- import torchvision.transforms as transforms
4
- from PIL import Image
5
- import gradio as gr
6
-
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
-
9
- model = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=True)
10
- n_classes = 10
11
- model.fc = nn.Linear(model.fc.in_features, n_classes)
12
- model = model.to(device)
13
-
14
- model.load_state_dict(torch.load("NumtaDB_Classifier_Model.pth", map_location=device))
15
- model.eval()
16
-
17
- transform = transforms.Compose([
18
- transforms.Resize((299, 299)),
19
- transforms.Grayscale(num_output_channels=3),
20
- transforms.ToTensor(),
21
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
22
- ])
23
-
24
- label_name = ["Zero", "One", "Two", "Three", "Four", "Five", "Six", "Seven", "Nine", "Ten"]
25
-
26
- def predict(image):
27
- if not isinstance(image, Image.Image):
28
- image = Image.fromarray(image)
29
-
30
- image_tensor = transform(image).unsqueeze(0).to(device)
31
-
32
- with torch.no_grad():
33
- outputs = model(image_tensor)
34
- probs = torch.softmax(outputs, dim=1)
35
-
36
- predictions = {label_name[i]: float(probs[0][i]) for i in range(len(label_name))}
37
-
38
- return predictions
39
-
40
- iface = gr.Interface(
41
- fn=predict,
42
- inputs=gr.Image(label="Upload Image"),
43
- outputs=gr.Label(num_top_classes=len(label_name)),
44
- title="BanglaDigitPro: Advanced Bengali Numeral Recognition",
45
- description="Upload an image of a handwritten Bangla digit to classify it.",
46
- examples=[["example_1.png"], ["example_2.png"], ["example_3.png"], ["example_4.png"], ["example_5.png"]]
47
- )
48
-
49
- iface.launch(share=True)
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ import gradio as gr
6
+
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+ model = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=True)
10
+ n_classes = 10
11
+ model.fc = nn.Linear(model.fc.in_features, n_classes)
12
+ model = model.to(device)
13
+
14
+ model.load_state_dict(torch.load("NumtaDB_Classifier_Model.pth", map_location=device))
15
+ model.eval()
16
+
17
+ transform = transforms.Compose([
18
+ transforms.Resize((299, 299)),
19
+ transforms.Grayscale(num_output_channels=3),
20
+ transforms.ToTensor(),
21
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
22
+ ])
23
+
24
+ label_name = ["Zero", "One", "Two", "Three", "Four", "Five", "Six", "Seven", "Nine", "Ten"]
25
+
26
+ def predict(image):
27
+ if not isinstance(image, Image.Image):
28
+ image = Image.fromarray(image)
29
+
30
+ image_tensor = transform(image).unsqueeze(0).to(device)
31
+
32
+ with torch.no_grad():
33
+ outputs = model(image_tensor)
34
+ probs = torch.softmax(outputs, dim=1)
35
+
36
+ predictions = {label_name[i]: float(probs[0][i]) for i in range(len(label_name))}
37
+
38
+ return predictions
39
+
40
+ iface = gr.Interface(
41
+ fn=predict,
42
+ inputs=gr.Image(label="Upload Image"),
43
+ outputs=gr.Label(num_top_classes=len(label_name)),
44
+ title="BanglaDigitPro: Advanced Bengali Numeral Recognition",
45
+ description="Upload an image of a handwritten Bangla digit to classify it.",
46
+ examples=[["example_1.png"], ["example_2.png"], ["example_3.png"], ["example_4.png"], ["example_5.png"]])
47
+
48
+ iface.launch(share=True)