{ "cells": [ { "cell_type": "markdown", "id": "b6d55b5c", "metadata": {}, "source": [ "# SAM 2 Few-Shot and Zero-Shot Segmentation Analysis\n", "\n", "This notebook provides comprehensive analysis and experimentation with SAM 2 for few-shot and zero-shot segmentation across multiple domains.\n", "\n", "## Overview\n", "\n", "This research project explores the capabilities of SAM 2 (Segment Anything Model 2) for:\n", "- **Few-shot learning**: Learning from a small number of examples\n", "- **Zero-shot learning**: Performing segmentation without prior examples\n", "- **Domain adaptation**: Applying to satellite imagery, fashion, and robotics\n", "\n", "## Setup and Imports" ] }, { "cell_type": "code", "execution_count": null, "id": "987de1f2", "metadata": {}, "outputs": [], "source": [ "# Install required packages if not already installed\n", "!pip install -q torch torchvision torchaudio\n", "!pip install -q transformers\n", "!pip install -q opencv-python\n", "!pip install -q matplotlib seaborn\n", "!pip install -q numpy pandas\n", "!pip install -q scikit-learn\n", "!pip install -q pillow\n", "!pip install -q tqdm" ] }, { "cell_type": "code", "execution_count": null, "id": "e5656564", "metadata": {}, "outputs": [], "source": [ "import sys\n", "import os\n", "sys.path.append('..')\n", "\n", "import torch\n", "import torch.nn.functional as F\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import cv2\n", "from PIL import Image\n", "import pandas as pd\n", "import seaborn as sns\n", "from tqdm import tqdm\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "# Set up plotting\n", "plt.style.use('seaborn-v0_8')\n", "sns.set_palette(\"husl\")\n", "\n", "print(f\"PyTorch version: {torch.__version__}\")\n", "print(f\"CUDA available: {torch.cuda.is_available()}\")\n", "if torch.cuda.is_available():\n", " print(f\"CUDA device: {torch.cuda.get_device_name()}\")" ] }, { "cell_type": "markdown", "id": "13e84f1b", "metadata": {}, "source": [ "## Model Loading and Setup" ] }, { "cell_type": "code", "execution_count": null, "id": "8bfad52b", "metadata": {}, "outputs": [], "source": [ "from models.sam2_fewshot import SAM2FewShot\n", "from models.sam2_zeroshot import SAM2ZeroShot\n", "from utils.data_loader import DataLoader\n", "from utils.metrics import SegmentationMetrics\n", "from utils.visualization import VisualizationUtils\n", "\n", "# Initialize models\n", "print(\"Loading SAM 2 models...\")\n", "\n", "# Few-shot model\n", "few_shot_model = SAM2FewShot(\n", " model_name=\"facebook/sam2-base\",\n", " device=\"cuda\" if torch.cuda.is_available() else \"cpu\"\n", ")\n", "\n", "# Zero-shot model\n", "zero_shot_model = SAM2ZeroShot(\n", " model_name=\"facebook/sam2-base\",\n", " device=\"cuda\" if torch.cuda.is_available() else \"cpu\"\n", ")\n", "\n", "print(\"Models loaded successfully!\")" ] }, { "cell_type": "markdown", "id": "2aa20553", "metadata": {}, "source": [ "## Data Loading and Preprocessing" ] }, { "cell_type": "code", "execution_count": null, "id": "a8ec2189", "metadata": {}, "outputs": [], "source": [ "# Initialize data loader\n", "data_loader = DataLoader()\n", "\n", "# Load sample datasets (you'll need to provide your own data)\n", "print(\"Loading sample data...\")\n", "\n", "# Example: Load satellite imagery\n", "# satellite_data = data_loader.load_satellite_data(\"path/to/satellite/data\")\n", "\n", "# Example: Load fashion data\n", "# fashion_data = data_loader.load_fashion_data(\"path/to/fashion/data\")\n", "\n", "# Example: Load robotics data\n", "# robotics_data = data_loader.load_robotics_data(\"path/to/robotics/data\")\n", "\n", "print(\"Data loading complete!\")" ] }, { "cell_type": "markdown", "id": "d7b7555d", "metadata": {}, "source": [ "## Few-Shot Learning Experiments" ] }, { "cell_type": "code", "execution_count": null, "id": "b1ab2212", "metadata": {}, "outputs": [], "source": [ "def run_few_shot_experiment(model, support_images, support_masks, query_images, k_shots=[1, 3, 5]):\n", " \"\"\"Run few-shot learning experiments with different numbers of support examples\"\"\"\n", " \n", " results = {}\n", " metrics_calculator = SegmentationMetrics()\n", " \n", " for k in k_shots:\n", " print(f\"Running {k}-shot experiment...\")\n", " \n", " # Select k support examples\n", " support_subset = support_images[:k]\n", " mask_subset = support_masks[:k]\n", " \n", " # Fine-tune model on support set\n", " model.fine_tune(support_subset, mask_subset, epochs=10)\n", " \n", " # Evaluate on query set\n", " predictions = []\n", " for query_img in query_images:\n", " pred_mask = model.predict(query_img)\n", " predictions.append(pred_mask)\n", " \n", " # Calculate metrics\n", " metrics = metrics_calculator.calculate_metrics(predictions, query_images)\n", " results[k] = metrics\n", " \n", " return results\n", "\n", "# Example usage (uncomment when you have data)\n", "# few_shot_results = run_few_shot_experiment(\n", "# few_shot_model, \n", "# support_images, \n", "# support_masks, \n", "# query_images\n", "# )" ] }, { "cell_type": "markdown", "id": "93daedf3", "metadata": {}, "source": [ "## Zero-Shot Learning Experiments" ] }, { "cell_type": "code", "execution_count": null, "id": "cd13d688", "metadata": {}, "outputs": [], "source": [ "def run_zero_shot_experiment(model, test_images, prompts):\n", " \"\"\"Run zero-shot learning experiments with different prompts\"\"\"\n", " \n", " results = {}\n", " metrics_calculator = SegmentationMetrics()\n", " \n", " for prompt_type, prompt in prompts.items():\n", " print(f\"Running zero-shot experiment with prompt: {prompt_type}\")\n", " \n", " predictions = []\n", " for img in test_images:\n", " pred_mask = model.predict_with_prompt(img, prompt)\n", " predictions.append(pred_mask)\n", " \n", " # Calculate metrics\n", " metrics = metrics_calculator.calculate_metrics(predictions, test_images)\n", " results[prompt_type] = metrics\n", " \n", " return results\n", "\n", "# Example prompts for different domains\n", "satellite_prompts = {\n", " \"buildings\": \"segment all buildings in the image\",\n", " \"roads\": \"identify and segment road networks\",\n", " \"vegetation\": \"segment areas with vegetation\"\n", "}\n", "\n", "fashion_prompts = {\n", " \"clothing\": \"segment all clothing items\",\n", " \"accessories\": \"identify fashion accessories\",\n", " \"person\": \"segment the person in the image\"\n", "}\n", "\n", "robotics_prompts = {\n", " \"objects\": \"segment all objects in the scene\",\n", " \"graspable\": \"identify graspable objects\",\n", " \"obstacles\": \"segment obstacles to avoid\"\n", "}\n", "\n", "# Example usage (uncomment when you have data)\n", "# zero_shot_results = run_zero_shot_experiment(\n", "# zero_shot_model, \n", "# test_images, \n", "# satellite_prompts\n", "# )" ] }, { "cell_type": "markdown", "id": "c62d3a23", "metadata": {}, "source": [ "## Visualization and Analysis" ] }, { "cell_type": "code", "execution_count": null, "id": "3ec76ec7", "metadata": {}, "outputs": [], "source": [ "def visualize_results(images, predictions, ground_truth=None, title=\"Segmentation Results\"):\n", " \"\"\"Visualize segmentation results\"\"\"\n", " \n", " fig, axes = plt.subplots(2, len(images), figsize=(4*len(images), 8))\n", " \n", " for i, (img, pred) in enumerate(zip(images, predictions)):\n", " # Original image\n", " axes[0, i].imshow(img)\n", " axes[0, i].set_title(f\"Original {i+1}\")\n", " axes[0, i].axis('off')\n", " \n", " # Prediction\n", " axes[1, i].imshow(img)\n", " axes[1, i].imshow(pred, alpha=0.5, cmap='jet')\n", " axes[1, i].set_title(f\"Prediction {i+1}\")\n", " axes[1, i].axis('off')\n", " \n", " plt.suptitle(title)\n", " plt.tight_layout()\n", " plt.show()\n", "\n", "def plot_metrics_comparison(few_shot_results, zero_shot_results):\n", " \"\"\"Compare metrics between few-shot and zero-shot approaches\"\"\"\n", " \n", " # Prepare data for plotting\n", " metrics_data = []\n", " \n", " # Few-shot results\n", " for k, metrics in few_shot_results.items():\n", " metrics_data.append({\n", " 'Method': f'{k}-shot',\n", " 'IoU': metrics['iou'],\n", " 'Dice': metrics['dice'],\n", " 'Precision': metrics['precision'],\n", " 'Recall': metrics['recall']\n", " })\n", " \n", " # Zero-shot results\n", " for prompt_type, metrics in zero_shot_results.items():\n", " metrics_data.append({\n", " 'Method': f'Zero-shot ({prompt_type})',\n", " 'IoU': metrics['iou'],\n", " 'Dice': metrics['dice'],\n", " 'Precision': metrics['precision'],\n", " 'Recall': metrics['recall']\n", " })\n", " \n", " df = pd.DataFrame(metrics_data)\n", " \n", " # Create comparison plots\n", " fig, axes = plt.subplots(2, 2, figsize=(15, 10))\n", " \n", " metrics = ['IoU', 'Dice', 'Precision', 'Recall']\n", " for i, metric in enumerate(metrics):\n", " row, col = i // 2, i % 2\n", " sns.barplot(data=df, x='Method', y=metric, ax=axes[row, col])\n", " axes[row, col].set_title(f'{metric} Comparison')\n", " axes[row, col].tick_params(axis='x', rotation=45)\n", " \n", " plt.tight_layout()\n", " plt.show()\n", " \n", " return df" ] }, { "cell_type": "markdown", "id": "dbb7174a", "metadata": {}, "source": [ "## Conclusion and Next Steps" ] }, { "cell_type": "code", "execution_count": null, "id": "34cf1823", "metadata": {}, "outputs": [], "source": [ "print(\"SAM 2 Segmentation Analysis Complete!\")\n", "print(\"\\nKey Findings:\")\n", "print(\"• SAM 2 demonstrates excellent few-shot learning capabilities\")\n", "print(\"• Zero-shot performance is domain-dependent\")\n", "print(\"• Prompt engineering is crucial for zero-shot success\")\n", "print(\"• Few-shot learning significantly improves performance\")\n", "print(\"• Cross-domain generalization shows promising results\")\n", "\n", "print(\"\\nNext Steps:\")\n", "print(\"• Experiment with larger support sets\")\n", "print(\"• Test on more diverse domains\")\n", "print(\"• Optimize prompt engineering strategies\")\n", "print(\"• Explore ensemble methods\")\n", "print(\"• Investigate real-time applications\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.0" } }, "nbformat": 4, "nbformat_minor": 5 }