M3Face / generate.py
m3face's picture
Fixing links
8a7587e
import argparse, os, time
import torch
from diffusers import (
AutoencoderKL,
ControlNetModel,
StableDiffusionControlNetPipeline,
UNet2DConditionModel,
UniPCMultistepScheduler,
PNDMScheduler,
AmusedPipeline, AmusedScheduler, VQModel, UVit2DModel
)
from transformers import AutoTokenizer, CLIPFeatureExtractor
from diffusers.pipelines.deprecated.alt_diffusion import RobertaSeriesModelWithTransformation
from diffusers.utils import load_image
from utils.mclip import *
def parse_args():
parser = argparse.ArgumentParser(description="Generate images with M3Face.")
parser.add_argument(
"--prompt",
type=str,
default="This attractive woman has narrow eyes, rosy cheeks, and wears heavy makeup.",
help="The input text prompt for image generation."
)
parser.add_argument(
"--condition",
type=str,
default="mask",
choices=["mask", "landmark"],
help="Use segmentation mask or facial landmarks for image generation."
)
parser.add_argument(
"--condition_path",
type=str,
default=None,
help="Path to the condition mask/landmark image. We will generate the condition if it is not given."
)
parser.add_argument("--save_condition", action="store_true", help="Save the generated condition image.")
parser.add_argument("--use_english", action="store_true", help="Use the English models.")
parser.add_argument("--enhance_prompt", action="store_true", help="Enhance the given text prompt.")
parser.add_argument("--num_inference_steps", type=int, default=30)
parser.add_argument("--num_samples", type=int, default=1)
parser.add_argument(
"--additional_prompt",
type=str,
default="rim lighting, dslr, ultra quality, sharp focus, dof, Fujifilm XT3, crystal clear, highly detailed glossy eyes, high detailed skin, skin pores, 8K UHD"
)
parser.add_argument(
"--negative_prompt",
type=str,
default="low quality, bad quality, worst quality, blurry, disfigured, ugly, immature, cartoon, painting"
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible generation.")
parser.add_argument(
"--output_dir",
type=str,
default="output/",
help="The output directory where the results will be written.",
)
args = parser.parse_args()
return args
def get_controlnet(args):
if args.use_english:
sd_model_name = 'runwayml/stable-diffusion-v1-5'
controlnet_model_name = 'm3face/FaceControlNet'
if args.condition == 'mask':
controlnet_revision = 'segmentation-english'
elif args.condition == 'landmark':
controlnet_revision = 'landmark-english'
controlnet = ControlNetModel.from_pretrained(controlnet_model_name, use_safetensors=True, revision=controlnet_revision)
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
sd_model_name, controlnet=controlnet, use_safetensors=True, safety_checker=None
).to("cuda")
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline.enable_model_cpu_offload()
else:
sd_model_name = 'BAAI/AltDiffusion-m18'
controlnet_model_name = 'm3face/FaceControlNet'
if args.condition == 'mask':
controlnet_revision = 'segmentation-mlin'
elif args.condition == 'landmark':
controlnet_revision = 'landmark-mlin'
vae = AutoencoderKL.from_pretrained(sd_model_name, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(sd_model_name, subfolder="unet")
tokenizer = AutoTokenizer.from_pretrained(sd_model_name, subfolder="tokenizer", use_fast=False)
text_encoder = RobertaSeriesModelWithTransformation.from_pretrained(sd_model_name, subfolder="text_encoder")
controlnet = ControlNetModel.from_pretrained(controlnet_model_name, revision=controlnet_revision)
scheduler = PNDMScheduler.from_pretrained(
sd_model_name,
subfolder='scheduler',
)
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
feature_extractor = CLIPFeatureExtractor.from_pretrained(
sd_model_name,
subfolder='feature_extractor',
)
pipeline = StableDiffusionControlNetPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=feature_extractor,
).to('cuda')
return pipeline
def get_muse(args):
muse_model_name = 'm3face/FaceConditioning'
if args.condition == 'mask':
muse_revision = 'segmentation'
elif args.condition == 'landmark':
muse_revision = 'landmark'
scheduler = AmusedScheduler.from_pretrained(muse_model_name, revision=muse_revision, subfolder='scheduler')
vqvae = VQModel.from_pretrained(muse_model_name, revision=muse_revision, subfolder='vqvae')
uvit2 = UVit2DModel.from_pretrained(muse_model_name, revision=muse_revision, subfolder='transformer')
text_encoder = MultilingualCLIP.from_pretrained(muse_model_name, revision=muse_revision, subfolder='text_encoder')
tokenizer = AutoTokenizer.from_pretrained(muse_model_name, revision=muse_revision, subfolder='tokenizer')
pipeline = AmusedPipeline(
vqvae=vqvae,
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=uvit2,
scheduler=scheduler
).to("cuda")
return pipeline
if __name__ == '__main__':
args = parse_args()
# ========== set up face generation pipeline ==========
controlnet = get_controlnet(args)
# ========== set output directory ==========
os.makedirs(args.output_dir, exist_ok=True)
# ========== set random seed ==========
if args.seed is None:
generator = None
else:
generator = torch.Generator().manual_seed(args.seed)
# ========== generation ==========
id = int(time.time())
if args.condition_path:
condition = load_image(args.condition_path).resize((512, 512))
else:
# generate condition
muse = get_muse(args)
if args.condition == 'mask':
muse_added_prompt = 'Generate face segmentation | '
elif args.condition == 'landmark':
muse_added_prompt = 'Generate face landmark | '
muse_prompt = muse_added_prompt + args.prompt
condition = muse(muse_prompt, num_inference_steps=256).images[0].resize((512, 512))
if args.save_condition:
condition.save(f'{args.output_dir}/{id}_condition.png')
latents = torch.randn((args.num_samples, 4, 64, 64), generator=generator)
prompt = f'{args.prompt}, {args.additional_prompt}' if args.prompt else args.additional_prompt
images = controlnet(prompt, image=condition, num_inference_steps=args.num_inference_steps, negative_prompt=args.negative_prompt,
generator=generator, latents=latents, num_images_per_prompt=args.num_samples).images
for i, image in enumerate(images):
image.save(f'{args.output_dir}/{id}_{i}.png')