Spaces:
Sleeping
Sleeping
debugging
Browse files
app.py
CHANGED
@@ -9,39 +9,39 @@ import matplotlib.pyplot as plt
|
|
9 |
import os
|
10 |
|
11 |
# Load the model
|
12 |
-
checkpoints_dir = Path("saved_checkpoints")
|
13 |
-
if not checkpoints_dir.exists():
|
14 |
-
os.system("aws s3 --no-sign-request cp --recursive s3://dataforgood-fb-data/forests/v1/models/ .")
|
15 |
-
checkpoint = "SSLhuge_satellite.pth"
|
16 |
-
device = "cpu"
|
17 |
-
ckpt_path = checkpoints_dir / checkpoint
|
18 |
-
model = SSLModule(ssl_path=str(ckpt_path))
|
19 |
-
model.to(device)
|
20 |
-
model = model.eval()
|
21 |
|
22 |
# Define the normalization transform
|
23 |
-
norm = T.Normalize((0.420, 0.411, 0.296), (0.213, 0.156, 0.143))
|
24 |
-
norm = norm.to(device)
|
25 |
|
26 |
|
27 |
# Define a function to make predictions
|
28 |
def predict(image):
|
29 |
# Convert PIL Image to tensor
|
30 |
-
image_t = torch.tensor(image).permute(2, 0, 1)[:3].float().to(device) / 255
|
31 |
-
# Normalize the image
|
32 |
-
with torch.no_grad():
|
33 |
-
pred = model(norm(image_t.unsqueeze(0)))
|
34 |
-
pred = pred.cpu().detach().relu()
|
35 |
-
# Convert tensor to numpy array
|
36 |
-
pred_np = pred[0, 0].numpy()
|
37 |
-
# Save the image to an in-memory buffer
|
38 |
-
buffer = BytesIO()
|
39 |
-
plt.imsave(buffer, pred_np, cmap="Greens")
|
40 |
-
buffer.seek(0) # Rewind the buffer to the beginning
|
41 |
-
# Read the image back from the buffer
|
42 |
-
image_from_buffer = Image.open(buffer)
|
43 |
-
return image_from_buffer
|
44 |
-
|
45 |
|
46 |
# create a Gradio interface
|
47 |
demo = gr.Interface(
|
|
|
9 |
import os
|
10 |
|
11 |
# Load the model
|
12 |
+
#checkpoints_dir = Path("saved_checkpoints")
|
13 |
+
#if not checkpoints_dir.exists():
|
14 |
+
# os.system("aws s3 --no-sign-request cp --recursive s3://dataforgood-fb-data/forests/v1/models/ .")
|
15 |
+
#checkpoint = "SSLhuge_satellite.pth"
|
16 |
+
#device = "cpu"
|
17 |
+
#ckpt_path = checkpoints_dir / checkpoint
|
18 |
+
#model = SSLModule(ssl_path=str(ckpt_path))
|
19 |
+
#model.to(device)
|
20 |
+
#model = model.eval()
|
21 |
|
22 |
# Define the normalization transform
|
23 |
+
#norm = T.Normalize((0.420, 0.411, 0.296), (0.213, 0.156, 0.143))
|
24 |
+
#norm = norm.to(device)
|
25 |
|
26 |
|
27 |
# Define a function to make predictions
|
28 |
def predict(image):
|
29 |
# Convert PIL Image to tensor
|
30 |
+
# image_t = torch.tensor(image).permute(2, 0, 1)[:3].float().to(device) / 255
|
31 |
+
# # Normalize the image
|
32 |
+
# with torch.no_grad():
|
33 |
+
# pred = model(norm(image_t.unsqueeze(0)))
|
34 |
+
# pred = pred.cpu().detach().relu()
|
35 |
+
# # Convert tensor to numpy array
|
36 |
+
# pred_np = pred[0, 0].numpy()
|
37 |
+
# # Save the image to an in-memory buffer
|
38 |
+
# buffer = BytesIO()
|
39 |
+
# plt.imsave(buffer, pred_np, cmap="Greens")
|
40 |
+
# buffer.seek(0) # Rewind the buffer to the beginning
|
41 |
+
# # Read the image back from the buffer
|
42 |
+
# image_from_buffer = Image.open(buffer)
|
43 |
+
# return image_from_buffer
|
44 |
+
return image
|
45 |
|
46 |
# create a Gradio interface
|
47 |
demo = gr.Interface(
|