Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from pathlib import Path | |
| import torch | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| from utils import SSLModule | |
| from io import BytesIO | |
| import matplotlib.pyplot as plt | |
| import os | |
| # Load the model | |
| checkpoints_dir = Path("saved_checkpoints") | |
| checkpoint = "SSLhuge_satellite.pth" | |
| device = "cpu" | |
| ckpt_path = checkpoints_dir / checkpoint | |
| model = SSLModule(ssl_path=str(ckpt_path)) | |
| model.to(device) | |
| model = model.eval() | |
| # Define the normalization transform | |
| norm = T.Normalize((0.420, 0.411, 0.296), (0.213, 0.156, 0.143)) | |
| norm = norm.to(device) | |
| # Define a function to make predictions | |
| def predict(image): | |
| # Convert PIL Image to tensor | |
| image_t = torch.tensor(image).permute(2, 0, 1)[:3].float().to(device) / 255 | |
| # Normalize the image | |
| with torch.no_grad(): | |
| pred = model(norm(image_t.unsqueeze(0))) | |
| pred = pred.cpu().detach().relu() | |
| # Convert tensor to numpy array | |
| pred_np = pred[0, 0].numpy() | |
| # Save the image to an in-memory buffer | |
| buffer = BytesIO() | |
| plt.imsave(buffer, pred_np, cmap="Greens") | |
| buffer.seek(0) # Rewind the buffer to the beginning | |
| # Read the image back from the buffer | |
| image_from_buffer = Image.open(buffer) | |
| return image_from_buffer | |
| # create a Gradio interface | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(label="Upload a Satellite Image"), | |
| outputs=gr.Image(label="Estimated Canopy Height"), | |
| title="Estimate 🌳 Canopy Height from Satellite Images 🛰️", | |
| description=""" | |
| <div style='display: flex; justify-content: center; align-items: center;'> | |
| <img src='https://sustainability.fb.com/wp-content/uploads/2024/04/worldmap-2500.jpg?w=1536' style='max-width: 500px'/> | |
| </div> | |
| <p>This application uses a pre-trained model to estimate canopy height from satellite images. Upload an image and see the result! (You can upload a screenshot from Google Maps, for example).</p> | |
| """, | |
| examples=[ | |
| ["examples/image.png"], | |
| ["examples/image2.png"], | |
| ["examples/image3.png"], | |
| ], | |
| article="<p style='text-align: center'>Find more information <a href='https://sustainability.fb.com/blog/2024/04/22/using-artificial-intelligence-to-map-the-earths-forests/'>here</a>.</p>", | |
| allow_flagging=False, | |
| ) | |
| # launch the interface | |
| demo.launch() | |