File size: 2,020 Bytes
6b8e3c4
 
 
 
 
 
 
 
 
 
 
 
9780d7b
1d564b4
6b8e3c4
 
 
 
9780d7b
 
 
6b8e3c4
9780d7b
6b8e3c4
9780d7b
 
 
 
 
 
 
 
 
 
 
62a1e0a
 
9780d7b
6b8e3c4
 
 
67eca52
6b8e3c4
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from transformers import DPTImageProcessor, DPTForDepthEstimation
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import gradio as gr
import supervision as sv
import torch
import numpy as np
from PIL import Image
import requests

class DepthPredictor:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
        self.model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
        self.model.eval()

    def predict(self, image):
        # prepare image for the model
        encoding = self.feature_extractor(image, return_tensors="pt")
        
        # forward pass
        with torch.no_grad():
            outputs = self.model(**encoding)
            predicted_depth = outputs.predicted_depth
            
            # interpolate to original size
            prediction = torch.nn.functional.interpolate(
                                predicted_depth.unsqueeze(1),
                                size=image.size[::-1],
                                mode="bicubic",
                                align_corners=False,
                        ).squeeze()
            
        output = prediction.cpu().numpy()
        formatted = (output * 255 / np.max(output)).astype('uint8')
        #img = Image.fromarray(formatted)
        return formatted
    



class SegmentPredictor:
    def __init__(self):
        MODEL_TYPE = "vit_b"
        checkpoint = "sam_vit_b_01ec64.pth"
        sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint)
        self.mask_generator = SamAutomaticMaskGenerator(sam)
    
    def predict(self, image):
        sam_result = self.mask_generator.generate(image)
        mask_annotator = sv.MaskAnnotator()
        detections = sv.Detections.from_sam(sam_result=sam_result)
        annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
        return [annotated_image]