anon5's picture
Create app.py
f84bf2d verified
raw
history blame
1.37 kB
import gradio as gr
import torch
from torch import nn
from PIL import Image
from torchvision import transforms
import torchvision.models as models
import torchvision.transforms as transforms
CLASSES = ['guro', 'pigs', 'proofs', 'protyk', 'safe', 'shit']
NUM_CLASSES = len(CLASSES)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
model.load_state_dict(torch.load('best_model.pth'))
model.to(device)
model.eval()
# Определение трансформаций для изображений
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Функция для предсказания
def predict(img):
img = Image.fromarray(img)
img = transform(img)
with torch.no_grad():
outputs = model(img.unsqueeze(0).to(device))
probabilities = torch.softmax(outputs, dim=1).to('cpu')
labels = [CLASSES[i] for i in range(len(CLASSES))]
result = [dict(zip(labels, probabilities.numpy()[0])), dict(zip(labels, probabilities.numpy()[0]))]
return result[0]
# Интерфейс Gradio
gr.Interface(fn=predict, inputs="image", outputs="label").launch()