blumenstiel commited on
Commit
6e9b62d
Β·
1 Parent(s): 680d6f0

Add demo code

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. Dockerfile +38 -0
  3. README.md +6 -6
  4. app.py +178 -0
  5. requirements.txt +8 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.tif filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM ubuntu:22.04
2
+
3
+
4
+ RUN apt-get update && apt-get install --no-install-recommends -y \
5
+ build-essential \
6
+ python3.9 \
7
+ python3-pip \
8
+ git \
9
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
10
+
11
+ WORKDIR /code
12
+
13
+ COPY ./requirements.txt /code/requirements.txt
14
+
15
+ # Set up a new user named "user" with user ID 1000
16
+ RUN useradd -m -u 1000 user
17
+ # Switch to the "user" user
18
+ USER user
19
+ # Set home to the user's home directory
20
+ ENV HOME=/home/user \
21
+ PATH=/home/user/.local/bin:$PATH \
22
+ PYTHONPATH=$HOME/app \
23
+ PYTHONUNBUFFERED=1 \
24
+ GRADIO_ALLOW_FLAGGING=never \
25
+ GRADIO_NUM_PORTS=1 \
26
+ GRADIO_SERVER_NAME=0.0.0.0 \
27
+ GRADIO_THEME=huggingface \
28
+ SYSTEM=spaces
29
+
30
+ RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
31
+
32
+ # Set the working directory to the user's home directory
33
+ WORKDIR $HOME/app
34
+
35
+ # Copy the current directory contents into the container at $HOME/app setting the owner to the user
36
+ COPY --chown=user . $HOME/app
37
+
38
+ CMD ["python3", "app.py"]
README.md CHANGED
@@ -1,14 +1,14 @@
1
  ---
2
- title: Prithvi EO 2.0 Sen1Floods11
3
- emoji: 🏒
4
- colorFrom: red
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.8.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: Prithvi EO 2.0 300M TL – Sen1Floods11 flood segmentation
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Prithvi EO 2.0 Sen1Floods11 Demo
3
+ emoji: 🌊
4
+ colorFrom: indigo
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 5.6.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ short_description: Prithvi EO 2.0 Sen1Floods11 flood segmentation demo
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import torch
4
+ import yaml
5
+ import numpy as np
6
+ import gradio as gr
7
+ from pathlib import Path
8
+ from einops import rearrange
9
+ from functools import partial
10
+ from huggingface_hub import hf_hub_download
11
+ from terratorch.cli_tools import LightningInferenceModel
12
+
13
+ # pull files from hub
14
+ token = os.environ.get("HF_TOKEN", None)
15
+ config_path = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
16
+ filename="config.yaml", token=token)
17
+ checkpoint = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
18
+ filename='Prithvi-EO-V2-300M-TL-Sen1Floods11.pt', token=token)
19
+ model_inference = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
20
+ filename='inference.py', token=token)
21
+ os.system(f'cp {model_inference} .')
22
+
23
+ from inference import process_channel_group, _convert_np_uint8, load_example, run_model
24
+
25
+ def extract_rgb_imgs(input_img, pred_img, channels):
26
+ """ Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
27
+ Args:
28
+ input_img: input torch.Tensor with shape (C, H, W).
29
+ rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
30
+ pred_img: mask torch.Tensor with shape (C, T, H, W).
31
+ channels: list of indices representing RGB channels.
32
+ mean: list of mean values for each band.
33
+ std: list of std values for each band.
34
+ output_dir: directory where to save outputs.
35
+ meta_data: list of dicts with geotiff meta info.
36
+ """
37
+ rgb_orig_list = []
38
+ rgb_mask_list = []
39
+ rgb_pred_list = []
40
+
41
+ for t in range(input_img.shape[1]):
42
+ rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :],
43
+ new_img=rec_img[:, t, :, :],
44
+ channels=channels,
45
+ mean=mean,
46
+ std=std)
47
+
48
+ rgb_mask = mask_img[channels, t, :, :] * rgb_orig
49
+
50
+ # extract images
51
+ rgb_orig_list.append(_convert_np_uint8(rgb_orig).transpose(1, 2, 0))
52
+ rgb_mask_list.append(_convert_np_uint8(rgb_mask).transpose(1, 2, 0))
53
+ rgb_pred_list.append(_convert_np_uint8(rgb_pred).transpose(1, 2, 0))
54
+
55
+ # Add white dummy image values for missing timestamps
56
+ dummy = np.ones((20, 20), dtype=np.uint8) * 255
57
+ num_dummies = 4 - len(rgb_orig_list)
58
+ if num_dummies:
59
+ rgb_orig_list.extend([dummy] * num_dummies)
60
+ rgb_mask_list.extend([dummy] * num_dummies)
61
+ rgb_pred_list.extend([dummy] * num_dummies)
62
+
63
+ outputs = rgb_orig_list + rgb_mask_list + rgb_pred_list
64
+
65
+ return outputs
66
+
67
+
68
+ def predict_on_images(data_file: str | Path, config_path: str, checkpoint: str):
69
+ try:
70
+ data_file = data_file.name
71
+ print('Path extracted from example')
72
+ except:
73
+ print('Files submitted through UI')
74
+
75
+ # Get parameters --------
76
+ print('This is the printout', data_file)
77
+
78
+ with open(config_path, "r") as f:
79
+ config_dict = yaml.safe_load(f)
80
+
81
+ # Load model ---------------------------------------------------------------------------------
82
+
83
+ lightning_model = LightningInferenceModel.from_config(config_path, checkpoint)
84
+ img_size = 256 # Size of Sen1Floods11
85
+
86
+ # Loading data ---------------------------------------------------------------------------------
87
+
88
+ input_data, temporal_coords, location_coords, meta_data = load_example(file_paths=[data_file])
89
+
90
+ if input_data.shape[1] == 6:
91
+ pass
92
+ elif input_data.shape[1] == 13:
93
+ input_data = input_data[:, [1,2,3,8,11,12], ...]
94
+ else:
95
+ raise Exception(f'Input data has {input_data.shape[1]} channels. Expect either 6 Prithvi channels or 13 S2L1C channels.')
96
+
97
+ if input_data.mean() > 1:
98
+ input_data = input_data / 10000 # Convert to range 0-1
99
+
100
+ # Running model --------------------------------------------------------------------------------
101
+
102
+ lightning_model.model.eval()
103
+
104
+ channels = [config_dict['data']['init_args']['bands'].index(b) for b in ["RED", "GREEN", "BLUE"]] # BGR -> RGB
105
+
106
+ pred = run_model(input_data, temporal_coords, location_coords,
107
+ lightning_model.model, lightning_model.datamodule, img_size)
108
+
109
+ if input_data.mean() < 1:
110
+ input_data = input_data * 10000 # Scale to 0-10000
111
+
112
+ # Extract RGB images for display
113
+ rgb_orig = process_channel_group(
114
+ orig_img=torch.Tensor(input_data[0, :, 0, ...]),
115
+ channels=channels,
116
+ )
117
+ out_rgb_orig = _convert_np_uint8(rgb_orig).transpose(1, 2, 0)
118
+ out_pred_rgb = _convert_np_uint8(pred).repeat(3, axis=0).transpose(1, 2, 0)
119
+
120
+ pred[pred == 0.] = np.nan
121
+ img_pred = rgb_orig * 0.6 + pred * 0.4
122
+ img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()]
123
+
124
+ out_img_pred = _convert_np_uint8(img_pred).transpose(1, 2, 0)
125
+
126
+ outputs = [out_rgb_orig] + [out_pred_rgb] + [out_img_pred]
127
+
128
+ print("Done!")
129
+
130
+ return outputs
131
+
132
+
133
+ run_inference = partial(predict_on_images, config_path=config_path, checkpoint=checkpoint)
134
+
135
+ with gr.Blocks() as demo:
136
+ gr.Markdown(value='# Prithvi-EO-2.0 Sen1Floods11 Demo')
137
+ gr.Markdown(value='''
138
+ Prithvi-EO-2.0 is the second generation EO foundation model developed by the IBM and NASA team.
139
+ This demo showcases the fine-tuned Prithvi-EO-2.0-300M-TL model to detect water using Sentinel 2 imagery from on the [sen1floods11 dataset](https://github.com/cloudtostreet/Sen1Floods11). More details can be found [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11).\n
140
+
141
+ The user needs to provide a Sentinel-2 L1C image with either all the 13 bands or the six Prithvi bands (Blue, Green, Red, Narrow NIR, SWIR, SWIR 2). The demo code selects the required bands.
142
+ We recommend submitting images of 500 to ~1000 pixels for faster processing time. Images bigger than 256x256 are processed using a sliding window approach which can lead to artefacts between patches.\n
143
+ Optionally, the location information is extracted from the tif files while the temporal information can be provided in the filename in the format `<date>T<time>` or `<year><julian day>T<time>` (HLS format).
144
+ Some example images are provided at the end of this page.
145
+ ''')
146
+ with gr.Row():
147
+ with gr.Column():
148
+ inp_file = gr.File(elem_id='file')
149
+ # inp_slider = gr.Slider(0, 100, value=50, label="Mask ratio", info="Choose ratio of masking between 0 and 100", elem_id='slider'),
150
+ btn = gr.Button("Submit")
151
+ with gr.Row():
152
+ gr.Markdown(value='## Input image')
153
+ gr.Markdown(value='## Prediction*')
154
+ gr.Markdown(value='## Overlay')
155
+
156
+ with gr.Row():
157
+ original = gr.Image(image_mode='RGB', show_label=False, show_fullscreen_button=False)
158
+ predicted = gr.Image(image_mode='RGB', show_label=False, show_fullscreen_button=False)
159
+ overlay = gr.Image(image_mode='RGB', show_label=False, show_fullscreen_button=False)
160
+
161
+ gr.Markdown(value='\* White = flood; Black = no flood')
162
+
163
+ btn.click(fn=run_inference,
164
+ inputs=inp_file,
165
+ outputs=[original] + [predicted] + [overlay])
166
+
167
+ with gr.Row():
168
+ gr.Examples(examples=[
169
+ os.path.join(os.path.dirname(__file__), "examples/India_900498_S2Hand.tif"),
170
+ os.path.join(os.path.dirname(__file__), "examples/Spain_7370579_S2Hand.tif"),
171
+ os.path.join(os.path.dirname(__file__), "examples/USA_430764_S2Hand.tif")],
172
+ inputs=inp_file,
173
+ outputs=[original] + [predicted] + [overlay],
174
+ fn=run_inference,
175
+ cache_examples=True
176
+ )
177
+
178
+ demo.launch() # share=True, ssr_mode=False
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm
4
+ rasterio
5
+ einops
6
+ huggingface_hub
7
+ gradio
8
+ git+https://github.com/IBM/terratorch.git