biomass / app.py
juansensio's picture
debugging
886f76c verified
raw
history blame
2.42 kB
import gradio as gr
from pathlib import Path
import torch
import torchvision.transforms as T
from PIL import Image
from utils import SSLModule
from io import BytesIO
import matplotlib.pyplot as plt
import os
# Load the model
#checkpoints_dir = Path("saved_checkpoints")
#if not checkpoints_dir.exists():
# os.system("aws s3 --no-sign-request cp --recursive s3://dataforgood-fb-data/forests/v1/models/ .")
#checkpoint = "SSLhuge_satellite.pth"
#device = "cpu"
#ckpt_path = checkpoints_dir / checkpoint
#model = SSLModule(ssl_path=str(ckpt_path))
#model.to(device)
#model = model.eval()
# Define the normalization transform
#norm = T.Normalize((0.420, 0.411, 0.296), (0.213, 0.156, 0.143))
#norm = norm.to(device)
# Define a function to make predictions
def predict(image):
# Convert PIL Image to tensor
# image_t = torch.tensor(image).permute(2, 0, 1)[:3].float().to(device) / 255
# # Normalize the image
# with torch.no_grad():
# pred = model(norm(image_t.unsqueeze(0)))
# pred = pred.cpu().detach().relu()
# # Convert tensor to numpy array
# pred_np = pred[0, 0].numpy()
# # Save the image to an in-memory buffer
# buffer = BytesIO()
# plt.imsave(buffer, pred_np, cmap="Greens")
# buffer.seek(0) # Rewind the buffer to the beginning
# # Read the image back from the buffer
# image_from_buffer = Image.open(buffer)
# return image_from_buffer
return image
# create a Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(label="Upload a Satellite Image"),
outputs=gr.Image(label="Estimated Canopy Height"),
title="Estimate 🌳 Canopy Height from Satellite Image 🛰️",
description="""
<div style='display: flex; justify-content: center; align-items: center;'>
<img src='https://sustainability.fb.com/wp-content/uploads/2024/04/worldmap-2500.jpg?w=1536' style='max-width: 500px'/>
</div>
<p>This application uses a pre-trained model to estimate canopy height from satellite images. Upload an image and see the result!</p>
""",
examples=[
["examples/image.png"],
["examples/image2.png"],
["examples/image3.png"],
],
article="<p style='text-align: center'>Find more information <a href='https://sustainability.fb.com/blog/2024/04/22/using-artificial-intelligence-to-map-the-earths-forests/'>here</a>.</p>",
allow_flagging=False,
)
# launch the interface
demo.launch()