File size: 4,831 Bytes
12fa055
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#!/usr/bin/env python3
"""
Download SAM 2 Model Script

This script downloads the SAM 2 model checkpoints and sets up the environment
for few-shot and zero-shot segmentation experiments.
"""

import os
import sys
import requests
import zipfile
from pathlib import Path
import argparse
from tqdm import tqdm


def download_file(url: str, destination: str, chunk_size: int = 8192):
    """Download a file with progress bar."""
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    
    with open(destination, 'wb') as file, tqdm(
        desc=os.path.basename(destination),
        total=total_size,
        unit='iB',
        unit_scale=True,
        unit_divisor=1024,
    ) as pbar:
        for data in response.iter_content(chunk_size=chunk_size):
            size = file.write(data)
            pbar.update(size)


def setup_sam2_environment():
    """Set up SAM 2 environment and download checkpoints."""
    print("Setting up SAM 2 environment...")
    
    # Create directories
    os.makedirs("models/checkpoints", exist_ok=True)
    os.makedirs("data", exist_ok=True)
    os.makedirs("results", exist_ok=True)
    
    # SAM 2 model URLs (these are example URLs - replace with actual SAM 2 URLs)
    sam2_urls = {
        "vit_h": "https://dl.fbaipublicfiles.com/segment_anything_2/sam2_h.pth",
        "vit_l": "https://dl.fbaipublicfiles.com/segment_anything_2/sam2_l.pth",
        "vit_b": "https://dl.fbaipublicfiles.com/segment_anything_2/sam2_b.pth"
    }
    
    # Download SAM 2 checkpoints
    for model_name, url in sam2_urls.items():
        checkpoint_path = f"models/checkpoints/sam2_{model_name}.pth"
        
        if not os.path.exists(checkpoint_path):
            print(f"Downloading SAM 2 {model_name} checkpoint...")
            try:
                download_file(url, checkpoint_path)
                print(f"Successfully downloaded {model_name} checkpoint")
            except Exception as e:
                print(f"Failed to download {model_name} checkpoint: {e}")
                print("Please download manually from the SAM 2 repository")
        else:
            print(f"SAM 2 {model_name} checkpoint already exists")
    
    # Create symbolic links for easier access
    if not os.path.exists("sam2_checkpoint"):
        try:
            os.symlink("models/checkpoints/sam2_vit_h.pth", "sam2_checkpoint")
            print("Created symbolic link: sam2_checkpoint -> models/checkpoints/sam2_vit_h.pth")
        except:
            print("Could not create symbolic link (this is normal on Windows)")


def install_dependencies():
    """Install required dependencies."""
    print("Installing dependencies...")
    
    # Install from requirements.txt
    os.system("pip install -r requirements.txt")
    
    # Install SAM 2 specifically
    print("Installing SAM 2...")
    os.system("pip install git+https://github.com/facebookresearch/segment-anything-2.git")
    
    # Install CLIP
    print("Installing CLIP...")
    os.system("pip install git+https://github.com/openai/CLIP.git")


def create_demo_data():
    """Create demo data for testing."""
    print("Creating demo data...")
    
    # Create demo directories
    demo_dirs = [
        "data/satellite_demo",
        "data/fashion_demo", 
        "data/robotics_demo"
    ]
    
    for demo_dir in demo_dirs:
        os.makedirs(f"{demo_dir}/images", exist_ok=True)
        os.makedirs(f"{demo_dir}/masks", exist_ok=True)
    
    print("Demo data directories created. Run experiments to generate dummy data.")


def main():
    parser = argparse.ArgumentParser(description="Set up SAM 2 environment")
    parser.add_argument("--skip-download", action="store_true", 
                       help="Skip downloading SAM 2 checkpoints")
    parser.add_argument("--skip-install", action="store_true",
                       help="Skip installing dependencies")
    parser.add_argument("--demo-only", action="store_true",
                       help="Only create demo data directories")
    
    args = parser.parse_args()
    
    if args.demo_only:
        create_demo_data()
        return
    
    if not args.skip_install:
        install_dependencies()
    
    if not args.skip_download:
        setup_sam2_environment()
    
    create_demo_data()
    
    print("\nSetup complete!")
    print("\nNext steps:")
    print("1. Run few-shot satellite experiment:")
    print("   python experiments/few_shot_satellite.py --sam2_checkpoint sam2_checkpoint --data_dir data/satellite_demo")
    print("\n2. Run zero-shot fashion experiment:")
    print("   python experiments/zero_shot_fashion.py --sam2_checkpoint sam2_checkpoint --data_dir data/fashion_demo")
    print("\n3. Check the results/ directory for experiment outputs")


if __name__ == "__main__":
    main()