|
|
|
""" |
|
Download SAM 2 Model Script |
|
|
|
This script downloads the SAM 2 model checkpoints and sets up the environment |
|
for few-shot and zero-shot segmentation experiments. |
|
""" |
|
|
|
import os |
|
import sys |
|
import requests |
|
import zipfile |
|
from pathlib import Path |
|
import argparse |
|
from tqdm import tqdm |
|
|
|
|
|
def download_file(url: str, destination: str, chunk_size: int = 8192): |
|
"""Download a file with progress bar.""" |
|
response = requests.get(url, stream=True) |
|
total_size = int(response.headers.get('content-length', 0)) |
|
|
|
with open(destination, 'wb') as file, tqdm( |
|
desc=os.path.basename(destination), |
|
total=total_size, |
|
unit='iB', |
|
unit_scale=True, |
|
unit_divisor=1024, |
|
) as pbar: |
|
for data in response.iter_content(chunk_size=chunk_size): |
|
size = file.write(data) |
|
pbar.update(size) |
|
|
|
|
|
def setup_sam2_environment(): |
|
"""Set up SAM 2 environment and download checkpoints.""" |
|
print("Setting up SAM 2 environment...") |
|
|
|
|
|
os.makedirs("models/checkpoints", exist_ok=True) |
|
os.makedirs("data", exist_ok=True) |
|
os.makedirs("results", exist_ok=True) |
|
|
|
|
|
sam2_urls = { |
|
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything_2/sam2_h.pth", |
|
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything_2/sam2_l.pth", |
|
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything_2/sam2_b.pth" |
|
} |
|
|
|
|
|
for model_name, url in sam2_urls.items(): |
|
checkpoint_path = f"models/checkpoints/sam2_{model_name}.pth" |
|
|
|
if not os.path.exists(checkpoint_path): |
|
print(f"Downloading SAM 2 {model_name} checkpoint...") |
|
try: |
|
download_file(url, checkpoint_path) |
|
print(f"Successfully downloaded {model_name} checkpoint") |
|
except Exception as e: |
|
print(f"Failed to download {model_name} checkpoint: {e}") |
|
print("Please download manually from the SAM 2 repository") |
|
else: |
|
print(f"SAM 2 {model_name} checkpoint already exists") |
|
|
|
|
|
if not os.path.exists("sam2_checkpoint"): |
|
try: |
|
os.symlink("models/checkpoints/sam2_vit_h.pth", "sam2_checkpoint") |
|
print("Created symbolic link: sam2_checkpoint -> models/checkpoints/sam2_vit_h.pth") |
|
except: |
|
print("Could not create symbolic link (this is normal on Windows)") |
|
|
|
|
|
def install_dependencies(): |
|
"""Install required dependencies.""" |
|
print("Installing dependencies...") |
|
|
|
|
|
os.system("pip install -r requirements.txt") |
|
|
|
|
|
print("Installing SAM 2...") |
|
os.system("pip install git+https://github.com/facebookresearch/segment-anything-2.git") |
|
|
|
|
|
print("Installing CLIP...") |
|
os.system("pip install git+https://github.com/openai/CLIP.git") |
|
|
|
|
|
def create_demo_data(): |
|
"""Create demo data for testing.""" |
|
print("Creating demo data...") |
|
|
|
|
|
demo_dirs = [ |
|
"data/satellite_demo", |
|
"data/fashion_demo", |
|
"data/robotics_demo" |
|
] |
|
|
|
for demo_dir in demo_dirs: |
|
os.makedirs(f"{demo_dir}/images", exist_ok=True) |
|
os.makedirs(f"{demo_dir}/masks", exist_ok=True) |
|
|
|
print("Demo data directories created. Run experiments to generate dummy data.") |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Set up SAM 2 environment") |
|
parser.add_argument("--skip-download", action="store_true", |
|
help="Skip downloading SAM 2 checkpoints") |
|
parser.add_argument("--skip-install", action="store_true", |
|
help="Skip installing dependencies") |
|
parser.add_argument("--demo-only", action="store_true", |
|
help="Only create demo data directories") |
|
|
|
args = parser.parse_args() |
|
|
|
if args.demo_only: |
|
create_demo_data() |
|
return |
|
|
|
if not args.skip_install: |
|
install_dependencies() |
|
|
|
if not args.skip_download: |
|
setup_sam2_environment() |
|
|
|
create_demo_data() |
|
|
|
print("\nSetup complete!") |
|
print("\nNext steps:") |
|
print("1. Run few-shot satellite experiment:") |
|
print(" python experiments/few_shot_satellite.py --sam2_checkpoint sam2_checkpoint --data_dir data/satellite_demo") |
|
print("\n2. Run zero-shot fashion experiment:") |
|
print(" python experiments/zero_shot_fashion.py --sam2_checkpoint sam2_checkpoint --data_dir data/fashion_demo") |
|
print("\n3. Check the results/ directory for experiment outputs") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |