Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import yaml | |
import numpy as np | |
import gradio as gr | |
from einops import rearrange | |
from functools import partial | |
from huggingface_hub import hf_hub_download | |
# pull files from hub | |
token = os.environ.get("HF_TOKEN", None) | |
config_path = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL", | |
filename="config.json", token=token) | |
checkpoint = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL", | |
filename='Prithvi_EO_V2_300M_TL.pt', token=token) | |
model_def = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL", | |
filename='prithvi_mae.py', token=token) | |
model_inference = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL", | |
filename='inference.py', token=token) | |
os.system(f'cp {model_def} .') | |
os.system(f'cp {model_inference} .') | |
from prithvi_mae import PrithviMAE | |
from inference import process_channel_group, _convert_np_uint8, load_example, run_model | |
def extract_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std): | |
""" Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp. | |
Args: | |
input_img: input torch.Tensor with shape (C, T, H, W). | |
rec_img: reconstructed torch.Tensor with shape (C, T, H, W). | |
mask_img: mask torch.Tensor with shape (C, T, H, W). | |
channels: list of indices representing RGB channels. | |
mean: list of mean values for each band. | |
std: list of std values for each band. | |
output_dir: directory where to save outputs. | |
meta_data: list of dicts with geotiff meta info. | |
""" | |
rgb_orig_list = [] | |
rgb_mask_list = [] | |
rgb_pred_list = [] | |
for t in range(input_img.shape[1]): | |
rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :], | |
new_img=rec_img[:, t, :, :], | |
channels=channels, | |
mean=mean, | |
std=std) | |
rgb_mask = mask_img[channels, t, :, :] * rgb_orig | |
# extract images | |
rgb_orig_list.append(_convert_np_uint8(rgb_orig).transpose(1, 2, 0)) | |
rgb_mask_list.append(_convert_np_uint8(rgb_mask).transpose(1, 2, 0)) | |
rgb_pred_list.append(_convert_np_uint8(rgb_pred).transpose(1, 2, 0)) | |
# Add white dummy image values for missing timestamps | |
dummy = np.ones((20, 20), dtype=np.uint8) * 255 | |
num_dummies = 4 - len(rgb_orig_list) | |
if num_dummies: | |
rgb_orig_list.extend([dummy] * num_dummies) | |
rgb_mask_list.extend([dummy] * num_dummies) | |
rgb_pred_list.extend([dummy] * num_dummies) | |
outputs = rgb_orig_list + rgb_mask_list + rgb_pred_list | |
return outputs | |
def predict_on_images(data_files: list, config_path: str, checkpoint: str, mask_ratio: float = None): | |
try: | |
data_files = [x.name for x in data_files] | |
print('Path extracted from example') | |
except: | |
print('Files submitted through UI') | |
# Get parameters -------- | |
print('This is the printout', data_files) | |
with open(config_path, 'r') as f: | |
config = yaml.safe_load(f)['pretrained_cfg'] | |
batch_size = 8 | |
bands = config['bands'] | |
num_frames = len(data_files) | |
mean = config['mean'] | |
std = config['std'] | |
coords_encoding = config['coords_encoding'] | |
img_size = config['img_size'] | |
mask_ratio = mask_ratio or config['mask_ratio'] | |
assert num_frames <= 4, "Demo only supports up to four timestamps" | |
if torch.cuda.is_available(): | |
device = torch.device('cuda') | |
else: | |
device = torch.device('cpu') | |
print(f"Using {device} device.\n") | |
# Loading data --------------------------------------------------------------------------------- | |
input_data, temporal_coords, location_coords, meta_data = load_example(file_paths=data_files, mean=mean, std=std) | |
if len(temporal_coords) != num_frames and 'time' in coords_encoding: | |
coords_encoding.pop('time') | |
if not len(location_coords) and 'location' in coords_encoding: | |
coords_encoding.pop('location') | |
# Create model and load checkpoint ------------------------------------------------------------- | |
config.update( | |
num_frames=num_frames, | |
coords_encoding=coords_encoding, | |
) | |
model = PrithviMAE(**config) | |
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
print(f"\n--> Model has {total_params:,} parameters.\n") | |
model.to(device) | |
state_dict = torch.load(checkpoint, map_location=device, weights_only=False) | |
# discard fixed pos_embedding weight | |
for k in list(state_dict.keys()): | |
if 'pos_embed' in k: | |
del state_dict[k] | |
model.load_state_dict(state_dict, strict=False) | |
print(f"Loaded checkpoint from {checkpoint}") | |
# Running model -------------------------------------------------------------------------------- | |
model.eval() | |
channels = [bands.index(b) for b in ['B04', 'B03', 'B02']] # BGR -> RGB | |
# Reflect pad if not divisible by img_size | |
original_h, original_w = input_data.shape[-2:] | |
pad_h = img_size - (original_h % img_size) | |
pad_w = img_size - (original_w % img_size) | |
input_data = np.pad(input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode='reflect') | |
# Build sliding window | |
batch = torch.tensor(input_data, device='cpu') | |
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size) | |
h1, w1 = windows.shape[3:5] | |
windows = rearrange(windows, 'b c t h1 w1 h w -> (b h1 w1) c t h w', h=img_size, w=img_size) | |
# Split into batches if number of windows > batch_size | |
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1 | |
windows = torch.tensor_split(windows, num_batches, dim=0) | |
# Run model | |
rec_imgs = [] | |
mask_imgs = [] | |
for x in windows: | |
temp_coords = torch.Tensor([temporal_coords] * len(x)) | |
loc_coords = torch.Tensor([location_coords[0]] * len(x)) | |
rec_img, mask_img = run_model(model, x, temp_coords, loc_coords, mask_ratio, device) | |
rec_imgs.append(rec_img) | |
mask_imgs.append(mask_img) | |
rec_imgs = torch.concat(rec_imgs, dim=0) | |
mask_imgs = torch.concat(mask_imgs, dim=0) | |
# Build images from patches | |
rec_imgs = rearrange(rec_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)', | |
h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1) | |
mask_imgs = rearrange(mask_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)', | |
h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1) | |
# Cut padded images back to original size | |
rec_imgs_full = rec_imgs[..., :original_h, :original_w] | |
mask_imgs_full = mask_imgs[..., :original_h, :original_w] | |
batch_full = batch[..., :original_h, :original_w] | |
# Build RGB images | |
for d in meta_data: | |
d.update(count=3, dtype='uint8', compress='lzw', nodata=0) | |
outputs = extract_rgb_imgs(batch_full[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...], | |
channels, mean, std) | |
print("Done!") | |
return outputs | |
run_inference = partial(predict_on_images, config_path=config_path,checkpoint=checkpoint) | |
with gr.Blocks() as demo: | |
gr.Markdown(value='# Prithvi-EO-2.0 image reconstruction demo') | |
gr.Markdown(value=''' | |
Prithvi-EO-2.0 is the second generation EO foundation model developed by the IBM and NASA team. | |
The temporal ViT is train on 4.2M Harmonised Landsat Sentinel 2 (HLS) samples with four timestamps each, using the Masked AutoEncoder learning strategy. | |
The model includes spatial and temporal attention across multiple patches and timestamps. | |
Additionally, temporal and location information is added to the model input via embeddings. | |
More info about the model are available [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL).\n | |
This demo showcases the image reconstruction over one to four timestamps. | |
The model randomly masks out some proportion of the images and reconstructs them based on the not masked portion of the images. | |
The reconstructed images are merged with the visible unmasked patches. | |
We recommend submitting images of size 224 to ~1000 pixels for faster processing time. | |
Images bigger than 224x224 are processed using a sliding window approach which can lead to artefacts between patches.\n | |
The user needs to provide the HLS geotiff images, including the following channels in reflectance units: Blue, Green, Red, Narrow NIR, SWIR, SWIR 2. | |
Optionally, the location information is extracted from the tif files while the temporal information can be provided in the filename in the format `<date>T<time>` or `<year><julian day>T<time>` (HLS format). | |
Some example images are provided at the end of this page. | |
''') | |
with gr.Row(): | |
with gr.Column(): | |
inp_files = gr.Files(elem_id='files') | |
# inp_slider = gr.Slider(0, 100, value=50, label="Mask ratio", info="Choose ratio of masking between 0 and 100", elem_id='slider'), | |
btn = gr.Button("Submit") | |
with gr.Row(): | |
gr.Markdown(value='## Input time series') | |
gr.Markdown(value='## Masked images') | |
gr.Markdown(value='## Reconstructed images*') | |
original = [] | |
masked = [] | |
predicted = [] | |
timestamps = [] | |
for t in range(4): | |
timestamps.append(gr.Column(visible=t == 0)) | |
with timestamps[t]: | |
#with gr.Row(): | |
# gr.Markdown(value=f"Timestamp {t+1}") | |
with gr.Row(): | |
original.append(gr.Image(image_mode='RGB', show_label=False, show_fullscreen_button=False)) | |
masked.append(gr.Image(image_mode='RGB', show_label=False, show_fullscreen_button=False)) | |
predicted.append(gr.Image(image_mode='RGB', show_label=False, show_fullscreen_button=False)) | |
gr.Markdown(value='\* The reconstructed images include the ground truth unmasked patches.') | |
btn.click(fn=run_inference, | |
inputs=inp_files, | |
outputs=original + masked + predicted) | |
with gr.Row(): | |
gr.Examples(examples=[[[ | |
os.path.join(os.path.dirname(__file__), "examples/Alaska_HLS.S30.T06VUN.2020305T212629.v2.0_cropped.tif"), | |
os.path.join(os.path.dirname(__file__), "examples/Alaska_HLS.S30.T06VUN.2021044T212601.v2.0_cropped.tif"), | |
os.path.join(os.path.dirname(__file__), "examples/Alaska_HLS.S30.T06VUN.2021067T213531.v2.0_cropped.tif"), | |
os.path.join(os.path.dirname(__file__), "examples/Alaska_HLS.S30.T06VUN.2021067T213531.v2.0_cropped.tif") | |
]], [[ | |
os.path.join(os.path.dirname(__file__), "examples/Florida_HLS.S30.T17RMP.2019119T155911.v2.0_cropped.tif"), | |
os.path.join(os.path.dirname(__file__), "examples/Florida_HLS.S30.T17RMP.2019249T155901.v2.0_cropped.tif"), | |
os.path.join(os.path.dirname(__file__), "examples/Florida_HLS.S30.T17RMP.2019349T160651.v2.0_cropped.tif"), | |
os.path.join(os.path.dirname(__file__), "examples/Florida_HLS.S30.T17RMP.2020039T160419.v2.0_cropped.tif") | |
]]], | |
inputs=inp_files, | |
outputs=original + masked + predicted, | |
fn=run_inference, | |
cache_examples=True | |
) | |
def update_visibility(files): | |
timestamps = [gr.Column(visible=t < len(files)) for t in range(4)] | |
return timestamps | |
inp_files.change(update_visibility, inp_files, timestamps) | |
demo.launch() | |