File size: 4,831 Bytes
12fa055 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
#!/usr/bin/env python3
"""
Download SAM 2 Model Script
This script downloads the SAM 2 model checkpoints and sets up the environment
for few-shot and zero-shot segmentation experiments.
"""
import os
import sys
import requests
import zipfile
from pathlib import Path
import argparse
from tqdm import tqdm
def download_file(url: str, destination: str, chunk_size: int = 8192):
"""Download a file with progress bar."""
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
with open(destination, 'wb') as file, tqdm(
desc=os.path.basename(destination),
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as pbar:
for data in response.iter_content(chunk_size=chunk_size):
size = file.write(data)
pbar.update(size)
def setup_sam2_environment():
"""Set up SAM 2 environment and download checkpoints."""
print("Setting up SAM 2 environment...")
# Create directories
os.makedirs("models/checkpoints", exist_ok=True)
os.makedirs("data", exist_ok=True)
os.makedirs("results", exist_ok=True)
# SAM 2 model URLs (these are example URLs - replace with actual SAM 2 URLs)
sam2_urls = {
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything_2/sam2_h.pth",
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything_2/sam2_l.pth",
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything_2/sam2_b.pth"
}
# Download SAM 2 checkpoints
for model_name, url in sam2_urls.items():
checkpoint_path = f"models/checkpoints/sam2_{model_name}.pth"
if not os.path.exists(checkpoint_path):
print(f"Downloading SAM 2 {model_name} checkpoint...")
try:
download_file(url, checkpoint_path)
print(f"Successfully downloaded {model_name} checkpoint")
except Exception as e:
print(f"Failed to download {model_name} checkpoint: {e}")
print("Please download manually from the SAM 2 repository")
else:
print(f"SAM 2 {model_name} checkpoint already exists")
# Create symbolic links for easier access
if not os.path.exists("sam2_checkpoint"):
try:
os.symlink("models/checkpoints/sam2_vit_h.pth", "sam2_checkpoint")
print("Created symbolic link: sam2_checkpoint -> models/checkpoints/sam2_vit_h.pth")
except:
print("Could not create symbolic link (this is normal on Windows)")
def install_dependencies():
"""Install required dependencies."""
print("Installing dependencies...")
# Install from requirements.txt
os.system("pip install -r requirements.txt")
# Install SAM 2 specifically
print("Installing SAM 2...")
os.system("pip install git+https://github.com/facebookresearch/segment-anything-2.git")
# Install CLIP
print("Installing CLIP...")
os.system("pip install git+https://github.com/openai/CLIP.git")
def create_demo_data():
"""Create demo data for testing."""
print("Creating demo data...")
# Create demo directories
demo_dirs = [
"data/satellite_demo",
"data/fashion_demo",
"data/robotics_demo"
]
for demo_dir in demo_dirs:
os.makedirs(f"{demo_dir}/images", exist_ok=True)
os.makedirs(f"{demo_dir}/masks", exist_ok=True)
print("Demo data directories created. Run experiments to generate dummy data.")
def main():
parser = argparse.ArgumentParser(description="Set up SAM 2 environment")
parser.add_argument("--skip-download", action="store_true",
help="Skip downloading SAM 2 checkpoints")
parser.add_argument("--skip-install", action="store_true",
help="Skip installing dependencies")
parser.add_argument("--demo-only", action="store_true",
help="Only create demo data directories")
args = parser.parse_args()
if args.demo_only:
create_demo_data()
return
if not args.skip_install:
install_dependencies()
if not args.skip_download:
setup_sam2_environment()
create_demo_data()
print("\nSetup complete!")
print("\nNext steps:")
print("1. Run few-shot satellite experiment:")
print(" python experiments/few_shot_satellite.py --sam2_checkpoint sam2_checkpoint --data_dir data/satellite_demo")
print("\n2. Run zero-shot fashion experiment:")
print(" python experiments/zero_shot_fashion.py --sam2_checkpoint sam2_checkpoint --data_dir data/fashion_demo")
print("\n3. Check the results/ directory for experiment outputs")
if __name__ == "__main__":
main() |