haowu11 commited on
Commit
33934d8
·
verified ·
1 Parent(s): 6539cd1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +203 -148
app.py CHANGED
@@ -1,154 +1,209 @@
1
- import gradio as gr
2
- import numpy as np
3
- import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
 
 
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
 
39
- generator = torch.Generator().manual_seed(seed)
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
 
 
 
 
 
 
 
48
  generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
-
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
- )
152
-
153
- if __name__ == "__main__":
154
- demo.launch()
 
 
 
 
 
 
 
1
  import torch
2
+ from transformers import CLIPVisionModelWithProjection,CLIPImageProcessor
3
+ from diffusers.utils import load_image
4
+ import os,sys
5
+ import gradio as gr
6
 
7
+ from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img_face import StableDiffusionXLControlNetImg2ImgPipeline
8
+ from kolors.models.modeling_chatglm import ChatGLMModel
9
+ from kolors.models.tokenization_chatglm import ChatGLMTokenizer
10
+ from kolors.models.controlnet import ControlNetModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ from diffusers import AutoencoderKL
13
+ from kolors.models.unet_2d_condition import UNet2DConditionModel
14
 
15
+ from diffusers import EulerDiscreteScheduler
16
+ from PIL import Image
17
+ import numpy as np
18
+ import cv2
19
+ from insightface.app import FaceAnalysis
20
+ from insightface.data import get_image as ins_get_image
21
+
22
+ example_path = os.path.join(os.path.dirname(__file__), 'examples')
23
+
24
+
25
+ class FaceInfoGenerator():
26
+ def __init__(self, root_dir = "./"):
27
+ self.app = FaceAnalysis(name = 'antelopev2', root = root_dir, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
28
+ self.app.prepare(ctx_id = 0, det_size = (640, 640))
29
+
30
+ def get_faceinfo_one_img(self, face_image):
31
+ face_info = self.app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
32
+
33
+ if len(face_info) == 0:
34
+ face_info = None
35
+ else:
36
+ face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face
37
+ return face_info
38
+
39
+ def face_bbox_to_square(bbox):
40
+ ## l, t, r, b to square l, t, r, b
41
+ l,t,r,b = bbox
42
+ cent_x = (l + r) / 2
43
+ cent_y = (t + b) / 2
44
+ w, h = r - l, b - t
45
+ r = max(w, h) / 2
46
+
47
+ l0 = cent_x - r
48
+ r0 = cent_x + r
49
+ t0 = cent_y - r
50
+ b0 = cent_y + r
51
+
52
+ return [l0, t0, r0, b0]
53
+
54
+
55
+ ckpt_dir = f'weights/Kolors'
56
+ text_encoder = ChatGLMModel.from_pretrained(
57
+ f'{ckpt_dir}/text_encoder').to(dtype=torch.bfloat16)
58
+ tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
59
+ vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).to(dtype=torch.bfloat16)
60
+ scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
61
+ unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).to(dtype=torch.bfloat16)
62
+
63
+ control_path = f'weights/Kolors-Controlnet-Pose-Tryon'
64
+ controlnet = ControlNetModel.from_pretrained( control_path , revision=None).to(dtype=torch.bfloat16)
65
+
66
+ face_info_generator = FaceInfoGenerator(root_dir = "./")
67
+
68
+ clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(f'weights/Kolors-IP-Adapter-FaceID-Plus/clip-vit-large-patch14-336', ignore_mismatched_sizes=True)
69
+ clip_image_encoder.to('cuda')
70
+ clip_image_processor = CLIPImageProcessor(size = 336, crop_size = 336)
71
+
72
+ pipe = StableDiffusionXLControlNetImg2ImgPipeline(
73
+ vae=vae,
74
+ controlnet = controlnet,
75
+ text_encoder=text_encoder,
76
+ tokenizer=tokenizer,
77
+ unet=unet,
78
+ scheduler=scheduler,
79
+ # image_encoder=image_encoder,
80
+ # feature_extractor=clip_image_processor,
81
+ force_zeros_for_empty_prompt=False,
82
+ face_clip_encoder=clip_image_encoder,
83
+ face_clip_processor=clip_image_processor,
84
+ )
85
+ if hasattr(pipe.unet, 'encoder_hid_proj'):
86
+ pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
87
+ ip_scale = 0.5
88
+ pipe.load_ip_adapter_faceid_plus(f'weights/Kolors-IP-Adapter-FaceID-Plus/ipa-faceid-plus.bin', device = 'cuda')
89
+ pipe.set_face_fidelity_scale(ip_scale)
90
+ pipe = pipe.to("cuda")
91
+ pipe.enable_model_cpu_offload()
92
+
93
+ def infer(face_img,pose_img, garm_img, prompt,negative_prompt, n_samples, n_steps, seed):
94
+ face_img = Image.open(face_img)
95
+ pose_img = Image.open(pose_img)
96
+ garm_img = Image.open(garm_img)
97
+ face_img = face_img.resize((336, 336))
98
+ pose_img = pose_img.resize((768, 1024))
99
+ garm_img = garm_img.resize((768, 1024))
100
+
101
+ background = Image.new("RGB", (768, 768), (255, 255, 255))
102
+ #将face_img粘贴到background中心
103
+ background.paste(face_img, (int((768 - 336) / 2), int((768 - 336) / 2)))
104
+
105
+ face_info = face_info_generator.get_faceinfo_one_img(background)
106
+
107
+ face_embeds = torch.from_numpy(np.array([face_info["embedding"]]))
108
+ face_embeds = face_embeds.to('cuda', dtype = torch.bfloat16)
109
+
110
+ controlnet_conditioning_scale = 1.0
111
+ control_guidance_end = 0.9
112
+ #strength 越是小,则生成图片越是依赖原始图片。
113
+ strength = 1.0
114
+
115
+ im1 = np.array(pose_img)
116
+ im2 = np.array(garm_img)
117
+
118
+ condi_img = Image.fromarray( np.concatenate( (im1, im2), axis=1 ) )
119
+
120
+ generator = torch.Generator(device="cpu").manual_seed(seed)
121
  image = pipe(
122
+ prompt= prompt ,
123
+ # image = init_image,
124
+ controlnet_conditioning_scale = controlnet_conditioning_scale,
125
+ control_guidance_end = control_guidance_end,
126
+ # ip_adapter_image=[ ip_adapter_img ],
127
+ face_crop_image = face_img,
128
+ face_insightface_embeds = face_embeds,
129
+ strength= strength ,
130
+ control_image = condi_img,
131
+ negative_prompt= negative_prompt ,
132
+ num_inference_steps=n_steps ,
133
+ guidance_scale= 5.0,
134
+ num_images_per_prompt=n_samples,
135
  generator=generator,
136
+ ).images
137
+ return image
138
+
139
+
140
+ block = gr.Blocks().queue()
141
+ with block:
142
+ with gr.Row():
143
+ gr.Markdown("# KolorsControlnerTryon Demo")
144
+ with gr.Row():
145
+ with gr.Column():
146
+ pose_img = gr.Image(label="Pose", sources='upload', type="filepath", height=768, value=os.path.join(example_path, 'pose/1.jpg'))
147
+ example = gr.Examples(
148
+ inputs=pose_img,
149
+ examples_per_page=10,
150
+ examples=[
151
+ os.path.join(example_path, 'pose/1.jpg'),
152
+ os.path.join(example_path, 'pose/2.jpg'),
153
+ os.path.join(example_path, 'pose/3.jpg'),
154
+ os.path.join(example_path, 'pose/4.jpg'),
155
+ os.path.join(example_path, 'pose/5.jpg'),
156
+ os.path.join(example_path, 'pose/6.jpg'),
157
+ os.path.join(example_path, 'pose/7.jpg'),
158
+ os.path.join(example_path, 'pose/8.jpg'),
159
+ os.path.join(example_path, 'pose/9.jpg'),
160
+ os.path.join(example_path, 'pose/10.jpg'),
161
+ ])
162
+ with gr.Column():
163
+ garm_img = gr.Image(label="Garment", sources='upload', type="filepath", height=768, value=os.path.join(example_path, 'garment/1.jpg'),)
164
+ example = gr.Examples(
165
+ inputs=garm_img,
166
+ examples_per_page=10,
167
+ examples=[
168
+ os.path.join(example_path, 'garment/1.jpg'),
169
+ os.path.join(example_path, 'garment/2.jpg'),
170
+ os.path.join(example_path, 'garment/3.jpg'),
171
+ os.path.join(example_path, 'garment/4.jpg'),
172
+ os.path.join(example_path, 'garment/5.jpg'),
173
+ os.path.join(example_path, 'garment/6.jpg'),
174
+ os.path.join(example_path, 'garment/7.jpg'),
175
+ os.path.join(example_path, 'garment/8.jpg'),
176
+ os.path.join(example_path, 'garment/9.jpg'),
177
+ os.path.join(example_path, 'garment/10.jpg'),
178
+ ])
179
+ with gr.Row():
180
+ with gr.Column():
181
+ face_img = gr.Image(label="Face", sources='upload', type="filepath", height=336, value=os.path.join(example_path, 'face/1.png'),)
182
+ example = gr.Examples(
183
+ inputs=face_img,
184
+ examples_per_page=10,
185
+ examples=[
186
+ os.path.join(example_path, 'face/1.png'),
187
+ os.path.join(example_path, 'face/2.png'),
188
+ os.path.join(example_path, 'face/3.png'),
189
+ os.path.join(example_path, 'face/4.png'),
190
+ os.path.join(example_path, 'face/5.png'),
191
+ os.path.join(example_path, 'face/6.png'),
192
+ os.path.join(example_path, 'face/7.png'),
193
+ os.path.join(example_path, 'face/8.png'),
194
+ os.path.join(example_path, 'face/9.png'),
195
+ os.path.join(example_path, 'face/10.png'),
196
+ ])
197
+ with gr.Column():
198
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True, scale=1)
199
+ with gr.Column():
200
+ prompt = gr.Textbox(value="这张图片上的模特穿着一件黑色的长袖T恤,T恤上印着彩色的字母'OBEY'。她还穿着一条牛仔裤。", show_label=False, elem_id="prompt")
201
+ negative_prompt = gr.Textbox(value="nsfw,脸部阴影,低分辨率,糟糕的解剖结构、糟糕的手,缺失手指、质量最差、低质量、jpeg伪影、模糊、糟糕,黑脸,霓虹灯", show_label=False, elem_id="negative_prompt")
202
+ n_samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1)
203
+ n_steps = gr.Slider(label="Steps", minimum=20, maximum=40, value=20, step=1)
204
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1)
205
+ run_button = gr.Button(value="Run")
206
+ ips = [face_img,pose_img, garm_img, prompt,negative_prompt, n_samples, n_steps, seed]
207
+ run_button.click(fn=infer, inputs=ips, outputs=[result_gallery])
208
+
209
+ block.launch(server_name='0.0.0.0', server_port=7865)