File size: 1,926 Bytes
6b8e3c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from transformers import DPTImageProcessor, DPTForDepthEstimation
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import gradio as gr
import supervision as sv
import torch
import numpy as np
from PIL import Image
import requests

class DepthPredictor:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.processor = DPTImageProcessor.from_pretrained("Intel/dpt-large").to(self.device)
        self.model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(self.device)
        self.model.eval()

    def predict(self, image):
        # prepare image for the model
        inputs = self.processor(images=image, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model(**inputs)
            predicted_depth = outputs.predicted_depth

        # interpolate to original size
        prediction = torch.nn.functional.interpolate(
            predicted_depth.unsqueeze(1),
            size=image.size[::-1],
            mode="bicubic",
            align_corners=False,
        )

        # visualize the prediction
        output = prediction.squeeze().cpu().numpy()
        formatted = (output * 255 / np.max(output)).astype("uint8")
        depth = Image.fromarray(formatted)
        return depth

    




class sam_inference:
    def __init__(self):
        MODEL_TYPE = "vit_b"
        checkpoint = "sam_vit_b_01ec64.pth"
        sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint)
        self.mask_generator = SamAutomaticMaskGenerator(sam)
    
    def predict(self, image):
        sam_result = self.mask_generator.generate(image)
        mask_annotator = sv.MaskAnnotator()
        detections = sv.Detections.from_sam(sam_result=sam_result)
        annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
        return [annotated_image]