Commit
·
11eb53a
1
Parent(s):
cad241c
Fix import error
Browse files
app.py
CHANGED
|
@@ -39,14 +39,22 @@ from efficient_sam.build_efficient_sam import build_efficient_sam_vits
|
|
| 39 |
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 40 |
MODEL_TYPE = "vit_h"
|
| 41 |
CHECKPOINT_PATH = "weights/sam_vit_h_4b8939.pth"
|
|
|
|
|
|
|
| 42 |
|
| 43 |
# Load SAM model
|
| 44 |
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
|
| 45 |
mask_generator_sam = SamAutomaticMaskGenerator(sam)
|
| 46 |
|
| 47 |
-
#
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
efficient_sam_vits_model = build_efficient_sam_vits()
|
| 51 |
|
| 52 |
from segment_anything.utils.amg import (
|
|
@@ -174,3 +182,4 @@ interface = gr.Interface(
|
|
| 174 |
interface.launch(debug=True)
|
| 175 |
|
| 176 |
|
|
|
|
|
|
| 39 |
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 40 |
MODEL_TYPE = "vit_h"
|
| 41 |
CHECKPOINT_PATH = "weights/sam_vit_h_4b8939.pth"
|
| 42 |
+
EFF_SAM_ZIP_PATH = "weights/efficient_sam_vits.pt.zip"
|
| 43 |
+
EFF_SAM_EXTRACT_DIR = "weights/efficient_sam_vits"
|
| 44 |
|
| 45 |
# Load SAM model
|
| 46 |
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
|
| 47 |
mask_generator_sam = SamAutomaticMaskGenerator(sam)
|
| 48 |
|
| 49 |
+
# Download EfficientSAM weights if not present
|
| 50 |
+
if not os.path.exists(EFF_SAM_ZIP_PATH):
|
| 51 |
+
download_file("https://example.com/path/to/efficient_sam_vits.pt.zip", EFF_SAM_ZIP_PATH)
|
| 52 |
+
|
| 53 |
+
# Extract EfficientSAM weights if not already extracted
|
| 54 |
+
if not os.path.exists(EFF_SAM_EXTRACT_DIR):
|
| 55 |
+
with zipfile.ZipFile(EFF_SAM_ZIP_PATH, 'r') as zip_ref:
|
| 56 |
+
zip_ref.extractall(EFF_SAM_EXTRACT_DIR)
|
| 57 |
+
|
| 58 |
efficient_sam_vits_model = build_efficient_sam_vits()
|
| 59 |
|
| 60 |
from segment_anything.utils.amg import (
|
|
|
|
| 182 |
interface.launch(debug=True)
|
| 183 |
|
| 184 |
|
| 185 |
+
|