juansensio commited on
Commit
886f76c
·
verified ·
1 Parent(s): 2b79972
Files changed (1) hide show
  1. app.py +26 -26
app.py CHANGED
@@ -9,39 +9,39 @@ import matplotlib.pyplot as plt
9
  import os
10
 
11
  # Load the model
12
- checkpoints_dir = Path("saved_checkpoints")
13
- if not checkpoints_dir.exists():
14
- os.system("aws s3 --no-sign-request cp --recursive s3://dataforgood-fb-data/forests/v1/models/ .")
15
- checkpoint = "SSLhuge_satellite.pth"
16
- device = "cpu"
17
- ckpt_path = checkpoints_dir / checkpoint
18
- model = SSLModule(ssl_path=str(ckpt_path))
19
- model.to(device)
20
- model = model.eval()
21
 
22
  # Define the normalization transform
23
- norm = T.Normalize((0.420, 0.411, 0.296), (0.213, 0.156, 0.143))
24
- norm = norm.to(device)
25
 
26
 
27
  # Define a function to make predictions
28
  def predict(image):
29
  # Convert PIL Image to tensor
30
- image_t = torch.tensor(image).permute(2, 0, 1)[:3].float().to(device) / 255
31
- # Normalize the image
32
- with torch.no_grad():
33
- pred = model(norm(image_t.unsqueeze(0)))
34
- pred = pred.cpu().detach().relu()
35
- # Convert tensor to numpy array
36
- pred_np = pred[0, 0].numpy()
37
- # Save the image to an in-memory buffer
38
- buffer = BytesIO()
39
- plt.imsave(buffer, pred_np, cmap="Greens")
40
- buffer.seek(0) # Rewind the buffer to the beginning
41
- # Read the image back from the buffer
42
- image_from_buffer = Image.open(buffer)
43
- return image_from_buffer
44
-
45
 
46
  # create a Gradio interface
47
  demo = gr.Interface(
 
9
  import os
10
 
11
  # Load the model
12
+ #checkpoints_dir = Path("saved_checkpoints")
13
+ #if not checkpoints_dir.exists():
14
+ # os.system("aws s3 --no-sign-request cp --recursive s3://dataforgood-fb-data/forests/v1/models/ .")
15
+ #checkpoint = "SSLhuge_satellite.pth"
16
+ #device = "cpu"
17
+ #ckpt_path = checkpoints_dir / checkpoint
18
+ #model = SSLModule(ssl_path=str(ckpt_path))
19
+ #model.to(device)
20
+ #model = model.eval()
21
 
22
  # Define the normalization transform
23
+ #norm = T.Normalize((0.420, 0.411, 0.296), (0.213, 0.156, 0.143))
24
+ #norm = norm.to(device)
25
 
26
 
27
  # Define a function to make predictions
28
  def predict(image):
29
  # Convert PIL Image to tensor
30
+ # image_t = torch.tensor(image).permute(2, 0, 1)[:3].float().to(device) / 255
31
+ # # Normalize the image
32
+ # with torch.no_grad():
33
+ # pred = model(norm(image_t.unsqueeze(0)))
34
+ # pred = pred.cpu().detach().relu()
35
+ # # Convert tensor to numpy array
36
+ # pred_np = pred[0, 0].numpy()
37
+ # # Save the image to an in-memory buffer
38
+ # buffer = BytesIO()
39
+ # plt.imsave(buffer, pred_np, cmap="Greens")
40
+ # buffer.seek(0) # Rewind the buffer to the beginning
41
+ # # Read the image back from the buffer
42
+ # image_from_buffer = Image.open(buffer)
43
+ # return image_from_buffer
44
+ return image
45
 
46
  # create a Gradio interface
47
  demo = gr.Interface(