Commit
Β·
540601d
1
Parent(s):
58a73c7
Switched to config.json
Browse files
app.py
CHANGED
|
@@ -10,8 +10,8 @@ from huggingface_hub import hf_hub_download
|
|
| 10 |
|
| 11 |
# pull files from hub
|
| 12 |
token = os.environ.get("HF_TOKEN", None)
|
| 13 |
-
|
| 14 |
-
filename="
|
| 15 |
checkpoint = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
|
| 16 |
filename='Prithvi_EO_V2_300M_TL.pt', token=token)
|
| 17 |
model_def = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
|
|
@@ -67,7 +67,7 @@ def extract_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std):
|
|
| 67 |
return outputs
|
| 68 |
|
| 69 |
|
| 70 |
-
def predict_on_images(data_files: list,
|
| 71 |
try:
|
| 72 |
data_files = [x.name for x in data_files]
|
| 73 |
print('Path extracted from example')
|
|
@@ -77,18 +77,17 @@ def predict_on_images(data_files: list, yaml_file_path: str, checkpoint: str, ma
|
|
| 77 |
# Get parameters --------
|
| 78 |
print('This is the printout', data_files)
|
| 79 |
|
| 80 |
-
with open(
|
| 81 |
-
config = yaml.safe_load(f)
|
| 82 |
|
| 83 |
batch_size = 8
|
| 84 |
-
bands = config['
|
| 85 |
num_frames = len(data_files)
|
| 86 |
-
mean = config['
|
| 87 |
-
std = config['
|
| 88 |
-
coords_encoding = config['
|
| 89 |
-
img_size = config['
|
| 90 |
-
|
| 91 |
-
mask_ratio = mask_ratio or config['DATA']['MASK_RATIO']
|
| 92 |
|
| 93 |
assert num_frames <= 4, "Demo only supports up to four timestamps"
|
| 94 |
|
|
@@ -110,21 +109,12 @@ def predict_on_images(data_files: list, yaml_file_path: str, checkpoint: str, ma
|
|
| 110 |
|
| 111 |
# Create model and load checkpoint -------------------------------------------------------------
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
num_heads=config['MODEL']['NUM_HEADS'],
|
| 120 |
-
decoder_embed_dim=config['MODEL']['DECODER_EMBED_DIM'],
|
| 121 |
-
decoder_depth=config['MODEL']['DECODER_DEPTH'],
|
| 122 |
-
decoder_num_heads=config['MODEL']['DECODER_NUM_HEADS'],
|
| 123 |
-
mlp_ratio=config['MODEL']['MLP_RATIO'],
|
| 124 |
-
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
| 125 |
-
norm_pix_loss=config['MODEL']['NORM_PIX_LOSS'],
|
| 126 |
-
coords_encoding=coords_encoding,
|
| 127 |
-
coords_scale_learn=config['MODEL']['COORDS_SCALE_LEARN'])
|
| 128 |
|
| 129 |
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 130 |
print(f"\n--> Model has {total_params:,} parameters.\n")
|
|
@@ -196,7 +186,7 @@ def predict_on_images(data_files: list, yaml_file_path: str, checkpoint: str, ma
|
|
| 196 |
return outputs
|
| 197 |
|
| 198 |
|
| 199 |
-
run_inference = partial(predict_on_images,
|
| 200 |
|
| 201 |
with gr.Blocks() as demo:
|
| 202 |
|
|
|
|
| 10 |
|
| 11 |
# pull files from hub
|
| 12 |
token = os.environ.get("HF_TOKEN", None)
|
| 13 |
+
config_path = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
|
| 14 |
+
filename="config.json", token=token)
|
| 15 |
checkpoint = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
|
| 16 |
filename='Prithvi_EO_V2_300M_TL.pt', token=token)
|
| 17 |
model_def = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
|
|
|
|
| 67 |
return outputs
|
| 68 |
|
| 69 |
|
| 70 |
+
def predict_on_images(data_files: list, config_path: str, checkpoint: str, mask_ratio: float = None):
|
| 71 |
try:
|
| 72 |
data_files = [x.name for x in data_files]
|
| 73 |
print('Path extracted from example')
|
|
|
|
| 77 |
# Get parameters --------
|
| 78 |
print('This is the printout', data_files)
|
| 79 |
|
| 80 |
+
with open(config_path, 'r') as f:
|
| 81 |
+
config = yaml.safe_load(f)['pretrained_cfg']
|
| 82 |
|
| 83 |
batch_size = 8
|
| 84 |
+
bands = config['bands']
|
| 85 |
num_frames = len(data_files)
|
| 86 |
+
mean = config['mean']
|
| 87 |
+
std = config['std']
|
| 88 |
+
coords_encoding = config['coords_encoding']
|
| 89 |
+
img_size = config['img_size']
|
| 90 |
+
mask_ratio = mask_ratio or config['mask_ratio']
|
|
|
|
| 91 |
|
| 92 |
assert num_frames <= 4, "Demo only supports up to four timestamps"
|
| 93 |
|
|
|
|
| 109 |
|
| 110 |
# Create model and load checkpoint -------------------------------------------------------------
|
| 111 |
|
| 112 |
+
config.update(
|
| 113 |
+
num_frames=num_frames,
|
| 114 |
+
coords_encoding=coords_encoding,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
model = PrithviMAE(**config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 120 |
print(f"\n--> Model has {total_params:,} parameters.\n")
|
|
|
|
| 186 |
return outputs
|
| 187 |
|
| 188 |
|
| 189 |
+
run_inference = partial(predict_on_images, config_path=config_path,checkpoint=checkpoint)
|
| 190 |
|
| 191 |
with gr.Blocks() as demo:
|
| 192 |
|