Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| from torch.utils.data import Dataset, DataLoader | |
| import gradio as gr | |
| import sys | |
| import tqdm | |
| import uuid | |
| sys.path.append(os.path.abspath(os.path.join("", ".."))) | |
| import gc | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| from PIL import Image | |
| import numpy as np | |
| from editing import get_direction, debias | |
| from sampling import sample_weights | |
| from lora_w2w import LoRAw2w | |
| from transformers import CLIPTextModel | |
| from lora_w2w import LoRAw2w | |
| from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler | |
| from transformers import AutoTokenizer, PretrainedConfig | |
| from diffusers import ( | |
| AutoencoderKL, | |
| DDPMScheduler, | |
| DiffusionPipeline, | |
| DPMSolverMultistepScheduler, | |
| UNet2DConditionModel, | |
| PNDMScheduler, | |
| StableDiffusionPipeline | |
| ) | |
| from huggingface_hub import snapshot_download | |
| import spaces | |
| models_path = snapshot_download(repo_id="Snapchat/w2w") | |
| def load_models(device): | |
| pretrained_model_name_or_path = "stablediffusionapi/realistic-vision-v51" | |
| revision = None | |
| weight_dtype = torch.bfloat16 | |
| # Load scheduler, tokenizer and models. | |
| pipe = StableDiffusionPipeline.from_pretrained("stablediffusionapi/realistic-vision-v51", | |
| torch_dtype=torch.float16,safety_checker = None, | |
| requires_safety_checker = False).to(device) | |
| noise_scheduler = pipe.scheduler | |
| del pipe | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| pretrained_model_name_or_path, subfolder="tokenizer", revision=revision | |
| ) | |
| text_encoder = CLIPTextModel.from_pretrained( | |
| pretrained_model_name_or_path, subfolder="text_encoder", revision=revision | |
| ) | |
| vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision) | |
| unet = UNet2DConditionModel.from_pretrained( | |
| pretrained_model_name_or_path, subfolder="unet", revision=revision | |
| ) | |
| unet.requires_grad_(False) | |
| unet.to(device, dtype=weight_dtype) | |
| vae.requires_grad_(False) | |
| text_encoder.requires_grad_(False) | |
| vae.requires_grad_(False) | |
| vae.to(device, dtype=weight_dtype) | |
| text_encoder.to(device, dtype=weight_dtype) | |
| print("") | |
| return unet, vae, text_encoder, tokenizer, noise_scheduler | |
| device="cuda" | |
| mean = torch.load(f"{models_path}/files/mean.pt", map_location=torch.device('cpu')).bfloat16().to(device) | |
| std = torch.load(f"{models_path}/files/std.pt", map_location=torch.device('cpu')).bfloat16().to(device) | |
| v = torch.load(f"{models_path}/files/V.pt", map_location=torch.device('cpu')).bfloat16().to(device) | |
| proj = torch.load(f"{models_path}/files/proj_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device) | |
| df = torch.load(f"{models_path}/files/identity_df.pt") | |
| weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt") | |
| pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device) | |
| unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device) | |
| def sample_then_run(): | |
| # get mean and standard deviation for each principal component | |
| m = torch.mean(proj, 0) | |
| standev = torch.std(proj, 0) | |
| # sample | |
| sample = torch.zeros([1, 1000]).to(device) | |
| for i in range(1000): | |
| sample[0, i] = torch.normal(m[i], factor*standev[i], (1,1)) | |
| net = "model_"+str(uuid.uuid4())[:4]+".pt" | |
| return net | |
| with gr.Blocks(css="style.css") as demo: | |
| net = gr.State() | |
| with gr.Column(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| sample = gr.Button("🎲 Sample New Model") | |
| sample.click(fn=sample_then_run, inputs = [net], outputs=[net]) | |
| demo.queue().launch() | |