|
""" |
|
SAM 2 Zero-Shot Segmentation Model |
|
|
|
This module implements zero-shot segmentation using SAM 2 with advanced |
|
text prompting, visual grounding, and attention-based prompt generation. |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from typing import Dict, List, Optional, Tuple, Union |
|
import numpy as np |
|
from PIL import Image |
|
import clip |
|
from segment_anything_2 import sam_model_registry, SamPredictor |
|
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel |
|
import cv2 |
|
|
|
|
|
class SAM2ZeroShot(nn.Module): |
|
""" |
|
SAM 2 Zero-Shot Segmentation Model |
|
|
|
Performs zero-shot segmentation using SAM 2 with advanced text prompting |
|
and visual grounding techniques. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
sam2_checkpoint: str, |
|
clip_model_name: str = "ViT-B/32", |
|
device: str = "cuda", |
|
use_attention_maps: bool = True, |
|
use_grounding_dino: bool = False, |
|
temperature: float = 0.1 |
|
): |
|
super().__init__() |
|
self.device = device |
|
self.temperature = temperature |
|
self.use_attention_maps = use_attention_maps |
|
self.use_grounding_dino = use_grounding_dino |
|
|
|
|
|
self.sam2 = sam_model_registry["vit_h"](checkpoint=sam2_checkpoint) |
|
self.sam2.to(device) |
|
self.sam2_predictor = SamPredictor(self.sam2) |
|
|
|
|
|
self.clip_model, self.clip_preprocess = clip.load(clip_model_name, device=device) |
|
self.clip_model.eval() |
|
|
|
|
|
if self.use_attention_maps: |
|
self.clip_text_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") |
|
self.clip_vision_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") |
|
self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") |
|
self.clip_text_model.to(device) |
|
self.clip_vision_model.to(device) |
|
|
|
|
|
self.advanced_prompts = { |
|
"satellite": { |
|
"building": [ |
|
"satellite view of buildings", "aerial photograph of structures", |
|
"overhead view of houses", "urban development from above", |
|
"rooftop structures", "architectural features from space" |
|
], |
|
"road": [ |
|
"satellite view of roads", "aerial photograph of streets", |
|
"overhead view of highways", "transportation network from above", |
|
"paved surfaces", "road infrastructure from space" |
|
], |
|
"vegetation": [ |
|
"satellite view of vegetation", "aerial photograph of forests", |
|
"overhead view of trees", "green areas from above", |
|
"natural landscape", "plant life from space" |
|
], |
|
"water": [ |
|
"satellite view of water", "aerial photograph of lakes", |
|
"overhead view of rivers", "water bodies from above", |
|
"aquatic features", "water resources from space" |
|
] |
|
}, |
|
"fashion": { |
|
"shirt": [ |
|
"fashion photography of shirts", "clothing item top", |
|
"apparel garment", "upper body clothing", |
|
"casual wear", "formal attire top" |
|
], |
|
"pants": [ |
|
"fashion photography of pants", "lower body clothing", |
|
"trousers garment", "leg wear", |
|
"casual pants", "formal trousers" |
|
], |
|
"dress": [ |
|
"fashion photography of dresses", "full body garment", |
|
"formal dress", "evening wear", |
|
"casual dress", "party dress" |
|
], |
|
"shoes": [ |
|
"fashion photography of shoes", "footwear item", |
|
"foot covering", "walking shoes", |
|
"casual footwear", "formal shoes" |
|
] |
|
}, |
|
"robotics": { |
|
"robot": [ |
|
"robotics environment with robot", "automation equipment", |
|
"mechanical arm", "industrial robot", |
|
"automated system", "robotic device" |
|
], |
|
"tool": [ |
|
"robotics environment with tools", "industrial equipment", |
|
"mechanical tools", "work equipment", |
|
"hand tools", "power tools" |
|
], |
|
"safety": [ |
|
"robotics environment with safety equipment", "protective gear", |
|
"safety helmet", "safety vest", |
|
"protective clothing", "safety equipment" |
|
] |
|
} |
|
} |
|
|
|
|
|
self.prompt_strategies = { |
|
"descriptive": lambda x: f"a clear image showing {x}", |
|
"contextual": lambda x: f"in a typical environment, {x}", |
|
"detailed": lambda x: f"high quality photograph of {x} with clear details", |
|
"contrastive": lambda x: f"{x} standing out from the background" |
|
} |
|
|
|
def generate_attention_maps( |
|
self, |
|
image: torch.Tensor, |
|
text_prompts: List[str] |
|
) -> torch.Tensor: |
|
"""Generate attention maps using CLIP's cross-attention.""" |
|
if not self.use_attention_maps: |
|
return None |
|
|
|
|
|
text_inputs = self.clip_tokenizer( |
|
text_prompts, |
|
padding=True, |
|
return_tensors="pt" |
|
).to(self.device) |
|
|
|
|
|
image_inputs = self.clip_preprocess(image).unsqueeze(0).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
text_outputs = self.clip_text_model(**text_inputs, output_attentions=True) |
|
vision_outputs = self.clip_vision_model(image_inputs, output_attentions=True) |
|
|
|
|
|
cross_attention = text_outputs.cross_attentions[-1] |
|
attention_maps = cross_attention.mean(dim=1) |
|
|
|
return attention_maps |
|
|
|
def extract_attention_points( |
|
self, |
|
attention_maps: torch.Tensor, |
|
num_points: int = 5 |
|
) -> List[Tuple[int, int]]: |
|
"""Extract points from attention maps for SAM 2 prompting.""" |
|
if attention_maps is None: |
|
return [] |
|
|
|
|
|
h, w = attention_maps.shape[-2:] |
|
attention_maps = F.interpolate( |
|
attention_maps.unsqueeze(0), |
|
size=(h, w), |
|
mode='bilinear' |
|
).squeeze(0) |
|
|
|
|
|
points = [] |
|
for i in range(min(num_points, attention_maps.shape[0])): |
|
attention_map = attention_maps[i] |
|
max_idx = torch.argmax(attention_map) |
|
y, x = max_idx // w, max_idx % w |
|
points.append((int(x), int(y))) |
|
|
|
return points |
|
|
|
def generate_enhanced_prompts( |
|
self, |
|
domain: str, |
|
class_names: List[str] |
|
) -> List[str]: |
|
"""Generate enhanced prompts using multiple strategies.""" |
|
enhanced_prompts = [] |
|
|
|
for class_name in class_names: |
|
if domain in self.advanced_prompts and class_name in self.advanced_prompts[domain]: |
|
base_prompts = self.advanced_prompts[domain][class_name] |
|
|
|
|
|
enhanced_prompts.extend(base_prompts) |
|
|
|
|
|
for strategy_name, strategy_func in self.prompt_strategies.items(): |
|
for base_prompt in base_prompts[:2]: |
|
enhanced_prompt = strategy_func(base_prompt) |
|
enhanced_prompts.append(enhanced_prompt) |
|
else: |
|
|
|
enhanced_prompts.append(class_name) |
|
enhanced_prompts.append(f"object: {class_name}") |
|
|
|
return enhanced_prompts |
|
|
|
def compute_text_image_similarity( |
|
self, |
|
image: torch.Tensor, |
|
text_prompts: List[str] |
|
) -> torch.Tensor: |
|
"""Compute similarity between image and text prompts.""" |
|
|
|
text_tokens = clip.tokenize(text_prompts).to(self.device) |
|
|
|
with torch.no_grad(): |
|
text_features = self.clip_model.encode_text(text_tokens) |
|
text_features = F.normalize(text_features, dim=-1) |
|
|
|
|
|
image_input = self.clip_preprocess(image).unsqueeze(0).to(self.device) |
|
image_features = self.clip_model.encode_image(image_input) |
|
image_features = F.normalize(image_features, dim=-1) |
|
|
|
|
|
similarity = torch.matmul(image_features, text_features.T) / self.temperature |
|
|
|
return similarity |
|
|
|
def generate_sam2_prompts( |
|
self, |
|
image: torch.Tensor, |
|
domain: str, |
|
class_names: List[str] |
|
) -> List[Dict]: |
|
"""Generate comprehensive SAM 2 prompts for zero-shot segmentation.""" |
|
prompts = [] |
|
|
|
|
|
text_prompts = self.generate_enhanced_prompts(domain, class_names) |
|
|
|
|
|
similarities = self.compute_text_image_similarity(image, text_prompts) |
|
|
|
|
|
attention_maps = self.generate_attention_maps(image, text_prompts) |
|
attention_points = self.extract_attention_points(attention_maps) |
|
|
|
|
|
for i, class_name in enumerate(class_names): |
|
class_prompts = [] |
|
|
|
|
|
class_text_indices = [] |
|
for j, prompt in enumerate(text_prompts): |
|
if class_name.lower() in prompt.lower(): |
|
class_text_indices.append(j) |
|
|
|
if class_text_indices: |
|
|
|
class_similarities = similarities[0, class_text_indices] |
|
best_idx = torch.argmax(class_similarities) |
|
best_similarity = class_similarities[best_idx] |
|
|
|
if best_similarity > 0.2: |
|
|
|
if attention_points: |
|
for point in attention_points[:3]: |
|
prompts.append({ |
|
'type': 'point', |
|
'data': point, |
|
'label': 1, |
|
'class': class_name, |
|
'confidence': best_similarity.item(), |
|
'source': 'attention' |
|
}) |
|
|
|
|
|
h, w = image.shape[-2:] |
|
center_point = [w // 2, h // 2] |
|
prompts.append({ |
|
'type': 'point', |
|
'data': center_point, |
|
'label': 1, |
|
'class': class_name, |
|
'confidence': best_similarity.item(), |
|
'source': 'center' |
|
}) |
|
|
|
|
|
if best_similarity > 0.4: |
|
box = [w // 4, h // 4, 3 * w // 4, 3 * h // 4] |
|
prompts.append({ |
|
'type': 'box', |
|
'data': box, |
|
'class': class_name, |
|
'confidence': best_similarity.item(), |
|
'source': 'similarity' |
|
}) |
|
|
|
return prompts |
|
|
|
def segment( |
|
self, |
|
image: torch.Tensor, |
|
domain: str, |
|
class_names: List[str] |
|
) -> Dict[str, torch.Tensor]: |
|
""" |
|
Perform zero-shot segmentation. |
|
|
|
Args: |
|
image: Input image tensor [C, H, W] |
|
domain: Domain name (satellite, fashion, robotics) |
|
class_names: List of class names to segment |
|
|
|
Returns: |
|
Dictionary with masks for each class |
|
""" |
|
|
|
if isinstance(image, torch.Tensor): |
|
image_np = image.permute(1, 2, 0).cpu().numpy() |
|
else: |
|
image_np = image |
|
|
|
|
|
self.sam2_predictor.set_image(image_np) |
|
|
|
|
|
prompts = self.generate_sam2_prompts(image, domain, class_names) |
|
|
|
results = {} |
|
|
|
for prompt in prompts: |
|
class_name = prompt['class'] |
|
|
|
if prompt['type'] == 'point': |
|
point = prompt['data'] |
|
label = prompt['label'] |
|
|
|
|
|
masks, scores, logits = self.sam2_predictor.predict( |
|
point_coords=np.array([point]), |
|
point_labels=np.array([label]), |
|
multimask_output=True |
|
) |
|
|
|
|
|
best_mask_idx = np.argmax(scores) |
|
mask = torch.from_numpy(masks[best_mask_idx]).float() |
|
|
|
|
|
if prompt['confidence'] > 0.2: |
|
if class_name not in results: |
|
results[class_name] = mask |
|
else: |
|
|
|
results[class_name] = torch.max(results[class_name], mask) |
|
|
|
elif prompt['type'] == 'box': |
|
box = prompt['data'] |
|
|
|
|
|
masks, scores, logits = self.sam2_predictor.predict( |
|
box=np.array(box), |
|
multimask_output=True |
|
) |
|
|
|
|
|
best_mask_idx = np.argmax(scores) |
|
mask = torch.from_numpy(masks[best_mask_idx]).float() |
|
|
|
|
|
if prompt['confidence'] > 0.3: |
|
if class_name not in results: |
|
results[class_name] = mask |
|
else: |
|
|
|
results[class_name] = torch.max(results[class_name], mask) |
|
|
|
return results |
|
|
|
def forward( |
|
self, |
|
image: torch.Tensor, |
|
domain: str, |
|
class_names: List[str] |
|
) -> Dict[str, torch.Tensor]: |
|
"""Forward pass.""" |
|
return self.segment(image, domain, class_names) |
|
|
|
|
|
class ZeroShotEvaluator: |
|
"""Evaluator for zero-shot segmentation.""" |
|
|
|
def __init__(self): |
|
self.metrics = {} |
|
|
|
def compute_iou(self, pred_mask: torch.Tensor, gt_mask: torch.Tensor) -> float: |
|
"""Compute Intersection over Union.""" |
|
intersection = (pred_mask & gt_mask).sum() |
|
union = (pred_mask | gt_mask).sum() |
|
return (intersection / union).item() if union > 0 else 0.0 |
|
|
|
def compute_dice(self, pred_mask: torch.Tensor, gt_mask: torch.Tensor) -> float: |
|
"""Compute Dice coefficient.""" |
|
intersection = (pred_mask & gt_mask).sum() |
|
total = pred_mask.sum() + gt_mask.sum() |
|
return (2 * intersection / total).item() if total > 0 else 0.0 |
|
|
|
def evaluate( |
|
self, |
|
predictions: Dict[str, torch.Tensor], |
|
ground_truth: Dict[str, torch.Tensor] |
|
) -> Dict[str, float]: |
|
"""Evaluate zero-shot segmentation results.""" |
|
results = {} |
|
|
|
for class_name in ground_truth.keys(): |
|
if class_name in predictions: |
|
pred_mask = predictions[class_name] > 0.5 |
|
gt_mask = ground_truth[class_name] > 0.5 |
|
|
|
iou = self.compute_iou(pred_mask, gt_mask) |
|
dice = self.compute_dice(pred_mask, gt_mask) |
|
|
|
results[f"{class_name}_iou"] = iou |
|
results[f"{class_name}_dice"] = dice |
|
|
|
|
|
if results: |
|
results['mean_iou'] = np.mean([v for k, v in results.items() if 'iou' in k]) |
|
results['mean_dice'] = np.mean([v for k, v in results.items() if 'dice' in k]) |
|
|
|
return results |