NandiniLokeshReddy commited on
Commit
11eb53a
·
1 Parent(s): cad241c

Fix import error

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -39,14 +39,22 @@ from efficient_sam.build_efficient_sam import build_efficient_sam_vits
39
  DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
40
  MODEL_TYPE = "vit_h"
41
  CHECKPOINT_PATH = "weights/sam_vit_h_4b8939.pth"
 
 
42
 
43
  # Load SAM model
44
  sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
45
  mask_generator_sam = SamAutomaticMaskGenerator(sam)
46
 
47
- # Load EfficientSAM model
48
- with zipfile.ZipFile("weights/efficient_sam_vits.pt.zip", 'r') as zip_ref:
49
- zip_ref.extractall("weights")
 
 
 
 
 
 
50
  efficient_sam_vits_model = build_efficient_sam_vits()
51
 
52
  from segment_anything.utils.amg import (
@@ -174,3 +182,4 @@ interface = gr.Interface(
174
  interface.launch(debug=True)
175
 
176
 
 
 
39
  DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
40
  MODEL_TYPE = "vit_h"
41
  CHECKPOINT_PATH = "weights/sam_vit_h_4b8939.pth"
42
+ EFF_SAM_ZIP_PATH = "weights/efficient_sam_vits.pt.zip"
43
+ EFF_SAM_EXTRACT_DIR = "weights/efficient_sam_vits"
44
 
45
  # Load SAM model
46
  sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
47
  mask_generator_sam = SamAutomaticMaskGenerator(sam)
48
 
49
+ # Download EfficientSAM weights if not present
50
+ if not os.path.exists(EFF_SAM_ZIP_PATH):
51
+ download_file("https://example.com/path/to/efficient_sam_vits.pt.zip", EFF_SAM_ZIP_PATH)
52
+
53
+ # Extract EfficientSAM weights if not already extracted
54
+ if not os.path.exists(EFF_SAM_EXTRACT_DIR):
55
+ with zipfile.ZipFile(EFF_SAM_ZIP_PATH, 'r') as zip_ref:
56
+ zip_ref.extractall(EFF_SAM_EXTRACT_DIR)
57
+
58
  efficient_sam_vits_model = build_efficient_sam_vits()
59
 
60
  from segment_anything.utils.amg import (
 
182
  interface.launch(debug=True)
183
 
184
 
185
+