lat mean = 39.951614360789364 lat std = 0.0007384844437841076 lon mean = -75.19140262762761 lon std = 0.0007284591160342192 **To load model:** ``` from huggingface_hub import hf_hub_download import torch repo_id = "thestalkers/ImageToGPSproject_base_resnet18_v2" filename = "resnet_gps_regressor_complete.pth" model_path = hf_hub_download(repo_id=repo_id, filename=filename) # Load the model using torch model_test = torch.load(model_path) model_test.eval() # Set the model to evaluation mode ``` **Load a hf dataset:** ``` from datasets import load_dataset, Image dataset_test = load_dataset("gydou/released_img", split="train") inference_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) test_dataset = GPSImageDataset( hf_dataset=dataset_test, transform=inference_transform, lat_mean=lat_mean, lat_std=lat_std, lon_mean=lon_mean, lon_std=lon_std ) test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False) ``` **Perform inference:** ``` from sklearn.metrics import mean_absolute_error, mean_squared_error # Initialize lists to store predictions and actual values all_preds = [] all_actuals = [] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f'Using device: {device}') with torch.no_grad(): for images, gps_coords in test_dataloader: images, gps_coords = images.to(device), gps_coords.to(device) outputs = model_test(images) # Denormalize predictions and actual values preds = outputs.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean]) actuals = gps_coords.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean]) all_preds.append(preds) all_actuals.append(actuals) # Concatenate all batches all_preds = torch.cat(all_preds).numpy() all_actuals = torch.cat(all_actuals).numpy() # Compute error metrics mae = mean_absolute_error(all_actuals, all_preds) rmse = mean_squared_error(all_actuals, all_preds, squared=False) print(f'Mean Absolute Error: {mae}') print(f'Root Mean Squared Error: {rmse}') ```