import gradio as gr import gradio.components as grc from torchvision import transforms from transformers import ViTForImageClassification model_path = "Inf009/view-angle" model = ViTForImageClassification.from_pretrained(model_path) model.eval() val_transforms = transforms.Compose( [ transforms.Resize((224, 224)), transforms.CenterCrop(224), transforms.ToTensor(), ] ) def predict_view_angle(image): image = val_transforms(image) outputs = model(image.unsqueeze(0)).logits.squeeze(0).sigmoid().detach().numpy() indices = sorted(range(len(outputs)), key=lambda x: outputs[x], reverse=True) predict_tags = ["45度俯视", "俯视", "正视"] return predict_tags[indices[0]] app = gr.Interface(fn=predict_view_angle, inputs=grc.Image(type="pil"), outputs=grc.Textbox()) app.launch()