NandiniLokeshReddy commited on
Commit
236b7bb
·
verified ·
1 Parent(s): b6243f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -26
app.py CHANGED
@@ -1,16 +1,14 @@
1
-
2
- import matplotlib.pyplot as plt
3
- import numpy as np
 
 
 
4
  import torch
 
5
  from torchvision.transforms import ToTensor
6
  from PIL import Image
7
- import io
8
  import cv2
9
- import gradio as gr
10
- import os
11
- import requests
12
- import zipfile
13
- import subprocess
14
 
15
  # Ensure the necessary model files are available
16
  def download_file(url, destination):
@@ -18,12 +16,30 @@ def download_file(url, destination):
18
  with open(destination, 'wb') as f:
19
  f.write(response.content)
20
 
21
- # Install the necessary packages
22
- subprocess.run(["pip", "install", "git+https://github.com/facebookresearch/segment-anything.git"])
23
- subprocess.run(["git", "clone", "https://github.com/yformer/EfficientSAM.git"])
 
 
 
 
24
 
25
- # Change directory to EfficientSAM
26
- os.chdir("EfficientSAM")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  from segment_anything.utils.amg import (
29
  batched_mask_to_box,
@@ -47,6 +63,7 @@ def process_small_region(rles):
47
  unchanged = unchanged and not changed
48
  new_masks.append(torch.as_tensor(mask).unsqueeze(0))
49
  scores.append(float(unchanged))
 
50
  masks = torch.cat(new_masks, dim=0)
51
  boxes = batched_mask_to_box(masks)
52
  keep_by_nms = batched_nms(
@@ -137,18 +154,6 @@ def process_image(image):
137
 
138
  return [image, sam_annotated_image, efficient_sam_annotated_image]
139
 
140
- # Download EfficientSAM model
141
- if not os.path.exists("weights/efficient_sam_vits.pt.zip"):
142
- download_file("https://example.com/path/to/efficient_sam_vits.pt.zip", "weights/efficient_sam_vits.pt.zip")
143
-
144
- # Extract EfficientSAM model
145
- with zipfile.ZipFile("weights/efficient_sam_vits.pt.zip", 'r') as zip_ref:
146
- zip_ref.extractall("weights")
147
-
148
- from efficient_sam.build_efficient_sam import build_efficient_sam_vits
149
- efficient_sam_vits_model = build_efficient_sam_vits()
150
- efficient_sam_vits_model.eval()
151
-
152
  # Gradio interface
153
  interface = gr.Interface(
154
  fn=process_image,
@@ -161,3 +166,4 @@ interface = gr.Interface(
161
  interface.launch(debug=True)
162
 
163
 
 
 
1
+ import os
2
+ import subprocess
3
+ import sys
4
+ import requests
5
+ import zipfile
6
+ import gradio as gr
7
  import torch
8
+ import numpy as np
9
  from torchvision.transforms import ToTensor
10
  from PIL import Image
 
11
  import cv2
 
 
 
 
 
12
 
13
  # Ensure the necessary model files are available
14
  def download_file(url, destination):
 
16
  with open(destination, 'wb') as f:
17
  f.write(response.content)
18
 
19
+ # Download SAM model
20
+ if not os.path.exists("weights/sam_vit_h_4b8939.pth"):
21
+ os.makedirs("weights", exist_ok=True)
22
+ download_file("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "weights/sam_vit_h_4b8939.pth")
23
+
24
+ # Add EfficientSAM to Python path
25
+ sys.path.append(os.path.abspath("EfficientSAM-main"))
26
 
27
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
28
+ from efficient_sam.build_efficient_sam_vits import build_efficient_sam_vits
29
+
30
+ # Constants
31
+ DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
32
+ MODEL_TYPE = "vit_h"
33
+ CHECKPOINT_PATH = "weights/sam_vit_h_4b8939.pth"
34
+
35
+ # Load SAM model
36
+ sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
37
+ mask_generator_sam = SamAutomaticMaskGenerator(sam)
38
+
39
+ # Load EfficientSAM model
40
+ with zipfile.ZipFile("weights/efficient_sam_vits.pt.zip", 'r') as zip_ref:
41
+ zip_ref.extractall("weights")
42
+ efficient_sam_vits_model = build_efficient_sam_vits()
43
 
44
  from segment_anything.utils.amg import (
45
  batched_mask_to_box,
 
63
  unchanged = unchanged and not changed
64
  new_masks.append(torch.as_tensor(mask).unsqueeze(0))
65
  scores.append(float(unchanged))
66
+
67
  masks = torch.cat(new_masks, dim=0)
68
  boxes = batched_mask_to_box(masks)
69
  keep_by_nms = batched_nms(
 
154
 
155
  return [image, sam_annotated_image, efficient_sam_annotated_image]
156
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  # Gradio interface
158
  interface = gr.Interface(
159
  fn=process_image,
 
166
  interface.launch(debug=True)
167
 
168
 
169
+