NandiniLokeshReddy commited on
Commit
6c5760f
·
1 Parent(s): 11eb53a
Files changed (1) hide show
  1. app.py +23 -40
app.py CHANGED
@@ -1,15 +1,15 @@
1
 
2
- import os
3
- import subprocess
4
- import sys
5
- import requests
6
- import zipfile
7
- import gradio as gr
8
- import torch
9
  import numpy as np
 
10
  from torchvision.transforms import ToTensor
11
  from PIL import Image
 
12
  import cv2
 
 
 
 
13
 
14
  # Ensure the necessary model files are available
15
  def download_file(url, destination):
@@ -24,38 +24,12 @@ if not os.path.exists("weights/sam_vit_h_4b8939.pth"):
24
 
25
  # Clone EfficientSAM repository if not already cloned
26
  if not os.path.exists("EfficientSAM"):
27
- subprocess.run(["git", "clone", "https://github.com/yformer/EfficientSAM.git"])
28
-
29
- # Add EfficientSAM to Python path
30
- sys.path.append(os.path.abspath("EfficientSAM"))
31
-
32
- # Install dependencies
33
- subprocess.run(["pip", "install", "git+https://github.com/facebookresearch/segment-anything.git"])
34
-
35
- from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
36
- from efficient_sam.build_efficient_sam import build_efficient_sam_vits
37
-
38
- # Constants
39
- DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
40
- MODEL_TYPE = "vit_h"
41
- CHECKPOINT_PATH = "weights/sam_vit_h_4b8939.pth"
42
- EFF_SAM_ZIP_PATH = "weights/efficient_sam_vits.pt.zip"
43
- EFF_SAM_EXTRACT_DIR = "weights/efficient_sam_vits"
44
 
45
- # Load SAM model
46
- sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
47
- mask_generator_sam = SamAutomaticMaskGenerator(sam)
48
 
49
- # Download EfficientSAM weights if not present
50
- if not os.path.exists(EFF_SAM_ZIP_PATH):
51
- download_file("https://example.com/path/to/efficient_sam_vits.pt.zip", EFF_SAM_ZIP_PATH)
52
-
53
- # Extract EfficientSAM weights if not already extracted
54
- if not os.path.exists(EFF_SAM_EXTRACT_DIR):
55
- with zipfile.ZipFile(EFF_SAM_ZIP_PATH, 'r') as zip_ref:
56
- zip_ref.extractall(EFF_SAM_EXTRACT_DIR)
57
-
58
- efficient_sam_vits_model = build_efficient_sam_vits()
59
 
60
  from segment_anything.utils.amg import (
61
  batched_mask_to_box,
@@ -79,7 +53,6 @@ def process_small_region(rles):
79
  unchanged = unchanged and not changed
80
  new_masks.append(torch.as_tensor(mask).unsqueeze(0))
81
  scores.append(float(unchanged))
82
-
83
  masks = torch.cat(new_masks, dim=0)
84
  boxes = batched_mask_to_box(masks)
85
  keep_by_nms = batched_nms(
@@ -170,6 +143,18 @@ def process_image(image):
170
 
171
  return [image, sam_annotated_image, efficient_sam_annotated_image]
172
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  # Gradio interface
174
  interface = gr.Interface(
175
  fn=process_image,
@@ -181,5 +166,3 @@ interface = gr.Interface(
181
 
182
  interface.launch(debug=True)
183
 
184
-
185
-
 
1
 
2
+ import matplotlib.pyplot as plt
 
 
 
 
 
 
3
  import numpy as np
4
+ import torch
5
  from torchvision.transforms import ToTensor
6
  from PIL import Image
7
+ import io
8
  import cv2
9
+ import gradio as gr
10
+ import os
11
+ import requests
12
+ import zipfile
13
 
14
  # Ensure the necessary model files are available
15
  def download_file(url, destination):
 
24
 
25
  # Clone EfficientSAM repository if not already cloned
26
  if not os.path.exists("EfficientSAM"):
27
+ os.makedirs("EfficientSAM", exist_ok=True)
28
+ os.system("git clone https://github.com/yformer/EfficientSAM.git EfficientSAM")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ os.chdir("EfficientSAM")
 
 
31
 
32
+ !pip install git+https://github.com/facebookresearch/segment-anything.git
 
 
 
 
 
 
 
 
 
33
 
34
  from segment_anything.utils.amg import (
35
  batched_mask_to_box,
 
53
  unchanged = unchanged and not changed
54
  new_masks.append(torch.as_tensor(mask).unsqueeze(0))
55
  scores.append(float(unchanged))
 
56
  masks = torch.cat(new_masks, dim=0)
57
  boxes = batched_mask_to_box(masks)
58
  keep_by_nms = batched_nms(
 
143
 
144
  return [image, sam_annotated_image, efficient_sam_annotated_image]
145
 
146
+ # Download EfficientSAM model
147
+ if not os.path.exists("weights/efficient_sam_vits.pt.zip"):
148
+ download_file("https://example.com/path/to/efficient_sam_vits.pt.zip", "weights/efficient_sam_vits.pt.zip")
149
+
150
+ # Extract EfficientSAM model
151
+ with zipfile.ZipFile("weights/efficient_sam_vits.pt.zip", 'r') as zip_ref:
152
+ zip_ref.extractall("weights")
153
+
154
+ from efficient_sam.build_efficient_sam import build_efficient_sam_vits
155
+ efficient_sam_vits_model = build_efficient_sam_vits()
156
+ efficient_sam_vits_model.eval()
157
+
158
  # Gradio interface
159
  interface = gr.Interface(
160
  fn=process_image,
 
166
 
167
  interface.launch(debug=True)
168