File size: 847 Bytes
215a9c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import gradio as gr 
from torchvision import transforms
from transformers import ViTForImageClassification

model_path = "Inf009/view-angle"
model = ViTForImageClassification.from_pretrained(model_path)
model.eval()
val_transforms = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ]
    )

def predict_view_angle(image):
    image = val_transforms(image)
    outputs = model(image.unsqueeze(0)).logits.squeeze(0).sigmoid().detach().numpy()
    indices = sorted(range(len(outputs)), key=lambda x: outputs[x], reverse=True)
    predict_tags = ["45度俯视", "俯视", "正视"]
    return predict_tags[indices[0]]

    
app = gr.Interface(fn=predict_view_angle, inputs=gr.inputs.Image(type="pil"), outputs=gr.outputs.Textbox())
app.launch()