|
""" |
|
SAM 2 Few-Shot Learning Model |
|
|
|
This module implements a few-shot segmentation model that combines SAM 2 with CLIP |
|
for domain adaptation using minimal labeled examples. |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from typing import Dict, List, Optional, Tuple, Union |
|
import numpy as np |
|
from PIL import Image |
|
import clip |
|
from segment_anything_2 import sam_model_registry, SamPredictor |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
|
|
|
class SAM2FewShot(nn.Module): |
|
""" |
|
SAM 2 Few-Shot Learning Model |
|
|
|
Combines SAM 2 with CLIP for few-shot and zero-shot segmentation |
|
across different domains (satellite, fashion, robotics). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
sam2_checkpoint: str, |
|
clip_model_name: str = "ViT-B/32", |
|
device: str = "cuda", |
|
prompt_engineering: bool = True, |
|
visual_similarity: bool = True, |
|
temperature: float = 0.1 |
|
): |
|
super().__init__() |
|
self.device = device |
|
self.temperature = temperature |
|
self.prompt_engineering = prompt_engineering |
|
self.visual_similarity = visual_similarity |
|
|
|
|
|
self.sam2 = sam_model_registry["vit_h"](checkpoint=sam2_checkpoint) |
|
self.sam2.to(device) |
|
self.sam2_predictor = SamPredictor(self.sam2) |
|
|
|
|
|
self.clip_model, self.clip_preprocess = clip.load(clip_model_name, device=device) |
|
self.clip_model.eval() |
|
|
|
|
|
self.domain_prompts = { |
|
"satellite": { |
|
"building": ["building", "house", "structure", "rooftop"], |
|
"road": ["road", "street", "highway", "pavement"], |
|
"vegetation": ["vegetation", "forest", "trees", "green area"], |
|
"water": ["water", "lake", "river", "ocean", "pond"] |
|
}, |
|
"fashion": { |
|
"shirt": ["shirt", "t-shirt", "blouse", "top"], |
|
"pants": ["pants", "trousers", "jeans", "legs"], |
|
"dress": ["dress", "gown", "outfit"], |
|
"shoes": ["shoes", "footwear", "sneakers", "boots"] |
|
}, |
|
"robotics": { |
|
"robot": ["robot", "automation", "mechanical arm"], |
|
"tool": ["tool", "wrench", "screwdriver", "equipment"], |
|
"safety": ["safety equipment", "helmet", "vest", "protection"] |
|
} |
|
} |
|
|
|
|
|
self.few_shot_memory = {} |
|
|
|
def encode_text_prompts(self, domain: str, class_names: List[str]) -> torch.Tensor: |
|
"""Encode text prompts for given domain and classes.""" |
|
prompts = [] |
|
for class_name in class_names: |
|
if domain in self.domain_prompts and class_name in self.domain_prompts[domain]: |
|
prompts.extend(self.domain_prompts[domain][class_name]) |
|
else: |
|
prompts.append(class_name) |
|
|
|
|
|
if domain == "satellite": |
|
prompts = [f"satellite image of {p}" for p in prompts] |
|
elif domain == "fashion": |
|
prompts = [f"fashion item {p}" for p in prompts] |
|
elif domain == "robotics": |
|
prompts = [f"robotics environment {p}" for p in prompts] |
|
|
|
text_tokens = clip.tokenize(prompts).to(self.device) |
|
with torch.no_grad(): |
|
text_features = self.clip_model.encode_text(text_tokens) |
|
text_features = F.normalize(text_features, dim=-1) |
|
|
|
return text_features |
|
|
|
def encode_image(self, image: Union[torch.Tensor, np.ndarray, Image.Image]) -> torch.Tensor: |
|
"""Encode image using CLIP.""" |
|
if isinstance(image, torch.Tensor): |
|
if image.dim() == 4: |
|
image = image.squeeze(0) |
|
image = image.permute(1, 2, 0).cpu().numpy() |
|
|
|
if isinstance(image, np.ndarray): |
|
image = Image.fromarray(image) |
|
|
|
|
|
clip_image = self.clip_preprocess(image).unsqueeze(0).to(self.device) |
|
|
|
with torch.no_grad(): |
|
image_features = self.clip_model.encode_image(clip_image) |
|
image_features = F.normalize(image_features, dim=-1) |
|
|
|
return image_features |
|
|
|
def compute_similarity( |
|
self, |
|
image_features: torch.Tensor, |
|
text_features: torch.Tensor |
|
) -> torch.Tensor: |
|
"""Compute similarity between image and text features.""" |
|
similarity = torch.matmul(image_features, text_features.T) / self.temperature |
|
return similarity |
|
|
|
def add_few_shot_example( |
|
self, |
|
domain: str, |
|
class_name: str, |
|
image: torch.Tensor, |
|
mask: torch.Tensor |
|
): |
|
"""Add a few-shot example to the memory bank.""" |
|
if domain not in self.few_shot_memory: |
|
self.few_shot_memory[domain] = {} |
|
|
|
if class_name not in self.few_shot_memory[domain]: |
|
self.few_shot_memory[domain][class_name] = [] |
|
|
|
|
|
image_features = self.encode_image(image) |
|
|
|
self.few_shot_memory[domain][class_name].append({ |
|
'image_features': image_features, |
|
'mask': mask, |
|
'image': image |
|
}) |
|
|
|
def get_few_shot_similarity( |
|
self, |
|
query_image: torch.Tensor, |
|
domain: str, |
|
class_name: str |
|
) -> torch.Tensor: |
|
"""Compute similarity with few-shot examples.""" |
|
if domain not in self.few_shot_memory or class_name not in self.few_shot_memory[domain]: |
|
return torch.zeros(1, device=self.device) |
|
|
|
query_features = self.encode_image(query_image) |
|
similarities = [] |
|
|
|
for example in self.few_shot_memory[domain][class_name]: |
|
similarity = F.cosine_similarity( |
|
query_features, |
|
example['image_features'], |
|
dim=-1 |
|
) |
|
similarities.append(similarity) |
|
|
|
return torch.stack(similarities).mean() |
|
|
|
def generate_sam2_prompts( |
|
self, |
|
image: torch.Tensor, |
|
domain: str, |
|
class_names: List[str], |
|
use_few_shot: bool = True |
|
) -> List[Dict]: |
|
"""Generate SAM 2 prompts based on text and few-shot similarity.""" |
|
prompts = [] |
|
|
|
|
|
if self.prompt_engineering: |
|
text_features = self.encode_text_prompts(domain, class_names) |
|
image_features = self.encode_image(image) |
|
text_similarities = self.compute_similarity(image_features, text_features) |
|
|
|
|
|
for i, class_name in enumerate(class_names): |
|
if text_similarities[0, i] > 0.3: |
|
|
|
h, w = image.shape[-2:] |
|
point = [w // 2, h // 2] |
|
prompts.append({ |
|
'type': 'point', |
|
'data': point, |
|
'label': 1, |
|
'class': class_name, |
|
'confidence': text_similarities[0, i].item() |
|
}) |
|
|
|
|
|
if use_few_shot and self.visual_similarity: |
|
for class_name in class_names: |
|
few_shot_sim = self.get_few_shot_similarity(image, domain, class_name) |
|
if few_shot_sim > 0.5: |
|
h, w = image.shape[-2:] |
|
point = [w // 2, h // 2] |
|
prompts.append({ |
|
'type': 'point', |
|
'data': point, |
|
'label': 1, |
|
'class': class_name, |
|
'confidence': few_shot_sim.item() |
|
}) |
|
|
|
return prompts |
|
|
|
def segment( |
|
self, |
|
image: torch.Tensor, |
|
domain: str, |
|
class_names: List[str], |
|
use_few_shot: bool = True |
|
) -> Dict[str, torch.Tensor]: |
|
""" |
|
Perform few-shot/zero-shot segmentation. |
|
|
|
Args: |
|
image: Input image tensor [C, H, W] |
|
domain: Domain name (satellite, fashion, robotics) |
|
class_names: List of class names to segment |
|
use_few_shot: Whether to use few-shot examples |
|
|
|
Returns: |
|
Dictionary with masks for each class |
|
""" |
|
|
|
if isinstance(image, torch.Tensor): |
|
image_np = image.permute(1, 2, 0).cpu().numpy() |
|
else: |
|
image_np = image |
|
|
|
|
|
self.sam2_predictor.set_image(image_np) |
|
|
|
|
|
prompts = self.generate_sam2_prompts(image, domain, class_names, use_few_shot) |
|
|
|
results = {} |
|
|
|
for prompt in prompts: |
|
class_name = prompt['class'] |
|
|
|
if prompt['type'] == 'point': |
|
point = prompt['data'] |
|
label = prompt['label'] |
|
|
|
|
|
masks, scores, logits = self.sam2_predictor.predict( |
|
point_coords=np.array([point]), |
|
point_labels=np.array([label]), |
|
multimask_output=True |
|
) |
|
|
|
|
|
best_mask_idx = np.argmax(scores) |
|
mask = torch.from_numpy(masks[best_mask_idx]).float() |
|
|
|
|
|
if prompt['confidence'] > 0.3: |
|
results[class_name] = mask |
|
|
|
return results |
|
|
|
def forward( |
|
self, |
|
image: torch.Tensor, |
|
domain: str, |
|
class_names: List[str], |
|
use_few_shot: bool = True |
|
) -> Dict[str, torch.Tensor]: |
|
"""Forward pass for training.""" |
|
return self.segment(image, domain, class_names, use_few_shot) |
|
|
|
|
|
class FewShotTrainer: |
|
"""Trainer for few-shot segmentation.""" |
|
|
|
def __init__(self, model: SAM2FewShot, learning_rate: float = 1e-4): |
|
self.model = model |
|
self.optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) |
|
self.criterion = nn.BCELoss() |
|
|
|
def train_step( |
|
self, |
|
support_images: List[torch.Tensor], |
|
support_masks: List[torch.Tensor], |
|
query_image: torch.Tensor, |
|
query_mask: torch.Tensor, |
|
domain: str, |
|
class_name: str |
|
): |
|
"""Single training step.""" |
|
self.model.train() |
|
|
|
|
|
for img, mask in zip(support_images, support_masks): |
|
self.model.add_few_shot_example(domain, class_name, img, mask) |
|
|
|
|
|
predictions = self.model(query_image, domain, [class_name], use_few_shot=True) |
|
|
|
if class_name in predictions: |
|
pred_mask = predictions[class_name] |
|
loss = self.criterion(pred_mask, query_mask) |
|
else: |
|
|
|
loss = torch.tensor(0.0, device=self.model.device, requires_grad=True) |
|
|
|
|
|
self.optimizer.zero_grad() |
|
loss.backward() |
|
self.optimizer.step() |
|
|
|
return loss.item() |