import os import torch from torch.utils.data import Dataset, DataLoader import pandas as pd import numpy as np import joblib from io import BytesIO from PIL import Image from torchvision import transforms,models from sklearn.preprocessing import LabelEncoder,MinMaxScaler from gradio import Interface, Image, Label, HTML from huggingface_hub import snapshot_download # Retrieve the token from the environment variables token = os.environ.get("token") # Download the repository snapshot local_dir = snapshot_download( repo_id="robocan/GeoG_coordinate", repo_type="model", local_dir="SVD", token=token ) device = 'cpu' le = LabelEncoder() le = joblib.load("SVD/le.gz") MMS = joblib.load("SVD/MMS.gz") len_classes = len(le.classes_) + 1 class ModelPre(torch.nn.Module): def __init__(self): super().__init__() self.embedding = torch.nn.Sequential( *list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1], torch.nn.Flatten(), torch.nn.Linear(in_features=768,out_features=512), torch.nn.ReLU(), torch.nn.Linear(in_features=512,out_features=len_classes), ) # Freeze all layers def forward(self, data): return self.embedding(data) # Load the pretrained model model = ModelPre() #for param in model.parameters(): # param.requires_grad = False class GeoGcord(torch.nn.Module): def __init__(self): super().__init__() self.embedding = torch.nn.Sequential( *list(model.children())[0][:-1], torch.nn.Linear(in_features=512,out_features=256), torch.nn.ReLU(), torch.nn.Linear(in_features=256,out_features=128), torch.nn.ReLU(), torch.nn.Linear(in_features=128,out_features=2), ) # Freeze all layers def forward(self, data): return self.embedding(data) # Load the pre-trained model model = GeoGcord() model_w = torch.load("SVD/GeoG.pth", map_location=torch.device(device)) model.load_state_dict(model_w['model']) cmp = transforms.Compose([ transforms.ToTensor(), transforms.Resize(size=(224, 224), antialias=True), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Predict function for the new regression model def predict(input_img): with torch.inference_mode(): img = cmp(input_img).unsqueeze(0) res = model(img.to(device)) # Assuming res is a 2-layer regression output, and MMS.inverse_transform is needed prediction = MMS.inverse_transform(res.cpu().numpy()).flatten() return prediction # Function to generate Plotly map figure def create_map_figure(lat, lon): fig = go.Figure(go.Scattermapbox( lat=[lat], lon=[lon], mode='markers', marker=go.scattermapbox.Marker( size=14 ), text=[f'Lat: {lat}, Lon: {lon}'], hoverinfo='text' )) fig.update_layout( mapbox_style="open-street-map", hovermode='closest', mapbox=dict( bearing=0, center=go.layout.mapbox.Center( lat=lat, lon=lon ), pitch=0, zoom=10 ), ) return fig # Create label output function def create_label_output(predictions): lat, lon = predictions fig = create_map_figure(lat, lon) return fig # Predict and plot function def predict_and_plot(input_img): predictions = predict(input_img) return create_label_output(predictions) # Gradio app definition with gr.Blocks() as gradio_app: with gr.Column(): input_image = gr.Image(label="Upload an Image", type="pil") output_map = gr.Plot(label="Predicted Location on Map") btn_predict = gr.Button("Predict") btn_predict.click(predict_and_plot, inputs=input_image, outputs=output_map) gradio_app.launch()