GeoGuessrRobot / app.py
robocan's picture
Update app.py
bde2d24 verified
raw
history blame
2.42 kB
import os
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import joblib
from PIL import Image
from torchvision import transforms,models
from sklearn.preprocessing import LabelEncoder
from gradio import Interface, Image, Label
from huggingface_hub import Repository
# Retrieve the token from the environment variables
token = os.environ.get("token")
repo = Repository(
local_dir="SVD",
repo_type="model",
clone_from="robocan/GeoG_City",
token=token
)
repo.git_pull()
device = 'cpu'
le = LabelEncoder()
le = joblib.load("SVD/le.gz")
len_classes = len(le.classes_) + 1
class ModelPre(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Sequential(
*list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1],
torch.nn.Flatten(),
torch.nn.Linear(in_features=768,out_features=512),
torch.nn.ReLU(),
torch.nn.Linear(in_features=512,out_features=len_classes),
)
def forward(self, data):
return self.embedding(data)
model = torch.load("SVD/GeoG.pth", map_location=torch.device(device))
modelm = ModelPre()
modelm.load_state_dict(model['model'])
cmp = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(size=(224, 224), antialias=True),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def predict(input_img):
with torch.inference_mode():
img = cmp(input_img).unsqueeze(0)
res = modelm(img.to(device))
probabilities = torch.softmax(res, dim=1).cpu().numpy().flatten()
top_10_indices = np.argsort(probabilities)[-10:][::-1]
top_10_probabilities = probabilities[top_10_indices]
top_10_predictions = le.inverse_transform(top_10_indices)
results = {top_10_predictions[i]: float(top_10_probabilities[i]) for i in range(10)}
return results
def create_label_output(predictions):
return predictions
def predict_and_plot(input_img):
predictions = predict(input_img)
return create_label_output(predictions)
gradio_app = Interface(
fn=predict_and_plot,
inputs=Image(label="Upload an Image", type="pil"),
outputs=Label(num_top_classes=10),
title="Predict the Location of this Image"
)
if __name__ == "__main__":
gradio_app.launch()