anon5 commited on
Commit
f84bf2d
·
verified ·
1 Parent(s): 0a97144

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torch import nn
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+ import torchvision.models as models
7
+ import torchvision.transforms as transforms
8
+
9
+ CLASSES = ['guro', 'pigs', 'proofs', 'protyk', 'safe', 'shit']
10
+ NUM_CLASSES = len(CLASSES)
11
+
12
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+
14
+ model = models.resnet18(pretrained=True)
15
+ model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
16
+ model.load_state_dict(torch.load('best_model.pth'))
17
+ model.to(device)
18
+ model.eval()
19
+
20
+ # Определение трансформаций для изображений
21
+ transform = transforms.Compose([
22
+ transforms.Resize(256),
23
+ transforms.CenterCrop(224),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
26
+ ])
27
+
28
+ # Функция для предсказания
29
+ def predict(img):
30
+ img = Image.fromarray(img)
31
+ img = transform(img)
32
+
33
+ with torch.no_grad():
34
+ outputs = model(img.unsqueeze(0).to(device))
35
+ probabilities = torch.softmax(outputs, dim=1).to('cpu')
36
+ labels = [CLASSES[i] for i in range(len(CLASSES))]
37
+ result = [dict(zip(labels, probabilities.numpy()[0])), dict(zip(labels, probabilities.numpy()[0]))]
38
+
39
+ return result[0]
40
+
41
+ # Интерфейс Gradio
42
+ gr.Interface(fn=predict, inputs="image", outputs="label").launch()
43
+