File size: 999 Bytes
9912c31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import gradio as gr
import pandas as pd
import torch
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import hiera
df=pd.read_csv('Imagenet.txt',usecols=[0],header=None)
model = hiera.hiera_base_224(pretrained=True, checkpoint="mae_in1k_ft_in1k")
input_size = 224
transform_list = [
transforms.Resize(int((256 / 224) * input_size), interpolation=InterpolationMode.BICUBIC),
transforms.CenterCrop(input_size)
]
transform_norm = transforms.Compose(transform_list + [
transforms.ToTensor(),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
def recognize(img):
  img1=img.resize((224,224))
  img_norm = transform_norm(img1)
  output = model(img_norm[None,])
  out=output.argmax(dim=-1).item()
  out1=(df.iloc[out,0])
  return out1
demo = gr.Interface(fn=recognize, inputs='pil',outputs='text',examples= [['Banana.jpg']])
demo.launch()