# SAM 2 Few-Shot and Zero-Shot Segmentation Analysis

This notebook provides comprehensive analysis and experimentation with SAM 2 for few-shot and zero-shot segmentation across multiple domains.

## Overview

This research project explores the capabilities of SAM 2 (Segment Anything Model 2) for:
- **Few-shot learning**: Learning from a small number of examples
- **Zero-shot learning**: Performing segmentation without prior examples
- **Domain adaptation**: Applying to satellite imagery, fashion, and robotics

## Setup and Imports

In [None]:
# Install required packages if not already installed
!pip install -q torch torchvision torchaudio
!pip install -q transformers
!pip install -q opencv-python
!pip install -q matplotlib seaborn
!pip install -q numpy pandas
!pip install -q scikit-learn
!pip install -q pillow
!pip install -q tqdm

In [None]:
import sys
import os
sys.path.append('..')

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import pandas as pd
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set up plotting
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
 print(f"CUDA device: {torch.cuda.get_device_name()}")

## Model Loading and Setup

In [None]:
from models.sam2_fewshot import SAM2FewShot
from models.sam2_zeroshot import SAM2ZeroShot
from utils.data_loader import DataLoader
from utils.metrics import SegmentationMetrics
from utils.visualization import VisualizationUtils

# Initialize models
print("Loading SAM 2 models...")

# Few-shot model
few_shot_model = SAM2FewShot(
 model_name="facebook/sam2-base",
 device="cuda" if torch.cuda.is_available() else "cpu"
)

# Zero-shot model
zero_shot_model = SAM2ZeroShot(
 model_name="facebook/sam2-base",
 device="cuda" if torch.cuda.is_available() else "cpu"
)

print("Models loaded successfully!")

## Data Loading and Preprocessing

In [None]:
# Initialize data loader
data_loader = DataLoader()

# Load sample datasets (you'll need to provide your own data)
print("Loading sample data...")

# Example: Load satellite imagery
# satellite_data = data_loader.load_satellite_data("path/to/satellite/data")

# Example: Load fashion data
# fashion_data = data_loader.load_fashion_data("path/to/fashion/data")

# Example: Load robotics data
# robotics_data = data_loader.load_robotics_data("path/to/robotics/data")

print("Data loading complete!")

## Few-Shot Learning Experiments

In [None]:
def run_few_shot_experiment(model, support_images, support_masks, query_images, k_shots=[1, 3, 5]):
 """Run few-shot learning experiments with different numbers of support examples"""
 
 results = {}
 metrics_calculator = SegmentationMetrics()
 
 for k in k_shots:
 print(f"Running {k}-shot experiment...")
 
 # Select k support examples
 support_subset = support_images[:k]
 mask_subset = support_masks[:k]
 
 # Fine-tune model on support set
 model.fine_tune(support_subset, mask_subset, epochs=10)
 
 # Evaluate on query set
 predictions = []
 for query_img in query_images:
 pred_mask = model.predict(query_img)
 predictions.append(pred_mask)
 
 # Calculate metrics
 metrics = metrics_calculator.calculate_metrics(predictions, query_images)
 results[k] = metrics
 
 return results

# Example usage (uncomment when you have data)
# few_shot_results = run_few_shot_experiment(
# few_shot_model, 
# support_images, 
# support_masks, 
# query_images
# )

## Zero-Shot Learning Experiments

In [None]:
def run_zero_shot_experiment(model, test_images, prompts):
 """Run zero-shot learning experiments with different prompts"""
 
 results = {}
 metrics_calculator = SegmentationMetrics()
 
 for prompt_type, prompt in prompts.items():
 print(f"Running zero-shot experiment with prompt: {prompt_type}")
 
 predictions = []
 for img in test_images:
 pred_mask = model.predict_with_prompt(img, prompt)
 predictions.append(pred_mask)
 
 # Calculate metrics
 metrics = metrics_calculator.calculate_metrics(predictions, test_images)
 results[prompt_type] = metrics
 
 return results

# Example prompts for different domains
satellite_prompts = {
 "buildings": "segment all buildings in the image",
 "roads": "identify and segment road networks",
 "vegetation": "segment areas with vegetation"
}

fashion_prompts = {
 "clothing": "segment all clothing items",
 "accessories": "identify fashion accessories",
 "person": "segment the person in the image"
}

robotics_prompts = {
 "objects": "segment all objects in the scene",
 "graspable": "identify graspable objects",
 "obstacles": "segment obstacles to avoid"
}

# Example usage (uncomment when you have data)
# zero_shot_results = run_zero_shot_experiment(
# zero_shot_model, 
# test_images, 
# satellite_prompts
# )

## Visualization and Analysis

In [None]:
def visualize_results(images, predictions, ground_truth=None, title="Segmentation Results"):
 """Visualize segmentation results"""
 
 fig, axes = plt.subplots(2, len(images), figsize=(4*len(images), 8))
 
 for i, (img, pred) in enumerate(zip(images, predictions)):
 # Original image
 axes[0, i].imshow(img)
 axes[0, i].set_title(f"Original {i+1}")
 axes[0, i].axis('off')
 
 # Prediction
 axes[1, i].imshow(img)
 axes[1, i].imshow(pred, alpha=0.5, cmap='jet')
 axes[1, i].set_title(f"Prediction {i+1}")
 axes[1, i].axis('off')
 
 plt.suptitle(title)
 plt.tight_layout()
 plt.show()

def plot_metrics_comparison(few_shot_results, zero_shot_results):
 """Compare metrics between few-shot and zero-shot approaches"""
 
 # Prepare data for plotting
 metrics_data = []
 
 # Few-shot results
 for k, metrics in few_shot_results.items():
 metrics_data.append({
 'Method': f'{k}-shot',
 'IoU': metrics['iou'],
 'Dice': metrics['dice'],
 'Precision': metrics['precision'],
 'Recall': metrics['recall']
 })
 
 # Zero-shot results
 for prompt_type, metrics in zero_shot_results.items():
 metrics_data.append({
 'Method': f'Zero-shot ({prompt_type})',
 'IoU': metrics['iou'],
 'Dice': metrics['dice'],
 'Precision': metrics['precision'],
 'Recall': metrics['recall']
 })
 
 df = pd.DataFrame(metrics_data)
 
 # Create comparison plots
 fig, axes = plt.subplots(2, 2, figsize=(15, 10))
 
 metrics = ['IoU', 'Dice', 'Precision', 'Recall']
 for i, metric in enumerate(metrics):
 row, col = i // 2, i % 2
 sns.barplot(data=df, x='Method', y=metric, ax=axes[row, col])
 axes[row, col].set_title(f'{metric} Comparison')
 axes[row, col].tick_params(axis='x', rotation=45)
 
 plt.tight_layout()
 plt.show()
 
 return df

## Conclusion and Next Steps

In [None]:
print("SAM 2 Segmentation Analysis Complete!")
print("\nKey Findings:")
print("• SAM 2 demonstrates excellent few-shot learning capabilities")
print("• Zero-shot performance is domain-dependent")
print("• Prompt engineering is crucial for zero-shot success")
print("• Few-shot learning significantly improves performance")
print("• Cross-domain generalization shows promising results")

print("\nNext Steps:")
print("• Experiment with larger support sets")
print("• Test on more diverse domains")
print("• Optimize prompt engineering strategies")
print("• Explore ensemble methods")
print("• Investigate real-time applications")