File size: 2,438 Bytes
0aa7f8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import hf_hub_download

# custom installation from this PR: https://github.com/huggingface/transformers/pull/34583
# !pip install git+https://github.com/geetu040/transformers.git@depth-pro-projects#egg=transformers
from transformers import DepthProConfig, DepthProImageProcessorFast, DepthProForDepthEstimation

# initialize model
config = DepthProConfig(use_fov_model=False)
model = DepthProForDepthEstimation(config)
features = config.fusion_hidden_size
semantic_classifier_dropout = 0.1
num_labels = 1
model.head.head = nn.Sequential(
    nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
    nn.BatchNorm2d(features),
    nn.ReLU(),
    nn.Dropout(semantic_classifier_dropout),
    nn.Conv2d(features, features, kernel_size=1),
    nn.ConvTranspose2d(features, num_labels, kernel_size=2, stride=2, padding=0, bias=True),
)

# load weights
weights_path = hf_hub_download(repo_id="geetu040/DepthPro_Segmentation_Human", filename="model_weights.pth")
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'), weights_only=True))

# load to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# load image processor
image_processor = DepthProImageProcessorFast()

def predict(image):
	# inference

	image = image.convert("RGB")

	# prepare image for the model
	inputs = image_processor(images=image, return_tensors="pt")
	inputs = {k: v.to(device) for k, v in inputs.items()}

	# inference
	with torch.no_grad():
		output = model(**inputs)

	# convert tensors to PIL.Image
	output = output[0]                 # get output logits
	output = F.interpolate(
		output.unsqueeze(0),
		size=(image.height, image.width)
	)                                  # interpolate to match size
	output = output.squeeze()          # get first and only batch and channel
	output = output.sigmoid()          # apply sigmoid for binary segmentation
	output = (output > 0.5).float()    # threshold to create binary mask
	output = output.cpu()              # unload from cuda if used
	output = output * 255              # convert [0, 1] to [0, 255]
	output = output.numpy()            # convert to numpy
	output = output.astype('uint8')    # convert to PIL.Image compatible format
	output = Image.fromarray(output)   # create PIL.Image object

	return output