Segmentation / models /sam2_zeroshot.py
Edwin Salguero
Initial commit: SAM 2 Few-Shot/Zero-Shot Segmentation Research Framework
12fa055
"""
SAM 2 Zero-Shot Segmentation Model
This module implements zero-shot segmentation using SAM 2 with advanced
text prompting, visual grounding, and attention-based prompt generation.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from PIL import Image
import clip
from segment_anything_2 import sam_model_registry, SamPredictor
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
import cv2
class SAM2ZeroShot(nn.Module):
"""
SAM 2 Zero-Shot Segmentation Model
Performs zero-shot segmentation using SAM 2 with advanced text prompting
and visual grounding techniques.
"""
def __init__(
self,
sam2_checkpoint: str,
clip_model_name: str = "ViT-B/32",
device: str = "cuda",
use_attention_maps: bool = True,
use_grounding_dino: bool = False,
temperature: float = 0.1
):
super().__init__()
self.device = device
self.temperature = temperature
self.use_attention_maps = use_attention_maps
self.use_grounding_dino = use_grounding_dino
# Initialize SAM 2
self.sam2 = sam_model_registry["vit_h"](checkpoint=sam2_checkpoint)
self.sam2.to(device)
self.sam2_predictor = SamPredictor(self.sam2)
# Initialize CLIP
self.clip_model, self.clip_preprocess = clip.load(clip_model_name, device=device)
self.clip_model.eval()
# Initialize CLIP text and vision models for attention
if self.use_attention_maps:
self.clip_text_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_vision_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
self.clip_text_model.to(device)
self.clip_vision_model.to(device)
# Advanced prompt templates with domain-specific variations
self.advanced_prompts = {
"satellite": {
"building": [
"satellite view of buildings", "aerial photograph of structures",
"overhead view of houses", "urban development from above",
"rooftop structures", "architectural features from space"
],
"road": [
"satellite view of roads", "aerial photograph of streets",
"overhead view of highways", "transportation network from above",
"paved surfaces", "road infrastructure from space"
],
"vegetation": [
"satellite view of vegetation", "aerial photograph of forests",
"overhead view of trees", "green areas from above",
"natural landscape", "plant life from space"
],
"water": [
"satellite view of water", "aerial photograph of lakes",
"overhead view of rivers", "water bodies from above",
"aquatic features", "water resources from space"
]
},
"fashion": {
"shirt": [
"fashion photography of shirts", "clothing item top",
"apparel garment", "upper body clothing",
"casual wear", "formal attire top"
],
"pants": [
"fashion photography of pants", "lower body clothing",
"trousers garment", "leg wear",
"casual pants", "formal trousers"
],
"dress": [
"fashion photography of dresses", "full body garment",
"formal dress", "evening wear",
"casual dress", "party dress"
],
"shoes": [
"fashion photography of shoes", "footwear item",
"foot covering", "walking shoes",
"casual footwear", "formal shoes"
]
},
"robotics": {
"robot": [
"robotics environment with robot", "automation equipment",
"mechanical arm", "industrial robot",
"automated system", "robotic device"
],
"tool": [
"robotics environment with tools", "industrial equipment",
"mechanical tools", "work equipment",
"hand tools", "power tools"
],
"safety": [
"robotics environment with safety equipment", "protective gear",
"safety helmet", "safety vest",
"protective clothing", "safety equipment"
]
}
}
# Prompt enhancement strategies
self.prompt_strategies = {
"descriptive": lambda x: f"a clear image showing {x}",
"contextual": lambda x: f"in a typical environment, {x}",
"detailed": lambda x: f"high quality photograph of {x} with clear details",
"contrastive": lambda x: f"{x} standing out from the background"
}
def generate_attention_maps(
self,
image: torch.Tensor,
text_prompts: List[str]
) -> torch.Tensor:
"""Generate attention maps using CLIP's cross-attention."""
if not self.use_attention_maps:
return None
# Tokenize text prompts
text_inputs = self.clip_tokenizer(
text_prompts,
padding=True,
return_tensors="pt"
).to(self.device)
# Get image features
image_inputs = self.clip_preprocess(image).unsqueeze(0).to(self.device)
# Get attention maps from CLIP
with torch.no_grad():
text_outputs = self.clip_text_model(**text_inputs, output_attentions=True)
vision_outputs = self.clip_vision_model(image_inputs, output_attentions=True)
# Extract cross-attention maps
cross_attention = text_outputs.cross_attentions[-1] # Last layer
attention_maps = cross_attention.mean(dim=1) # Average over heads
return attention_maps
def extract_attention_points(
self,
attention_maps: torch.Tensor,
num_points: int = 5
) -> List[Tuple[int, int]]:
"""Extract points from attention maps for SAM 2 prompting."""
if attention_maps is None:
return []
# Resize attention map to image size
h, w = attention_maps.shape[-2:]
attention_maps = F.interpolate(
attention_maps.unsqueeze(0),
size=(h, w),
mode='bilinear'
).squeeze(0)
# Find top attention points
points = []
for i in range(min(num_points, attention_maps.shape[0])):
attention_map = attention_maps[i]
max_idx = torch.argmax(attention_map)
y, x = max_idx // w, max_idx % w
points.append((int(x), int(y)))
return points
def generate_enhanced_prompts(
self,
domain: str,
class_names: List[str]
) -> List[str]:
"""Generate enhanced prompts using multiple strategies."""
enhanced_prompts = []
for class_name in class_names:
if domain in self.advanced_prompts and class_name in self.advanced_prompts[domain]:
base_prompts = self.advanced_prompts[domain][class_name]
# Add base prompts
enhanced_prompts.extend(base_prompts)
# Add strategy-enhanced prompts
for strategy_name, strategy_func in self.prompt_strategies.items():
for base_prompt in base_prompts[:2]: # Use first 2 base prompts
enhanced_prompt = strategy_func(base_prompt)
enhanced_prompts.append(enhanced_prompt)
else:
# Fallback for unknown classes
enhanced_prompts.append(class_name)
enhanced_prompts.append(f"object: {class_name}")
return enhanced_prompts
def compute_text_image_similarity(
self,
image: torch.Tensor,
text_prompts: List[str]
) -> torch.Tensor:
"""Compute similarity between image and text prompts."""
# Tokenize and encode text
text_tokens = clip.tokenize(text_prompts).to(self.device)
with torch.no_grad():
text_features = self.clip_model.encode_text(text_tokens)
text_features = F.normalize(text_features, dim=-1)
# Encode image
image_input = self.clip_preprocess(image).unsqueeze(0).to(self.device)
image_features = self.clip_model.encode_image(image_input)
image_features = F.normalize(image_features, dim=-1)
# Compute similarity
similarity = torch.matmul(image_features, text_features.T) / self.temperature
return similarity
def generate_sam2_prompts(
self,
image: torch.Tensor,
domain: str,
class_names: List[str]
) -> List[Dict]:
"""Generate comprehensive SAM 2 prompts for zero-shot segmentation."""
prompts = []
# Generate enhanced text prompts
text_prompts = self.generate_enhanced_prompts(domain, class_names)
# Compute text-image similarity
similarities = self.compute_text_image_similarity(image, text_prompts)
# Generate attention maps
attention_maps = self.generate_attention_maps(image, text_prompts)
attention_points = self.extract_attention_points(attention_maps)
# Create prompts for each class
for i, class_name in enumerate(class_names):
class_prompts = []
# Find relevant text prompts for this class
class_text_indices = []
for j, prompt in enumerate(text_prompts):
if class_name.lower() in prompt.lower():
class_text_indices.append(j)
if class_text_indices:
# Get best similarity for this class
class_similarities = similarities[0, class_text_indices]
best_idx = torch.argmax(class_similarities)
best_similarity = class_similarities[best_idx]
if best_similarity > 0.2: # Threshold for relevance
# Add attention-based points
if attention_points:
for point in attention_points[:3]: # Use top 3 points
prompts.append({
'type': 'point',
'data': point,
'label': 1,
'class': class_name,
'confidence': best_similarity.item(),
'source': 'attention'
})
# Add center point as fallback
h, w = image.shape[-2:]
center_point = [w // 2, h // 2]
prompts.append({
'type': 'point',
'data': center_point,
'label': 1,
'class': class_name,
'confidence': best_similarity.item(),
'source': 'center'
})
# Add bounding box prompt (simple rectangle)
if best_similarity > 0.4: # Higher threshold for box prompts
box = [w // 4, h // 4, 3 * w // 4, 3 * h // 4]
prompts.append({
'type': 'box',
'data': box,
'class': class_name,
'confidence': best_similarity.item(),
'source': 'similarity'
})
return prompts
def segment(
self,
image: torch.Tensor,
domain: str,
class_names: List[str]
) -> Dict[str, torch.Tensor]:
"""
Perform zero-shot segmentation.
Args:
image: Input image tensor [C, H, W]
domain: Domain name (satellite, fashion, robotics)
class_names: List of class names to segment
Returns:
Dictionary with masks for each class
"""
# Convert image for SAM 2
if isinstance(image, torch.Tensor):
image_np = image.permute(1, 2, 0).cpu().numpy()
else:
image_np = image
# Set image in SAM 2 predictor
self.sam2_predictor.set_image(image_np)
# Generate prompts
prompts = self.generate_sam2_prompts(image, domain, class_names)
results = {}
for prompt in prompts:
class_name = prompt['class']
if prompt['type'] == 'point':
point = prompt['data']
label = prompt['label']
# Get SAM 2 prediction
masks, scores, logits = self.sam2_predictor.predict(
point_coords=np.array([point]),
point_labels=np.array([label]),
multimask_output=True
)
# Select best mask
best_mask_idx = np.argmax(scores)
mask = torch.from_numpy(masks[best_mask_idx]).float()
# Apply confidence threshold
if prompt['confidence'] > 0.2:
if class_name not in results:
results[class_name] = mask
else:
# Combine masks if multiple prompts for same class
results[class_name] = torch.max(results[class_name], mask)
elif prompt['type'] == 'box':
box = prompt['data']
# Get SAM 2 prediction with box
masks, scores, logits = self.sam2_predictor.predict(
box=np.array(box),
multimask_output=True
)
# Select best mask
best_mask_idx = np.argmax(scores)
mask = torch.from_numpy(masks[best_mask_idx]).float()
# Apply confidence threshold
if prompt['confidence'] > 0.3:
if class_name not in results:
results[class_name] = mask
else:
# Combine masks
results[class_name] = torch.max(results[class_name], mask)
return results
def forward(
self,
image: torch.Tensor,
domain: str,
class_names: List[str]
) -> Dict[str, torch.Tensor]:
"""Forward pass."""
return self.segment(image, domain, class_names)
class ZeroShotEvaluator:
"""Evaluator for zero-shot segmentation."""
def __init__(self):
self.metrics = {}
def compute_iou(self, pred_mask: torch.Tensor, gt_mask: torch.Tensor) -> float:
"""Compute Intersection over Union."""
intersection = (pred_mask & gt_mask).sum()
union = (pred_mask | gt_mask).sum()
return (intersection / union).item() if union > 0 else 0.0
def compute_dice(self, pred_mask: torch.Tensor, gt_mask: torch.Tensor) -> float:
"""Compute Dice coefficient."""
intersection = (pred_mask & gt_mask).sum()
total = pred_mask.sum() + gt_mask.sum()
return (2 * intersection / total).item() if total > 0 else 0.0
def evaluate(
self,
predictions: Dict[str, torch.Tensor],
ground_truth: Dict[str, torch.Tensor]
) -> Dict[str, float]:
"""Evaluate zero-shot segmentation results."""
results = {}
for class_name in ground_truth.keys():
if class_name in predictions:
pred_mask = predictions[class_name] > 0.5 # Threshold
gt_mask = ground_truth[class_name] > 0.5
iou = self.compute_iou(pred_mask, gt_mask)
dice = self.compute_dice(pred_mask, gt_mask)
results[f"{class_name}_iou"] = iou
results[f"{class_name}_dice"] = dice
# Compute average metrics
if results:
results['mean_iou'] = np.mean([v for k, v in results.items() if 'iou' in k])
results['mean_dice'] = np.mean([v for k, v in results.items() if 'dice' in k])
return results