import gradio as gr import PIL import torch import numpy as np from PIL import Image from tqdm import tqdm import torch.nn.functional as F import torchvision.transforms as T from diffusers import LMSDiscreteScheduler, DiffusionPipeline # configurations torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" height, width = 512, 512 guidance_scale = 8 loss_scale = 200 num_inference_steps = 10 model_path = "CompVis/stable-diffusion-v1-4" sd_pipeline = DiffusionPipeline.from_pretrained( model_path, low_cpu_mem_usage = True, torch_dtype=torch.float32 ).to(torch_device) sd_pipeline.load_textual_inversion("sd-concepts-library/illustration-style") sd_pipeline.load_textual_inversion("sd-concepts-library/line-art") sd_pipeline.load_textual_inversion("sd-concepts-library/hitokomoru-style-nao") sd_pipeline.load_textual_inversion("sd-concepts-library/style-of-marc-allante") sd_pipeline.load_textual_inversion("sd-concepts-library/midjourney-style") sd_pipeline.load_textual_inversion("sd-concepts-library/hanfu-anime-style") sd_pipeline.load_textual_inversion("sd-concepts-library/birb-style") styles_mapping = { "Illustration Style": '', "Line Art":'', "Hitokomoru Style":'', "Marc Allante": '', "Midjourney":'', "Hanfu Anime": '', "Birb Style": '' } # Define seeds for all the styles seed_list = [11, 56, 110, 65, 5, 29, 47] # Optimized loss computation functions def edge_detection(image): channels = image.shape[1] kernels = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1], [-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], device=image.device).float() kernels = kernels.view(2, 1, 3, 3).repeat(channels, 1, 1, 1) padded_image = F.pad(image, (1, 1, 1, 1), mode='replicate') edge = F.conv2d(padded_image, kernels, groups=channels) return torch.sqrt(edge[:, :channels]**2 + edge[:, channels:]**2) @torch.jit.script def compute_loss(original_image, loss_type: str): if loss_type == 'blue': return torch.abs(original_image[:,2] - 0.9).mean() elif loss_type == 'edge': ed_value = edge_detection(original_image) return F.mse_loss(ed_value, (ed_value > 0.5).float()) elif loss_type == 'contrast': transformed_image = T.functional.adjust_contrast(original_image, contrast_factor=2) return torch.abs(transformed_image - original_image).mean() elif loss_type == 'brightness': transformed_image = T.functional.adjust_brightness(original_image, brightness_factor=2) return torch.abs(transformed_image - original_image).mean() elif loss_type == 'sharpness': transformed_image = T.functional.adjust_sharpness(original_image, sharpness_factor=2) return torch.abs(transformed_image - original_image).mean() elif loss_type == 'saturation': transformed_image = T.functional.adjust_saturation(original_image, saturation_factor=10) return torch.abs(transformed_image - original_image).mean() else: return torch.tensor(0.0, device=original_image.device) # Optimized generate_image function @torch.cuda.amp.autocast() def generate_image(seed, prompt, loss_type, loss_flag=False): generator = torch.manual_seed(seed) batch_size = 1 text_embeddings = sd_pipeline._encode_prompt(prompt, sd_pipeline.device, 1, True) latents = torch.randn( (batch_size, sd_pipeline.unet.config.in_channels, height // 8, width // 8), generator=generator, ).to(sd_pipeline.device) latents = latents * sd_pipeline.scheduler.init_noise_sigma sd_pipeline.scheduler.set_timesteps(num_inference_steps) for i, t in enumerate(tqdm(sd_pipeline.scheduler.timesteps)): latent_model_input = torch.cat([latents] * 2) latent_model_input = sd_pipeline.scheduler.scale_model_input(latent_model_input, t) with torch.no_grad(): noise_pred = sd_pipeline.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if loss_flag and i % 5 == 0: latents = latents.detach().requires_grad_() latents_x0 = sd_pipeline.scheduler.step(noise_pred, t, latents).pred_original_sample with torch.no_grad(): denoised_images = sd_pipeline.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 loss = compute_loss(denoised_images, loss_type) * loss_scale print(f"Step {i}, Loss: {loss.item():.4f}") cond_grad = torch.autograd.grad(loss, latents)[0] latents = latents.detach() - cond_grad * sd_pipeline.scheduler.sigmas[i] ** 2 latents = sd_pipeline.scheduler.step(noise_pred, t, latents).prev_sample return latents # Gradio interface function def generate_images(prompt, style, guidance_type): images = show_image(prompt, style, guidance_type) return images[0], images[1] # Create Gradio interface iface = gr.Interface( fn=generate_images, inputs=[ gr.Textbox(label="Prompt"), gr.Dropdown(list(styles_mapping.keys()), label="Style"), gr.Dropdown(["blue", "edge", "contrast", "brightness", "sharpness", "saturation"], label="Guidance Type"), ], outputs=[ gr.Image(label="Image without Loss"), gr.Image(label="Image with Loss"), ], examples=get_examples(), title="Text Inversion Image Generation", description="Generate images using text inversion with different styles and guidance types.", ) # Launch the app iface.launch()