Vijish commited on
Commit
dc30841
·
1 Parent(s): e4fdcb5

Upload 2 files

Browse files
Files changed (2) hide show
  1. bg.py +149 -0
  2. handler.py +140 -0
bg.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ from PIL import Image
4
+ import numpy as np
5
+ import torch
6
+ from torch.autograd import Variable
7
+ from torchvision import transforms
8
+ import torch.nn.functional as F
9
+ import matplotlib.pyplot as plt
10
+ import warnings
11
+ import random
12
+ import tempfile
13
+
14
+
15
+ warnings.filterwarnings("ignore")
16
+
17
+ os.system("git clone https://github.com/xuebinqin/DIS")
18
+ os.system("mv DIS/IS-Net/* .")
19
+
20
+ # project imports
21
+ from data_loader_cache import normalize, im_reader, im_preprocess
22
+ from models import *
23
+
24
+ #Helpers
25
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
26
+
27
+
28
+ class GOSNormalize(object):
29
+ '''
30
+ Normalize the Image using torch.transforms
31
+ '''
32
+ def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
33
+ self.mean = mean
34
+ self.std = std
35
+
36
+ def __call__(self,image):
37
+ image = normalize(image,self.mean,self.std)
38
+ return image
39
+
40
+
41
+ transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])
42
+
43
+ def load_image(im_path, hypar):
44
+ im = im_reader(im_path)
45
+ im, im_shp = im_preprocess(im, hypar["cache_size"])
46
+ im = torch.divide(im,255.0)
47
+ shape = torch.from_numpy(np.array(im_shp))
48
+ return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape
49
+
50
+
51
+ def build_model(hypar,device):
52
+ net = hypar["model"]#GOSNETINC(3,1)
53
+
54
+ # convert to half precision
55
+ if(hypar["model_digit"]=="half"):
56
+ net.half()
57
+ for layer in net.modules():
58
+ if isinstance(layer, nn.BatchNorm2d):
59
+ layer.float()
60
+
61
+ net.to(device)
62
+
63
+ if(hypar["restore_model"]!=""):
64
+ net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"], map_location=device))
65
+ net.to(device)
66
+ net.eval()
67
+ return net
68
+
69
+
70
+ def predict(net, inputs_val, shapes_val, hypar, device):
71
+ '''
72
+ Given an Image, predict the mask
73
+ '''
74
+ net.eval()
75
+
76
+ if(hypar["model_digit"]=="full"):
77
+ inputs_val = inputs_val.type(torch.FloatTensor)
78
+ else:
79
+ inputs_val = inputs_val.type(torch.HalfTensor)
80
+
81
+
82
+ inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable
83
+
84
+ ds_val = net(inputs_val_v)[0] # list of 6 results
85
+
86
+ pred_val = ds_val[0][0,:,:,:] # B x 1 x H x W # we want the first one which is the most accurate prediction
87
+
88
+ ## recover the prediction spatial size to the orignal image size
89
+ pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))
90
+
91
+ ma = torch.max(pred_val)
92
+ mi = torch.min(pred_val)
93
+ pred_val = (pred_val-mi)/(ma-mi) # max = 1
94
+
95
+ if device == 'cuda': torch.cuda.empty_cache()
96
+ return (pred_val.detach().cpu().numpy()*255).astype(np.uint8) # it is the mask we need
97
+
98
+ # Set Parameters
99
+ hypar = {} # paramters for inferencing
100
+
101
+
102
+ hypar["model_path"] ="./saved_models" ## load trained weights from this path
103
+ hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights
104
+ hypar["interm_sup"] = False ## indicate if activate intermediate feature supervision
105
+
106
+ ## choose floating point accuracy --
107
+ hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number
108
+ hypar["seed"] = 0
109
+
110
+ hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size
111
+
112
+ ## data augmentation parameters ---
113
+ hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
114
+ hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
115
+
116
+ hypar["model"] = ISNetDIS()
117
+
118
+ # Build Model
119
+ net = build_model(hypar, device)
120
+
121
+
122
+
123
+ def inference(image: Image):
124
+ # Save the image to a temporary file
125
+ with tempfile.NamedTemporaryFile(suffix='.jpg') as temp:
126
+ image.save(temp.name)
127
+ image_path = temp.name
128
+
129
+ image_tensor, orig_size = load_image(image_path, hypar)
130
+ mask = predict(net, image_tensor, orig_size, hypar, device)
131
+
132
+ pil_mask = Image.fromarray(mask).convert('L')
133
+ im_rgb = image.convert("RGBA")
134
+
135
+ im_rgba = im_rgb.copy()
136
+ im_rgba.putalpha(pil_mask)
137
+ output_path = "output.png"
138
+ im_rgba.save(output_path, format="PNG", mode="RGBA")
139
+
140
+ return im_rgba
141
+
142
+ def paste_transparent_image(output_path, transparent_path):
143
+ transparent_image = transparent_path.resize(output_path.size)
144
+
145
+ result_image = Image.alpha_composite(output_path.convert('RGBA'), transparent_image)
146
+
147
+ return result_image
148
+
149
+
handler.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import base64
3
+ from PIL import Image
4
+ from io import BytesIO
5
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
6
+ import torch
7
+ from bg import inference
8
+ from bg import paste_transparent_image
9
+
10
+
11
+
12
+
13
+ import numpy as np
14
+ import cv2
15
+ import controlnet_hinter
16
+
17
+ # set device
18
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+ if device.type != 'cuda':
20
+ raise ValueError("need to run on GPU")
21
+ # set mixed precision dtype
22
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
23
+
24
+ # controlnet mapping for controlnet id and control hinter
25
+ CONTROLNET_MAPPING = {
26
+ "canny_edge": {
27
+ "model_id": "lllyasviel/sd-controlnet-canny",
28
+ "hinter": controlnet_hinter.hint_canny
29
+ },
30
+ "pose": {
31
+ "model_id": "lllyasviel/sd-controlnet-openpose",
32
+ "hinter": controlnet_hinter.hint_openpose
33
+ },
34
+ "depth": {
35
+ "model_id": "lllyasviel/sd-controlnet-depth",
36
+ "hinter": controlnet_hinter.hint_depth
37
+ },
38
+ "scribble": {
39
+ "model_id": "lllyasviel/sd-controlnet-scribble",
40
+ "hinter": controlnet_hinter.hint_scribble,
41
+ },
42
+ "segmentation": {
43
+ "model_id": "lllyasviel/sd-controlnet-seg",
44
+ "hinter": controlnet_hinter.hint_segmentation,
45
+ },
46
+ "normal": {
47
+ "model_id": "lllyasviel/sd-controlnet-normal",
48
+ "hinter": controlnet_hinter.hint_normal,
49
+ },
50
+ "hed": {
51
+ "model_id": "lllyasviel/sd-controlnet-hed",
52
+ "hinter": controlnet_hinter.hint_hed,
53
+ },
54
+ "hough": {
55
+ "model_id": "lllyasviel/sd-controlnet-mlsd",
56
+ "hinter": controlnet_hinter.hint_hough,
57
+ }
58
+ }
59
+
60
+
61
+ class EndpointHandler():
62
+ def __init__(self, path=""):
63
+ # define default controlnet id and load controlnet
64
+ self.control_type = "normal"
65
+ self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],torch_dtype=dtype).to(device)
66
+
67
+ # Load StableDiffusionControlNetPipeline
68
+ self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
69
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
70
+ controlnet=self.controlnet,
71
+ torch_dtype=dtype,
72
+ safety_checker=None).to(device)
73
+ # Define Generator with seed
74
+ self.generator = torch.Generator(device="cpu").manual_seed(3)
75
+
76
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
77
+ """
78
+ :param data: A dictionary contains `inputs` and optional `image` field.
79
+ :return: A dictionary with `image` field contains image in base64.
80
+ """
81
+ prompt = data.pop("inputs", None)
82
+ image = data.pop("image", None)
83
+ controlnet_type = data.pop("controlnet_type", None)
84
+
85
+ # Check if neither prompt nor image is provided
86
+ if prompt is None and image is None:
87
+ return {"error": "Please provide a prompt and base64 encoded image."}
88
+
89
+ # Check if a new controlnet is provided
90
+ if controlnet_type is not None and controlnet_type != self.control_type:
91
+ print(f"changing controlnet from {self.control_type} to {controlnet_type} using {CONTROLNET_MAPPING[controlnet_type]['model_id']} model")
92
+ self.control_type = controlnet_type
93
+ self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],
94
+ torch_dtype=dtype).to(device)
95
+ self.pipe.controlnet = self.controlnet
96
+
97
+
98
+ # hyperparamters
99
+ negatice_prompt = data.pop("negative_prompt", None)
100
+ num_inference_steps = data.pop("num_inference_steps", 30)
101
+ guidance_scale = data.pop("guidance_scale", 7.5)
102
+ negative_prompt = data.pop("negative_prompt", None)
103
+ height = data.pop("height", None)
104
+ width = data.pop("width", None)
105
+ controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1.0)
106
+
107
+ # process image
108
+
109
+ image = self.decode_base64_image(image)
110
+ image = inference(image)
111
+ control_image = CONTROLNET_MAPPING[self.control_type]["hinter"](image)
112
+
113
+ # run inference pipeline
114
+ out = self.pipe(
115
+ prompt=prompt,
116
+ negative_prompt=negative_prompt,
117
+ image=control_image,
118
+ num_inference_steps=num_inference_steps,
119
+ guidance_scale=guidance_scale,
120
+ num_images_per_prompt=1,
121
+ height=height,
122
+ width=width,
123
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
124
+ generator=self.generator
125
+ )
126
+ result_image = paste_transparent_image(out.images[0], image)
127
+
128
+ # convert the resulting PIL image to a base64-encoded string
129
+
130
+ # return first generate PIL image
131
+ return result_image
132
+
133
+
134
+ # helper to decode input image
135
+
136
+ def decode_base64_image(self, image_string):
137
+ base64_image = base64.b64decode(image_string)
138
+ buffer = BytesIO(base64_image)
139
+ image = Image.open(buffer)
140
+ return image