|
lat mean = 39.951614360789364 |
|
|
|
lat std = 0.0007384844437841076 |
|
|
|
lon mean = -75.19140262762761 |
|
|
|
lon std = 0.0007284591160342192 |
|
|
|
**To load model:** |
|
``` |
|
from huggingface_hub import hf_hub_download |
|
import torch |
|
|
|
repo_id = "thestalkers/ImageToGPSproject_base_resnet18_v2" |
|
filename = "resnet_gps_regressor_complete.pth" |
|
|
|
model_path = hf_hub_download(repo_id=repo_id, filename=filename) |
|
|
|
# Load the model using torch |
|
model_test = torch.load(model_path) |
|
model_test.eval() # Set the model to evaluation mode |
|
``` |
|
|
|
**Load a hf dataset:** |
|
``` |
|
from datasets import load_dataset, Image |
|
|
|
dataset_test = load_dataset("gydou/released_img", split="train") |
|
|
|
inference_transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
test_dataset = GPSImageDataset( |
|
hf_dataset=dataset_test, |
|
transform=inference_transform, |
|
lat_mean=lat_mean, |
|
lat_std=lat_std, |
|
lon_mean=lon_mean, |
|
lon_std=lon_std |
|
) |
|
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False) |
|
``` |
|
|
|
**Perform inference:** |
|
``` |
|
from sklearn.metrics import mean_absolute_error, mean_squared_error |
|
|
|
# Initialize lists to store predictions and actual values |
|
all_preds = [] |
|
all_actuals = [] |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f'Using device: {device}') |
|
|
|
with torch.no_grad(): |
|
for images, gps_coords in test_dataloader: |
|
images, gps_coords = images.to(device), gps_coords.to(device) |
|
|
|
outputs = model_test(images) |
|
|
|
# Denormalize predictions and actual values |
|
preds = outputs.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean]) |
|
actuals = gps_coords.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean]) |
|
|
|
all_preds.append(preds) |
|
all_actuals.append(actuals) |
|
|
|
# Concatenate all batches |
|
all_preds = torch.cat(all_preds).numpy() |
|
all_actuals = torch.cat(all_actuals).numpy() |
|
|
|
# Compute error metrics |
|
mae = mean_absolute_error(all_actuals, all_preds) |
|
rmse = mean_squared_error(all_actuals, all_preds, squared=False) |
|
|
|
print(f'Mean Absolute Error: {mae}') |
|
print(f'Root Mean Squared Error: {rmse}') |
|
``` |