baixintech_zhangyiming_prod commited on
Commit
215a9c6
·
1 Parent(s): 9ae7c96
Files changed (2) hide show
  1. requirements.txt +7 -0
  2. view_angle_app.py +25 -0
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ huggingface-hub
5
+ opencv-python
6
+ timm>=0.6.12
7
+ transformers
view_angle_app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from torchvision import transforms
3
+ from transformers import ViTForImageClassification
4
+
5
+ model_path = "Inf009/view-angle"
6
+ model = ViTForImageClassification.from_pretrained(model_path)
7
+ model.eval()
8
+ val_transforms = transforms.Compose(
9
+ [
10
+ transforms.Resize((224, 224)),
11
+ transforms.CenterCrop(224),
12
+ transforms.ToTensor(),
13
+ ]
14
+ )
15
+
16
+ def predict_view_angle(image):
17
+ image = val_transforms(image)
18
+ outputs = model(image.unsqueeze(0)).logits.squeeze(0).sigmoid().detach().numpy()
19
+ indices = sorted(range(len(outputs)), key=lambda x: outputs[x], reverse=True)
20
+ predict_tags = ["45度俯视", "俯视", "正视"]
21
+ return predict_tags[indices[0]]
22
+
23
+
24
+ app = gr.Interface(fn=predict_view_angle, inputs=gr.inputs.Image(type="pil"), outputs=gr.outputs.Textbox())
25
+ app.launch()