|
import torch |
|
import gradio as gr |
|
from torchvision import transforms |
|
from PIL import Image |
|
|
|
|
|
class MyModel(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
|
|
def forward(self, x): |
|
|
|
return x |
|
|
|
model = MyModel() |
|
model.load_state_dict(torch.load("model.pth")) |
|
model.eval() |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
]) |
|
|
|
|
|
def predict(image): |
|
image = transform(image).unsqueeze(0) |
|
with torch.no_grad(): |
|
output = model(image) |
|
return output.numpy().tolist() |
|
|
|
|
|
iface = gr.Interface(fn=predict, inputs=gr.Image(), outputs="json") |
|
iface.launch() |
|
|