from transformers import AutoFeatureExtractor, ResNetForImageClassification import torch # from datasets import load_dataset # dataset = load_dataset("huggingface/cats-image") # image = dataset["test"]["image"][0] feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50") model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50") import gradio as gr def segment(image): inputs = feature_extractor(image, return_tensors="pt") # with torch.no_grad(): # logits = model(**inputs).logits # model predicts one of the 1000 ImageNet classes # predicted_label = logits.argmax(-1).item() # return model.config.id2label[predicted_label] with torch.no_grad(): prediction = torch.nn.functional.softmax(model(**inputs)[0], dim=0) return {model.config.id2label[i]: float(prediction[i]) for i in range(3)} # gr.Interface(fn=segment, inputs="image", outputs="text").launch() gr.Interface(fn=segment, inputs="image", outputs="label").launch()