Spaces:
Running
on
Zero
Running
on
Zero
| import base64 | |
| import os.path | |
| from io import BytesIO | |
| from pathlib import Path | |
| import glob | |
| import numpy as np | |
| import gradio as gr | |
| import rasterio as rio | |
| import matplotlib.pyplot as plt | |
| import matplotlib as mpl | |
| from PIL import Image | |
| from matplotlib import rcParams | |
| from msclip.inference import run_inference_classification | |
| rcParams["font.size"] = 9 | |
| rcParams["axes.titlesize"] = 9 | |
| IMG_PX = 300 | |
| EXAMPLES = { | |
| "EuroSAT": { | |
| "images": glob.glob("examples/eurosat/*.tif"), | |
| "classes": [ | |
| "Annual crop", "Forest", "Herbaceous vegetation", "Highway", "Industrial", | |
| "Pasture", "Permanent crop", "Residential", "River", "Sea lake" | |
| ] | |
| }, | |
| "Meter-ML": { | |
| "images": glob.glob("examples/meterml/*.tif"), | |
| "classes": [ | |
| "Concentrated animal feeding operations", | |
| "Landfills", | |
| "Coal mines", | |
| "Other features", | |
| "Natural gas processing plants", | |
| "Oil refineries and petroleum terminals", | |
| "Wastewater treatment plants", | |
| ] | |
| }, | |
| "TerraMesh": { | |
| "images": glob.glob("examples/terramesh/*.tif"), | |
| "classes": [ | |
| "Village", "Beach", "River", "Ice", "Fields", "Mountains", "Desert" | |
| ] | |
| }, | |
| } | |
| def load_eurosat_example(): | |
| return EXAMPLES["EuroSAT"]["images"], ", ".join(EXAMPLES["EuroSAT"]["classes"]) | |
| def load_meterml_example(): | |
| return EXAMPLES["Meter-ML"]["images"], ", ".join(EXAMPLES["Meter-ML"]["classes"]) | |
| def load_terramesh_example(): | |
| return EXAMPLES["TerraMesh"]["images"], ", ".join(EXAMPLES["TerraMesh"]["classes"]) | |
| pastel1_hex = [mpl.colors.to_hex(c) for c in mpl.colormaps["Pastel1"].colors] | |
| def build_colormap(class_names): | |
| return {c: pastel1_hex[i % len(pastel1_hex)] for i, c in enumerate(sorted(class_names))} | |
| def _rgb_smooth_quantiles(array, tolerance=0.02, scaling=0.5, default=2000): | |
| """ | |
| array: numpy array with dimensions [C, H, W] | |
| returns 0-1 scaled array | |
| """ | |
| # Get scaling thresholds for smoothing the brightness | |
| limit_low, median, limit_high = np.quantile(array, q=[tolerance, 0.5, 1. - tolerance]) | |
| limit_high = limit_high.clip(default) # Scale only pixels above default value | |
| limit_low = limit_low.clip(0, 1000) # Scale only pixels below 1000 | |
| limit_low = np.where(median > default / 2, limit_low, 0) # Make image only darker if it is not dark already | |
| # Smooth very dark and bright values using linear scaling | |
| array = np.where(array >= limit_low, array, limit_low + (array - limit_low) * scaling) | |
| array = np.where(array <= limit_high, array, limit_high + (array - limit_high) * scaling) | |
| # Update scaling params using a 10th of the tolerance for max value | |
| limit_low, limit_high = np.quantile(array, q=[tolerance / 10, 1. - tolerance / 10]) | |
| limit_high = limit_high.clip(default, 20000) # Scale only pixels above default value | |
| limit_low = limit_low.clip(0, 500) # Scale only pixels below 500 | |
| limit_low = np.where(median > default / 2, limit_low, 0) # Make image only darker if it is not dark already | |
| # Scale data to 0-255 | |
| array = (array - limit_low) / (limit_high - limit_low) | |
| return array | |
| def _s2_to_rgb(data, smooth_quantiles=True): | |
| # Select | |
| if data.shape[0] > 13: | |
| # assuming channel last | |
| rgb = data[:, :, [3, 2, 1]] | |
| else: | |
| # assuming channel first | |
| rgb = data[[3, 2, 1]].transpose((1, 2, 0)) | |
| if smooth_quantiles: | |
| rgb = _rgb_smooth_quantiles(rgb) | |
| else: | |
| rgb = rgb / 2000 | |
| # to uint8 | |
| rgb = (rgb * 255).round().clip(0, 255).astype(np.uint8) | |
| return rgb | |
| def _img_to_b64(path: str | Path) -> str: | |
| """Encode image as base64 (optionally downsized).""" | |
| with rio.open(path) as src: | |
| data = src.read() | |
| rgb = _s2_to_rgb(data) | |
| img = Image.fromarray(rgb) | |
| side = max(img.size) | |
| # create square canvas, paste centred, then resize | |
| canvas = Image.new("RGB", (side, side), (255, 255, 255)) | |
| canvas.paste(img, ((side - img.width) // 2, (side - img.height) // 2)) | |
| canvas = canvas.resize((IMG_PX, IMG_PX)) | |
| buf = BytesIO() | |
| canvas.save(buf, format="PNG") | |
| return base64.b64encode(buf.getvalue()).decode() | |
| def _bar_chart(top_scores, img_name, cmap) -> str: | |
| scores = top_scores.values.tolist() | |
| labels = top_scores.index.tolist() | |
| while len(scores) < 3: | |
| scores.append(0) | |
| labels.append("") | |
| fig, ax = plt.subplots(figsize=(3, 1)) | |
| y_pos = np.arange(3) | |
| colors = [cmap.get(cls, "none") if val > 0 else (0, 0, 0, 0) | |
| for cls, val in zip(labels, scores)] | |
| ax.barh(y_pos, scores, height=0.7, color=colors) | |
| ax.set_xlim(0, 1) | |
| ax.invert_yaxis() | |
| ax.axis("off") | |
| img_name = os.path.splitext(img_name)[0] | |
| if len(img_name) > 25: | |
| img_name = img_name[:23] + "..." | |
| ax.set_title(img_name) | |
| for i, (cls, val) in enumerate(zip(labels, scores)): | |
| if len(cls) > 25: | |
| cls = cls[:23] + "..." | |
| if val > 0: # skip padded rows | |
| ax.text(0.02, i + 0.03, f"{cls} ({round(val * 100)}%)", ha="left", va="center") | |
| buf = BytesIO() | |
| fig.savefig(buf, format="png", dpi=300, bbox_inches="tight", transparent=True) | |
| plt.close(fig) | |
| b64 = base64.b64encode(buf.getvalue()).decode() | |
| return f'<img src="data:image/png;base64,{b64}" style="display:block;margin:auto;width:{IMG_PX}px;" />' | |
| def classify(images, class_text): | |
| class_names = [c.strip() for c in class_text.split(",") if c.strip()] | |
| cards = [] | |
| df = run_inference_classification(image_path=images, class_names=class_names, verbose=False) | |
| for img_path, (id, row) in zip(images, df.iterrows()): | |
| scores = row[2:].astype(float) # drop filename column | |
| top = scores.sort_values(ascending=False)[:3] | |
| top = top[top > 0.01] # filter low scores | |
| cmap = build_colormap(class_names) | |
| cards.append(f""" | |
| <div style="width:{IMG_PX}px;margin:18px auto;text-align:left;"> | |
| <img src="data:image/png;base64,{_img_to_b64(img_path)}" | |
| style="width:{IMG_PX}px;height:{IMG_PX}px;object-fit:cover; | |
| border-radius:8px;box-shadow:0 2px 6px rgba(0,0,0,.15);display:block;margin:auto;"> | |
| {_bar_chart(top, os.path.basename(img_path), cmap)} | |
| </div>""") | |
| return ( | |
| "<div style='display:flex;flex-wrap:wrap;justify-content:center;'>" | |
| + "".join(cards) | |
| + "</div>" | |
| ) | |
| # UI | |
| with gr.Blocks( | |
| css=""" | |
| .gradio-container | |
| #result_box, | |
| #result_box.gr-skeleton {min-height:280px !important;} | |
| """) as demo: | |
| gr.Markdown("## Zero‑shot Classification with Llama3-MS‑CLIP") | |
| gr.Markdown("Provide Sentinel-2 L2A tif files with all 12 bands and define the class names for running zero-shot classification. " | |
| "You can also use S-2 L1C files with 13 bands but the model might not work as well (e.g., misclassifing forests as sea because of the differrently scaled values). " | |
| "We provide three sets of example images with class names that you can modify. The examples are from [EuroSAT](https://arxiv.org/abs/1709.00029), [Meter-ML](https://arxiv.org/abs/2207.11166), and [TerraMesh](https://arxiv.org/abs/2504.11172) (We downloaded S-2 L2A images for the same locations). " | |
| "The images are classified based on the similarity between the images embeddings and text embeddings. " | |
| "You find more information in the [model card](https://huggingface.co/ibm-esa-geospatial/Llama3-MS-CLIP-base) and the [paper](https://arxiv.org/abs/2503.15969). ") | |
| with gr.Row(): | |
| img_in = gr.File( | |
| label="Upload S-2 images", file_count="multiple", type="filepath" | |
| ) | |
| cls_in = gr.Textbox( | |
| value=", ".join(["Forest", "River", "Buildings", "Agriculture", "Mountain", "Snow"]), | |
| # some default classes | |
| label="Class names (comma‑separated)", | |
| ) | |
| run_btn = gr.Button("Classify", variant="primary") | |
| # Examples | |
| gr.Markdown("#### Load examples") | |
| with gr.Row(): | |
| btn_terramesh = gr.Button("TerraMesh") | |
| btn_eurosat = gr.Button("EuroSAT") | |
| btn_meterml = gr.Button("Meter-ML") | |
| out_html = gr.HTML(label="Results", | |
| elem_id="result_box", | |
| min_height=280) | |
| run_btn.click(classify, inputs=[img_in, cls_in], outputs=out_html) | |
| btn_terramesh.click( | |
| load_terramesh_example, | |
| outputs=[img_in, cls_in], | |
| ).then( | |
| classify, | |
| inputs=[img_in, cls_in], | |
| outputs=out_html, | |
| ) | |
| btn_eurosat.click( | |
| load_eurosat_example, | |
| outputs=[img_in, cls_in], | |
| ).then( | |
| classify, | |
| inputs=[img_in, cls_in], | |
| outputs=out_html, | |
| ) | |
| btn_meterml.click( | |
| load_meterml_example, | |
| outputs=[img_in, cls_in], | |
| ).then( | |
| classify, | |
| inputs=[img_in, cls_in], | |
| outputs=out_html, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True, ssr_mode=False) | |