RohitGandikota's picture
testing predict_noise to guidance 1
ff5b7eb
raw
history blame
8.51 kB
from pathlib import Path
import gradio as gr
import torch
from finetuning import FineTunedModel
from StableDiffuser import StableDiffuser
from tqdm import tqdm
class Demo:
def __init__(self) -> None:
self.training = False
self.generating = False
self.nsteps = 50
self.diffuser = StableDiffuser(scheduler='DDIM', seed=42).to('cuda')
self.finetuner = None
with gr.Blocks() as demo:
self.layout()
demo.queue(concurrency_count=2).launch()
def disable(self):
return [gr.update(interactive=False), gr.update(interactive=False)]
def layout(self):
with gr.Row():
self.explain = gr.HTML(interactive=False,
value="<p>This page demonstrates Erasing Concepts in Stable Diffusion (Ganikota, Materzynska, Fiotto-Kaufman and Bau; paper and code linked from https://erasing.baulab.info/). <br> Use it in two steps <br> 1. First, on the left fine-tune your own custom model by naming the concept that you want to erase. For example, you can try erasing all cars from a model by entering the prompt corresponding to the concept to erase as 'car'. This can take awhile. For example, with the default settings, this can take about an hour. <br> 2. Second, on the right once you have your model fine-tuned, you can try running it in inference. <br>If you want to run it yourself, then you can create your own instance. Configuration, code, and details are at https://github.com/xxxx/xxxx/xxx</p>")
with gr.Row():
with gr.Column(scale=1) as training_column:
self.prompt_input = gr.Text(
placeholder="Enter prompt...",
label="Prompt to Erase",
info="Prompt corresponding to concept to erase"
)
self.train_method_input = gr.Dropdown(
choices=['ESD-x', 'ESD-self'],
value='ESD-x',
label='Train Method',
info='Method of training'
)
self.neg_guidance_input = gr.Number(
value=1,
label="Negative Guidance",
info='Guidance of negative training used to train'
)
self.iterations_input = gr.Number(
value=150,
precision=0,
label="Iterations",
info='iterations used to train'
)
self.lr_input = gr.Number(
value=1e-5,
label="Learning Rate",
info='Learning rate used to train'
)
self.train_button = gr.Button(
value="Train",
)
self.download = gr.Files()
with gr.Column(scale=2) as inference_column:
with gr.Row():
with gr.Column(scale=5):
self.prompt_input_infr = gr.Text(
placeholder="Enter prompt...",
label="Prompt",
info="Prompt to generate"
)
with gr.Column(scale=1):
self.seed_infr = gr.Number(
label="Seed",
value=42
)
with gr.Row():
self.image_new = gr.Image(
label="New Image",
interactive=False
)
self.image_orig = gr.Image(
label="Orig Image",
interactive=False
)
with gr.Row():
self.infr_button = gr.Button(
value="Generate",
interactive=False
)
self.infr_button.click(self.inference, inputs = [
self.prompt_input_infr,
self.seed_infr
],
outputs=[
self.image_new,
self.image_orig
]
)
self.train_button.click(self.disable,
outputs=[self.train_button, self.infr_button]
)
self.train_button.click(self.train, inputs = [
self.prompt_input,
self.train_method_input,
self.neg_guidance_input,
self.iterations_input,
self.lr_input
],
outputs=[self.train_button, self.infr_button, self.download]
)
def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
if self.training:
return [None, None, None]
else:
self.training = True
del self.finetuner
torch.cuda.empty_cache()
self.diffuser = self.diffuser.train().float()
if train_method == 'ESD-x':
modules = ".*attn2$"
elif train_method == 'ESD-self':
modules = ".*attn1$"
finetuner = FineTunedModel(self.diffuser, modules)
optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
criteria = torch.nn.MSELoss()
pbar = tqdm(range(iterations))
with torch.no_grad():
neutral_text_embeddings = self.diffuser.get_text_embeddings([''],n_imgs=1)
positive_text_embeddings = self.diffuser.get_text_embeddings([prompt],n_imgs=1)
for i in pbar:
with torch.no_grad():
self.diffuser.set_scheduler_timesteps(self.nsteps)
optimizer.zero_grad()
iteration = torch.randint(1, self.nsteps - 1, (1,)).item()
latents = self.diffuser.get_initial_latents(1, 512, 1)
with finetuner:
latents_steps, _ = self.diffuser.diffusion(
latents,
positive_text_embeddings,
start_iteration=0,
end_iteration=iteration,
guidance_scale=3,
show_progress=False
)
self.diffuser.set_scheduler_timesteps(1000)
iteration = int(iteration / self.nsteps * 1000)
positive_latents = self.diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
neutral_latents = self.diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)
with finetuner:
negative_latents = self.diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
positive_latents.requires_grad = False
neutral_latents.requires_grad = False
loss = criteria(negative_latents, neutral_latents - (neg_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs
loss.backward()
optimizer.step()
torch.save(finetuner.state_dict(), 'ft.ckpt')
self.finetuner = finetuner.eval().half()
self.diffuser = self.diffuser.eval().half()
torch.cuda.empty_cache()
self.training = False
return [gr.update(interactive=True), gr.update(interactive=True), 'ft.ckpt']
def inference(self, prompt, seed, pbar = gr.Progress(track_tqdm=True)):
if self.generating:
return [None, None]
else:
self.generating = True
self.diffuser._seed = seed
images = self.diffuser(
prompt,
n_steps=50,
reseed=True
)
orig_image = images[0][0]
torch.cuda.empty_cache()
with self.finetuner:
images = self.diffuser(
prompt,
n_steps=50,
reseed=True
)
edited_image = images[0][0]
self.generating = False
torch.cuda.empty_cache()
return edited_image, orig_image
demo = Demo()