|
import argparse, os, time |
|
import torch |
|
from diffusers import ( |
|
AutoencoderKL, |
|
ControlNetModel, |
|
StableDiffusionControlNetPipeline, |
|
UNet2DConditionModel, |
|
UniPCMultistepScheduler, |
|
PNDMScheduler, |
|
AmusedPipeline, AmusedScheduler, VQModel, UVit2DModel |
|
) |
|
from transformers import AutoTokenizer, CLIPFeatureExtractor |
|
from diffusers.pipelines.deprecated.alt_diffusion import RobertaSeriesModelWithTransformation |
|
from diffusers.utils import load_image |
|
from utils.mclip import * |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Generate images with M3Face.") |
|
parser.add_argument( |
|
"--prompt", |
|
type=str, |
|
default="This attractive woman has narrow eyes, rosy cheeks, and wears heavy makeup.", |
|
help="The input text prompt for image generation." |
|
) |
|
parser.add_argument( |
|
"--condition", |
|
type=str, |
|
default="mask", |
|
choices=["mask", "landmark"], |
|
help="Use segmentation mask or facial landmarks for image generation." |
|
) |
|
parser.add_argument( |
|
"--condition_path", |
|
type=str, |
|
default=None, |
|
help="Path to the condition mask/landmark image. We will generate the condition if it is not given." |
|
) |
|
parser.add_argument("--save_condition", action="store_true", help="Save the generated condition image.") |
|
parser.add_argument("--use_english", action="store_true", help="Use the English models.") |
|
parser.add_argument("--enhance_prompt", action="store_true", help="Enhance the given text prompt.") |
|
parser.add_argument("--num_inference_steps", type=int, default=30) |
|
parser.add_argument("--num_samples", type=int, default=1) |
|
parser.add_argument( |
|
"--additional_prompt", |
|
type=str, |
|
default="rim lighting, dslr, ultra quality, sharp focus, dof, Fujifilm XT3, crystal clear, highly detailed glossy eyes, high detailed skin, skin pores, 8K UHD" |
|
) |
|
parser.add_argument( |
|
"--negative_prompt", |
|
type=str, |
|
default="low quality, bad quality, worst quality, blurry, disfigured, ugly, immature, cartoon, painting" |
|
) |
|
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible generation.") |
|
parser.add_argument( |
|
"--output_dir", |
|
type=str, |
|
default="output/", |
|
help="The output directory where the results will be written.", |
|
) |
|
args = parser.parse_args() |
|
|
|
return args |
|
|
|
def get_controlnet(args): |
|
if args.use_english: |
|
sd_model_name = 'runwayml/stable-diffusion-v1-5' |
|
controlnet_model_name = 'm3face/FaceControlNet' |
|
if args.condition == 'mask': |
|
controlnet_revision = 'segmentation-english' |
|
elif args.condition == 'landmark': |
|
controlnet_revision = 'landmark-english' |
|
controlnet = ControlNetModel.from_pretrained(controlnet_model_name, use_safetensors=True, revision=controlnet_revision) |
|
pipeline = StableDiffusionControlNetPipeline.from_pretrained( |
|
sd_model_name, controlnet=controlnet, use_safetensors=True, safety_checker=None |
|
).to("cuda") |
|
|
|
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) |
|
pipeline.enable_model_cpu_offload() |
|
else: |
|
sd_model_name = 'BAAI/AltDiffusion-m18' |
|
controlnet_model_name = 'm3face/FaceControlNet' |
|
if args.condition == 'mask': |
|
controlnet_revision = 'segmentation-mlin' |
|
elif args.condition == 'landmark': |
|
controlnet_revision = 'landmark-mlin' |
|
vae = AutoencoderKL.from_pretrained(sd_model_name, subfolder="vae") |
|
unet = UNet2DConditionModel.from_pretrained(sd_model_name, subfolder="unet") |
|
tokenizer = AutoTokenizer.from_pretrained(sd_model_name, subfolder="tokenizer", use_fast=False) |
|
text_encoder = RobertaSeriesModelWithTransformation.from_pretrained(sd_model_name, subfolder="text_encoder") |
|
controlnet = ControlNetModel.from_pretrained(controlnet_model_name, revision=controlnet_revision) |
|
|
|
scheduler = PNDMScheduler.from_pretrained( |
|
sd_model_name, |
|
subfolder='scheduler', |
|
) |
|
scheduler = UniPCMultistepScheduler.from_config(scheduler.config) |
|
feature_extractor = CLIPFeatureExtractor.from_pretrained( |
|
sd_model_name, |
|
subfolder='feature_extractor', |
|
) |
|
pipeline = StableDiffusionControlNetPipeline( |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
unet=unet, |
|
controlnet=controlnet, |
|
scheduler=scheduler, |
|
safety_checker=None, |
|
feature_extractor=feature_extractor, |
|
).to('cuda') |
|
|
|
return pipeline |
|
|
|
|
|
def get_muse(args): |
|
muse_model_name = 'm3face/FaceConditioning' |
|
if args.condition == 'mask': |
|
muse_revision = 'segmentation' |
|
elif args.condition == 'landmark': |
|
muse_revision = 'landmark' |
|
scheduler = AmusedScheduler.from_pretrained(muse_model_name, revision=muse_revision, subfolder='scheduler') |
|
vqvae = VQModel.from_pretrained(muse_model_name, revision=muse_revision, subfolder='vqvae') |
|
uvit2 = UVit2DModel.from_pretrained(muse_model_name, revision=muse_revision, subfolder='transformer') |
|
text_encoder = MultilingualCLIP.from_pretrained(muse_model_name, revision=muse_revision, subfolder='text_encoder') |
|
tokenizer = AutoTokenizer.from_pretrained(muse_model_name, revision=muse_revision, subfolder='tokenizer') |
|
|
|
pipeline = AmusedPipeline( |
|
vqvae=vqvae, |
|
tokenizer=tokenizer, |
|
text_encoder=text_encoder, |
|
transformer=uvit2, |
|
scheduler=scheduler |
|
).to("cuda") |
|
|
|
return pipeline |
|
|
|
|
|
if __name__ == '__main__': |
|
args = parse_args() |
|
|
|
|
|
controlnet = get_controlnet(args) |
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
if args.seed is None: |
|
generator = None |
|
else: |
|
generator = torch.Generator().manual_seed(args.seed) |
|
|
|
|
|
id = int(time.time()) |
|
if args.condition_path: |
|
condition = load_image(args.condition_path).resize((512, 512)) |
|
else: |
|
|
|
muse = get_muse(args) |
|
if args.condition == 'mask': |
|
muse_added_prompt = 'Generate face segmentation | ' |
|
elif args.condition == 'landmark': |
|
muse_added_prompt = 'Generate face landmark | ' |
|
muse_prompt = muse_added_prompt + args.prompt |
|
condition = muse(muse_prompt, num_inference_steps=256).images[0].resize((512, 512)) |
|
if args.save_condition: |
|
condition.save(f'{args.output_dir}/{id}_condition.png') |
|
|
|
latents = torch.randn((args.num_samples, 4, 64, 64), generator=generator) |
|
prompt = f'{args.prompt}, {args.additional_prompt}' if args.prompt else args.additional_prompt |
|
images = controlnet(prompt, image=condition, num_inference_steps=args.num_inference_steps, negative_prompt=args.negative_prompt, |
|
generator=generator, latents=latents, num_images_per_prompt=args.num_samples).images |
|
|
|
for i, image in enumerate(images): |
|
image.save(f'{args.output_dir}/{id}_{i}.png') |
|
|