Spaces:
Running
Running
import os | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
import pandas as pd | |
import numpy as np | |
import joblib | |
import gradio as gr | |
import plotly.graph_objects as go | |
from io import BytesIO | |
from PIL import Image | |
from torchvision import transforms,models | |
from sklearn.preprocessing import LabelEncoder,MinMaxScaler | |
from gradio import Interface, Image, Label, HTML | |
from huggingface_hub import snapshot_download | |
# Retrieve the token from the environment variables | |
token = os.environ.get("token") | |
# Download the repository snapshot | |
local_dir = snapshot_download( | |
repo_id="robocan/GeoG_coordinate", | |
repo_type="model", | |
local_dir="SVD", | |
token=token | |
) | |
device = 'cpu' | |
le = LabelEncoder() | |
le = joblib.load("SVD/le.gz") | |
MMS = joblib.load("SVD/MMS.gz") | |
len_classes = len(le.classes_) + 1 | |
class ModelPre(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.embedding = torch.nn.Sequential( | |
*list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1], | |
torch.nn.Flatten(), | |
torch.nn.Linear(in_features=768,out_features=512), | |
torch.nn.ReLU(), | |
torch.nn.Linear(in_features=512,out_features=len_classes), | |
) | |
# Freeze all layers | |
def forward(self, data): | |
return self.embedding(data) | |
# Load the pretrained model | |
model = ModelPre() | |
#for param in model.parameters(): | |
# param.requires_grad = False | |
class GeoGcord(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.embedding = torch.nn.Sequential( | |
*list(model.children())[0][:-1], | |
torch.nn.Linear(in_features=512,out_features=256), | |
torch.nn.ReLU(), | |
torch.nn.Linear(in_features=256,out_features=128), | |
torch.nn.ReLU(), | |
torch.nn.Linear(in_features=128,out_features=2), | |
) | |
# Freeze all layers | |
def forward(self, data): | |
return self.embedding(data) | |
# Load the pre-trained model | |
model = GeoGcord() | |
model_w = torch.load("SVD/GeoG.pth", map_location=torch.device(device)) | |
model.load_state_dict(model_w['model']) | |
cmp = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Resize(size=(224, 224), antialias=True), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# Predict function for the new regression model | |
def predict(input_img): | |
with torch.inference_mode(): | |
img = cmp(input_img).unsqueeze(0) | |
res = model(img.to(device)) | |
# Assuming res is a 2-layer regression output, and MMS.inverse_transform is needed | |
prediction = MMS.inverse_transform(res.cpu().numpy()).flatten() | |
return prediction | |
# Function to generate Plotly map figure | |
def create_map_figure(lat, lon): | |
fig = go.Figure(go.Scattermapbox( | |
lat=[lat], | |
lon=[lon], | |
mode='markers', | |
marker=go.scattermapbox.Marker( | |
size=14 | |
), | |
text=[f'Lat: {lat}, Lon: {lon}'], | |
hoverinfo='text' | |
)) | |
fig.update_layout( | |
mapbox_style="open-street-map", | |
hovermode='closest', | |
mapbox=dict( | |
bearing=0, | |
center=go.layout.mapbox.Center( | |
lat=lat, | |
lon=lon | |
), | |
pitch=0, | |
zoom=10 | |
), | |
) | |
return fig | |
# Create label output function | |
def create_label_output(predictions): | |
lat, lon = predictions | |
fig = create_map_figure(lat, lon) | |
return fig | |
# Predict and plot function | |
def predict_and_plot(input_img): | |
predictions = predict(input_img) | |
return create_label_output(predictions) | |
# Gradio app definition | |
with gr.Blocks() as gradio_app: | |
with gr.Column(): | |
input_image = gr.Image(label="Upload an Image", type="pil") | |
output_map = gr.Plot(label="Predicted Location on Map") | |
btn_predict = gr.Button("Predict") | |
btn_predict.click(predict_and_plot, inputs=input_image, outputs=output_map) | |
examples = ["GB.PNG", "IT.PNG", "NL.PNG", "NZ.PNG"] | |
gr.Examples(examples=examples, inputs=input_image) | |
gradio_app.launch() |