import argparse, os, time import torch from diffusers import ( AutoencoderKL, ControlNetModel, StableDiffusionControlNetPipeline, UNet2DConditionModel, UniPCMultistepScheduler, PNDMScheduler, AmusedPipeline, AmusedScheduler, VQModel, UVit2DModel ) from transformers import AutoTokenizer, CLIPFeatureExtractor from diffusers.pipelines.deprecated.alt_diffusion import RobertaSeriesModelWithTransformation from diffusers.utils import load_image from utils.mclip import * def parse_args(): parser = argparse.ArgumentParser(description="Generate images with M3Face.") parser.add_argument( "--prompt", type=str, default="This attractive woman has narrow eyes, rosy cheeks, and wears heavy makeup.", help="The input text prompt for image generation." ) parser.add_argument( "--condition", type=str, default="mask", choices=["mask", "landmark"], help="Use segmentation mask or facial landmarks for image generation." ) parser.add_argument( "--condition_path", type=str, default=None, help="Path to the condition mask/landmark image. We will generate the condition if it is not given." ) parser.add_argument("--save_condition", action="store_true", help="Save the generated condition image.") parser.add_argument("--use_english", action="store_true", help="Use the English models.") parser.add_argument("--enhance_prompt", action="store_true", help="Enhance the given text prompt.") parser.add_argument("--num_inference_steps", type=int, default=30) parser.add_argument("--num_samples", type=int, default=1) parser.add_argument( "--additional_prompt", type=str, default="rim lighting, dslr, ultra quality, sharp focus, dof, Fujifilm XT3, crystal clear, highly detailed glossy eyes, high detailed skin, skin pores, 8K UHD" ) parser.add_argument( "--negative_prompt", type=str, default="low quality, bad quality, worst quality, blurry, disfigured, ugly, immature, cartoon, painting" ) parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible generation.") parser.add_argument( "--output_dir", type=str, default="output/", help="The output directory where the results will be written.", ) args = parser.parse_args() return args def get_controlnet(args): if args.use_english: sd_model_name = 'runwayml/stable-diffusion-v1-5' controlnet_model_name = 'm3face/FaceControlNet' if args.condition == 'mask': controlnet_revision = 'segmentation-english' elif args.condition == 'landmark': controlnet_revision = 'landmark-english' controlnet = ControlNetModel.from_pretrained(controlnet_model_name, use_safetensors=True, revision=controlnet_revision) pipeline = StableDiffusionControlNetPipeline.from_pretrained( sd_model_name, controlnet=controlnet, use_safetensors=True, safety_checker=None ).to("cuda") pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) pipeline.enable_model_cpu_offload() else: sd_model_name = 'BAAI/AltDiffusion-m18' controlnet_model_name = 'm3face/FaceControlNet' if args.condition == 'mask': controlnet_revision = 'segmentation-mlin' elif args.condition == 'landmark': controlnet_revision = 'landmark-mlin' vae = AutoencoderKL.from_pretrained(sd_model_name, subfolder="vae") unet = UNet2DConditionModel.from_pretrained(sd_model_name, subfolder="unet") tokenizer = AutoTokenizer.from_pretrained(sd_model_name, subfolder="tokenizer", use_fast=False) text_encoder = RobertaSeriesModelWithTransformation.from_pretrained(sd_model_name, subfolder="text_encoder") controlnet = ControlNetModel.from_pretrained(controlnet_model_name, revision=controlnet_revision) scheduler = PNDMScheduler.from_pretrained( sd_model_name, subfolder='scheduler', ) scheduler = UniPCMultistepScheduler.from_config(scheduler.config) feature_extractor = CLIPFeatureExtractor.from_pretrained( sd_model_name, subfolder='feature_extractor', ) pipeline = StableDiffusionControlNetPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, feature_extractor=feature_extractor, ).to('cuda') return pipeline def get_muse(args): muse_model_name = 'm3face/FaceConditioning' if args.condition == 'mask': muse_revision = 'segmentation' elif args.condition == 'landmark': muse_revision = 'landmark' scheduler = AmusedScheduler.from_pretrained(muse_model_name, revision=muse_revision, subfolder='scheduler') vqvae = VQModel.from_pretrained(muse_model_name, revision=muse_revision, subfolder='vqvae') uvit2 = UVit2DModel.from_pretrained(muse_model_name, revision=muse_revision, subfolder='transformer') text_encoder = MultilingualCLIP.from_pretrained(muse_model_name, revision=muse_revision, subfolder='text_encoder') tokenizer = AutoTokenizer.from_pretrained(muse_model_name, revision=muse_revision, subfolder='tokenizer') pipeline = AmusedPipeline( vqvae=vqvae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=uvit2, scheduler=scheduler ).to("cuda") return pipeline if __name__ == '__main__': args = parse_args() # ========== set up face generation pipeline ========== controlnet = get_controlnet(args) # ========== set output directory ========== os.makedirs(args.output_dir, exist_ok=True) # ========== set random seed ========== if args.seed is None: generator = None else: generator = torch.Generator().manual_seed(args.seed) # ========== generation ========== id = int(time.time()) if args.condition_path: condition = load_image(args.condition_path).resize((512, 512)) else: # generate condition muse = get_muse(args) if args.condition == 'mask': muse_added_prompt = 'Generate face segmentation | ' elif args.condition == 'landmark': muse_added_prompt = 'Generate face landmark | ' muse_prompt = muse_added_prompt + args.prompt condition = muse(muse_prompt, num_inference_steps=256).images[0].resize((512, 512)) if args.save_condition: condition.save(f'{args.output_dir}/{id}_condition.png') latents = torch.randn((args.num_samples, 4, 64, 64), generator=generator) prompt = f'{args.prompt}, {args.additional_prompt}' if args.prompt else args.additional_prompt images = controlnet(prompt, image=condition, num_inference_steps=args.num_inference_steps, negative_prompt=args.negative_prompt, generator=generator, latents=latents, num_images_per_prompt=args.num_samples).images for i, image in enumerate(images): image.save(f'{args.output_dir}/{id}_{i}.png')