import os import torch from torch.utils.data import Dataset, DataLoader import pandas as pd import numpy as np import joblib import gradio as gr import plotly.graph_objects as go from io import BytesIO from PIL import Image from torchvision import transforms,models from sklearn.preprocessing import LabelEncoder,MinMaxScaler from gradio import Interface, Image, Label, HTML from huggingface_hub import snapshot_download # Retrieve the token from the environment variables token = os.environ.get("token") # Download the repository snapshot local_dir = snapshot_download( repo_id="robocan/GeoG_coordinate", repo_type="model", local_dir="SVD", token=token ) device = 'cpu' le = LabelEncoder() le = joblib.load("SVD/le.gz") MMS = joblib.load("SVD/MMS.gz") len_classes = len(le.classes_) + 1 class ModelPre(torch.nn.Module): def __init__(self): super().__init__() self.embedding = torch.nn.Sequential( *list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1], torch.nn.Flatten(), torch.nn.Linear(in_features=768,out_features=512), torch.nn.ReLU(), torch.nn.Linear(in_features=512,out_features=len_classes), ) # Freeze all layers def forward(self, data): return self.embedding(data) # Load the pretrained model model = ModelPre() #for param in model.parameters(): # param.requires_grad = False class GeoGcord(torch.nn.Module): def __init__(self): super().__init__() self.embedding = torch.nn.Sequential( *list(model.children())[0][:-1], torch.nn.Linear(in_features=512,out_features=256), torch.nn.ReLU(), torch.nn.Linear(in_features=256,out_features=128), torch.nn.ReLU(), torch.nn.Linear(in_features=128,out_features=2), ) # Freeze all layers def forward(self, data): return self.embedding(data) # Load the pre-trained model model = GeoGcord() model_w = torch.load("SVD/GeoG.pth", map_location=torch.device(device)) model.load_state_dict(model_w['model']) cmp = transforms.Compose([ transforms.ToTensor(), transforms.Resize(size=(224, 224), antialias=True), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Predict function for the new regression model def predict(input_img): with torch.inference_mode(): img = cmp(input_img).unsqueeze(0) res = model(img.to(device)) # Assuming res is a 2-layer regression output, and MMS.inverse_transform is needed prediction = MMS.inverse_transform(res.cpu().numpy()).flatten() return prediction # Function to generate Plotly map figure def create_map_figure(lat, lon): fig = go.Figure(go.Scattermapbox( lat=[lat], lon=[lon], mode='markers', marker=go.scattermapbox.Marker( size=14 ), text=[f'Lat: {lat}, Lon: {lon}'], hoverinfo='text' )) fig.update_layout( mapbox_style="open-street-map", hovermode='closest', mapbox=dict( bearing=0, center=go.layout.mapbox.Center( lat=lat, lon=lon ), pitch=0, zoom=10 ), ) return fig # Create label output function def create_label_output(predictions): lat, lon = predictions fig = create_map_figure(lat, lon) return fig # Predict and plot function def predict_and_plot(input_img): predictions = predict(input_img) return create_label_output(predictions) # Gradio app definition with gr.Blocks() as gradio_app: with gr.Column(): input_image = gr.Image(label="Upload an Image", type="pil") output_map = gr.Plot(label="Predicted Location on Map") btn_predict = gr.Button("Predict") btn_predict.click(predict_and_plot, inputs=input_image, outputs=output_map) gradio_app.launch()