juansensio commited on
Commit
13d5a13
·
verified ·
1 Parent(s): b37a44d

create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pathlib import Path
3
+ import torch
4
+ import torchvision.transforms as T
5
+ from PIL import Image
6
+ from utils import SSLModule
7
+ from io import BytesIO
8
+ import matplotlib.pyplot as plt
9
+ import os
10
+
11
+ # Load the model
12
+ checkpoints_dir = Path("saved_checkpoints")
13
+ if not checkpoints_dir.exists():
14
+ os.system("aws s3 --no-sign-request cp --recursive s3://dataforgood-fb-data/forests/v1/models/ .")
15
+ checkpoint = "SSLhuge_satellite.pth"
16
+ device = "cpu"
17
+ ckpt_path = checkpoints_dir / checkpoint
18
+ model = SSLModule(ssl_path=str(ckpt_path))
19
+ model.to(device)
20
+ model = model.eval()
21
+
22
+ # Define the normalization transform
23
+ norm = T.Normalize((0.420, 0.411, 0.296), (0.213, 0.156, 0.143))
24
+ norm = norm.to(device)
25
+
26
+
27
+ # Define a function to make predictions
28
+ def predict(image):
29
+ # Convert PIL Image to tensor
30
+ image_t = torch.tensor(image).permute(2, 0, 1)[:3].float().to(device) / 255
31
+ # Normalize the image
32
+ with torch.no_grad():
33
+ pred = model(norm(image_t.unsqueeze(0)))
34
+ pred = pred.cpu().detach().relu()
35
+ # Convert tensor to numpy array
36
+ pred_np = pred[0, 0].numpy()
37
+ # Save the image to an in-memory buffer
38
+ buffer = BytesIO()
39
+ plt.imsave(buffer, pred_np, cmap="Greens")
40
+ buffer.seek(0) # Rewind the buffer to the beginning
41
+ # Read the image back from the buffer
42
+ image_from_buffer = Image.open(buffer)
43
+ return image_from_buffer
44
+
45
+
46
+ # create a Gradio interface
47
+ demo = gr.Interface(
48
+ fn=predict,
49
+ inputs=gr.Image(label="Upload a Satellite Image"),
50
+ outputs=gr.Image(label="Estimated Canopy Height"),
51
+ title="Estimate 🌳 Canopy Height from Satellite Image 🛰️",
52
+ description="""
53
+ <div style='display: flex; justify-content: center; align-items: center;'>
54
+ <img src='https://sustainability.fb.com/wp-content/uploads/2024/04/worldmap-2500.jpg?w=1536' style='max-width: 500px'/>
55
+ </div>
56
+ <p>This application uses a pre-trained model to estimate canopy height from satellite images. Upload an image and see the result!</p>
57
+ """,
58
+ examples=[
59
+ ["examples/image.png"],
60
+ ["examples/image2.png"],
61
+ ["examples/image3.png"],
62
+ ],
63
+ article="<p style='text-align: center'>Find more information <a href='https://sustainability.fb.com/blog/2024/04/22/using-artificial-intelligence-to-map-the-earths-forests/'>here</a>.</p>",
64
+ allow_flagging=False,
65
+ )
66
+
67
+ # launch the interface
68
+ demo.launch()