Spaces:
Running
Running
File size: 3,973 Bytes
9b889da afe2deb 8a4658f afe2deb 4b52e6f afe2deb bde2d24 f07227d 4b52e6f dbf177b 9b889da dbf177b f07227d db673be dbf177b db673be 9b889da 955fc23 2a50088 f07227d 4184b6d 955fc23 be60ccb 955fc23 be60ccb dccd8f9 955fc23 f07227d 955fc23 f07227d 955fc23 f07227d 955fc23 f07227d 955fc23 6d49cf1 f07227d 6d49cf1 955fc23 f07227d 6d49cf1 24d6812 4b52e6f f07227d afe2deb 4b52e6f 24d6812 8a4658f f07227d 8a4658f afe2deb f07227d 24d6812 6d49cf1 9678900 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import os
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import joblib
from io import BytesIO
from PIL import Image
from torchvision import transforms,models
from sklearn.preprocessing import LabelEncoder,MinMaxScaler
from gradio import Interface, Image, Label, HTML
from huggingface_hub import snapshot_download
# Retrieve the token from the environment variables
token = os.environ.get("token")
# Download the repository snapshot
local_dir = snapshot_download(
repo_id="robocan/GeoG_coordinate",
repo_type="model",
local_dir="SVD",
token=token
)
device = 'cpu'
le = LabelEncoder()
le = joblib.load("SVD/le.gz")
MMS = joblib.load("SVD/MMS.gz")
len_classes = len(le.classes_) + 1
class ModelPre(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Sequential(
*list(models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).children())[:-1],
torch.nn.Flatten(),
torch.nn.Linear(in_features=768,out_features=512),
torch.nn.ReLU(),
torch.nn.Linear(in_features=512,out_features=len_classes),
)
# Freeze all layers
def forward(self, data):
return self.embedding(data)
# Load the pretrained model
model = ModelPre()
#for param in model.parameters():
# param.requires_grad = False
class GeoGcord(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Sequential(
*list(model.children())[0][:-1],
torch.nn.Linear(in_features=512,out_features=256),
torch.nn.ReLU(),
torch.nn.Linear(in_features=256,out_features=128),
torch.nn.ReLU(),
torch.nn.Linear(in_features=128,out_features=2),
)
# Freeze all layers
def forward(self, data):
return self.embedding(data)
# Load the pre-trained model
model = GeoGcord()
model_w = torch.load("SVD/GeoG.pth", map_location=torch.device(device))
model.load_state_dict(model_w['model'])
cmp = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(size=(224, 224), antialias=True),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Predict function for the new regression model
def predict(input_img):
with torch.inference_mode():
img = cmp(input_img).unsqueeze(0)
res = model(img.to(device))
# Assuming res is a 2-layer regression output, and MMS.inverse_transform is needed
prediction = MMS.inverse_transform(res.cpu().numpy()).flatten()
return prediction
# Function to generate Plotly map figure
def create_map_figure(lat, lon):
fig = go.Figure(go.Scattermapbox(
lat=[lat],
lon=[lon],
mode='markers',
marker=go.scattermapbox.Marker(
size=14
),
text=[f'Lat: {lat}, Lon: {lon}'],
hoverinfo='text'
))
fig.update_layout(
mapbox_style="open-street-map",
hovermode='closest',
mapbox=dict(
bearing=0,
center=go.layout.mapbox.Center(
lat=lat,
lon=lon
),
pitch=0,
zoom=10
),
)
return fig
# Create label output function
def create_label_output(predictions):
lat, lon = predictions
fig = create_map_figure(lat, lon)
return fig
# Predict and plot function
def predict_and_plot(input_img):
predictions = predict(input_img)
return create_label_output(predictions)
# Gradio app definition
with gr.Blocks() as gradio_app:
with gr.Column():
input_image = gr.Image(label="Upload an Image", type="pil")
output_map = gr.Plot(label="Predicted Location on Map")
btn_predict = gr.Button("Predict")
btn_predict.click(predict_and_plot, inputs=input_image, outputs=output_map)
gradio_app.launch() |