GeoGuessrRobot / app.py
robocan's picture
Update app.py
4b52e6f verified
raw
history blame
3.64 kB
import os
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import joblib
from io import BytesIO
import folium
from PIL import Image
from torchvision import transforms,models
from sklearn.preprocessing import LabelEncoder,MinMaxScaler
from gradio import Interface, Image, Label, HTML
from huggingface_hub import snapshot_download
# Retrieve the token from the environment variables
token = os.environ.get("token")
# Download the repository snapshot
local_dir = snapshot_download(
repo_id="robocan/GeoG_coordinate",
repo_type="model",
local_dir="SVD",
token=token
)
device = 'cpu'
le = LabelEncoder()
le = joblib.load("SVD/le.gz")
MMS = joblib.load("SVD/MMS.gz")
len_classes = len(le.classes_) + 1
class ModelPre(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Sequential(
*list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1],
torch.nn.Flatten(),
torch.nn.Linear(in_features=768,out_features=512),
torch.nn.ReLU(),
torch.nn.Linear(in_features=512,out_features=len_classes),
)
# Freeze all layers
def forward(self, data):
return self.embedding(data)
# Load the pretrained model
model = ModelPre()
#for param in model.parameters():
# param.requires_grad = False
class GeoGcord(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Sequential(
*list(model.children())[0][:-1],
torch.nn.Linear(in_features=512,out_features=256),
torch.nn.ReLU(),
torch.nn.Linear(in_features=256,out_features=128),
torch.nn.ReLU(),
torch.nn.Linear(in_features=128,out_features=2),
)
# Freeze all layers
def forward(self, data):
return self.embedding(data)
# Load the pre-trained model
model = GeoGcord()
model_w = torch.load("SVD/GeoG.pth", map_location=torch.device(device))
model.load_state_dict(model_w['model'])
cmp = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(size=(224, 224), antialias=True),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Predict function for the new regression model
def predict(input_img):
with torch.inference_mode():
img = cmp(input_img).unsqueeze(0)
res = model(img.to(device))
# Assuming res is a 2-layer regression output, and MMS.inverse_transform is needed
prediction = MMS.inverse_transform(res.cpu().numpy()).flatten()
return prediction
# Function to generate HTML for map
def create_map_html(lat, lon):
m = folium.Map(location=[lat, lon], zoom_start=12)
folium.Marker([lat, lon]).add_to(m)
data = BytesIO()
m.save(data, close_file=False)
return data.getvalue().decode()
# Create label output function
def create_label_output(predictions):
lat, lon = predictions
map_html = create_map_html(lat, lon)
return f"<div><h3>Predicted coordinates: ({lat:.6f}, {lon:.6f})</h3>{map_html}</div>"
# Predict and plot function
def predict_and_plot(input_img):
predictions = predict(input_img)
return create_label_output(predictions)
# Gradio app definition
gradio_app = Interface(
fn=predict_and_plot,
inputs=Image(label="Upload an Image", type="pil"),
examples=["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"],
outputs=HTML(),
title="Predict the Location of this Image"
)
if __name__ == "__main__":
gradio_app.launch()