Commit
·
6c5760f
1
Parent(s):
11eb53a
Try1
Browse files
app.py
CHANGED
@@ -1,15 +1,15 @@
|
|
1 |
|
2 |
-
import
|
3 |
-
import subprocess
|
4 |
-
import sys
|
5 |
-
import requests
|
6 |
-
import zipfile
|
7 |
-
import gradio as gr
|
8 |
-
import torch
|
9 |
import numpy as np
|
|
|
10 |
from torchvision.transforms import ToTensor
|
11 |
from PIL import Image
|
|
|
12 |
import cv2
|
|
|
|
|
|
|
|
|
13 |
|
14 |
# Ensure the necessary model files are available
|
15 |
def download_file(url, destination):
|
@@ -24,38 +24,12 @@ if not os.path.exists("weights/sam_vit_h_4b8939.pth"):
|
|
24 |
|
25 |
# Clone EfficientSAM repository if not already cloned
|
26 |
if not os.path.exists("EfficientSAM"):
|
27 |
-
|
28 |
-
|
29 |
-
# Add EfficientSAM to Python path
|
30 |
-
sys.path.append(os.path.abspath("EfficientSAM"))
|
31 |
-
|
32 |
-
# Install dependencies
|
33 |
-
subprocess.run(["pip", "install", "git+https://github.com/facebookresearch/segment-anything.git"])
|
34 |
-
|
35 |
-
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
36 |
-
from efficient_sam.build_efficient_sam import build_efficient_sam_vits
|
37 |
-
|
38 |
-
# Constants
|
39 |
-
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
40 |
-
MODEL_TYPE = "vit_h"
|
41 |
-
CHECKPOINT_PATH = "weights/sam_vit_h_4b8939.pth"
|
42 |
-
EFF_SAM_ZIP_PATH = "weights/efficient_sam_vits.pt.zip"
|
43 |
-
EFF_SAM_EXTRACT_DIR = "weights/efficient_sam_vits"
|
44 |
|
45 |
-
|
46 |
-
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
|
47 |
-
mask_generator_sam = SamAutomaticMaskGenerator(sam)
|
48 |
|
49 |
-
|
50 |
-
if not os.path.exists(EFF_SAM_ZIP_PATH):
|
51 |
-
download_file("https://example.com/path/to/efficient_sam_vits.pt.zip", EFF_SAM_ZIP_PATH)
|
52 |
-
|
53 |
-
# Extract EfficientSAM weights if not already extracted
|
54 |
-
if not os.path.exists(EFF_SAM_EXTRACT_DIR):
|
55 |
-
with zipfile.ZipFile(EFF_SAM_ZIP_PATH, 'r') as zip_ref:
|
56 |
-
zip_ref.extractall(EFF_SAM_EXTRACT_DIR)
|
57 |
-
|
58 |
-
efficient_sam_vits_model = build_efficient_sam_vits()
|
59 |
|
60 |
from segment_anything.utils.amg import (
|
61 |
batched_mask_to_box,
|
@@ -79,7 +53,6 @@ def process_small_region(rles):
|
|
79 |
unchanged = unchanged and not changed
|
80 |
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
81 |
scores.append(float(unchanged))
|
82 |
-
|
83 |
masks = torch.cat(new_masks, dim=0)
|
84 |
boxes = batched_mask_to_box(masks)
|
85 |
keep_by_nms = batched_nms(
|
@@ -170,6 +143,18 @@ def process_image(image):
|
|
170 |
|
171 |
return [image, sam_annotated_image, efficient_sam_annotated_image]
|
172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
# Gradio interface
|
174 |
interface = gr.Interface(
|
175 |
fn=process_image,
|
@@ -181,5 +166,3 @@ interface = gr.Interface(
|
|
181 |
|
182 |
interface.launch(debug=True)
|
183 |
|
184 |
-
|
185 |
-
|
|
|
1 |
|
2 |
+
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import numpy as np
|
4 |
+
import torch
|
5 |
from torchvision.transforms import ToTensor
|
6 |
from PIL import Image
|
7 |
+
import io
|
8 |
import cv2
|
9 |
+
import gradio as gr
|
10 |
+
import os
|
11 |
+
import requests
|
12 |
+
import zipfile
|
13 |
|
14 |
# Ensure the necessary model files are available
|
15 |
def download_file(url, destination):
|
|
|
24 |
|
25 |
# Clone EfficientSAM repository if not already cloned
|
26 |
if not os.path.exists("EfficientSAM"):
|
27 |
+
os.makedirs("EfficientSAM", exist_ok=True)
|
28 |
+
os.system("git clone https://github.com/yformer/EfficientSAM.git EfficientSAM")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
+
os.chdir("EfficientSAM")
|
|
|
|
|
31 |
|
32 |
+
!pip install git+https://github.com/facebookresearch/segment-anything.git
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
from segment_anything.utils.amg import (
|
35 |
batched_mask_to_box,
|
|
|
53 |
unchanged = unchanged and not changed
|
54 |
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
55 |
scores.append(float(unchanged))
|
|
|
56 |
masks = torch.cat(new_masks, dim=0)
|
57 |
boxes = batched_mask_to_box(masks)
|
58 |
keep_by_nms = batched_nms(
|
|
|
143 |
|
144 |
return [image, sam_annotated_image, efficient_sam_annotated_image]
|
145 |
|
146 |
+
# Download EfficientSAM model
|
147 |
+
if not os.path.exists("weights/efficient_sam_vits.pt.zip"):
|
148 |
+
download_file("https://example.com/path/to/efficient_sam_vits.pt.zip", "weights/efficient_sam_vits.pt.zip")
|
149 |
+
|
150 |
+
# Extract EfficientSAM model
|
151 |
+
with zipfile.ZipFile("weights/efficient_sam_vits.pt.zip", 'r') as zip_ref:
|
152 |
+
zip_ref.extractall("weights")
|
153 |
+
|
154 |
+
from efficient_sam.build_efficient_sam import build_efficient_sam_vits
|
155 |
+
efficient_sam_vits_model = build_efficient_sam_vits()
|
156 |
+
efficient_sam_vits_model.eval()
|
157 |
+
|
158 |
# Gradio interface
|
159 |
interface = gr.Interface(
|
160 |
fn=process_image,
|
|
|
166 |
|
167 |
interface.launch(debug=True)
|
168 |
|
|
|
|