tombetthauser's picture
Working artist concept selector
7e14edd
raw
history blame
8.1 kB
#@title 1. General Setup
!pip install -qq diffusers==0.11.1 transformers ftfy accelerate
!pip install -Uq diffusers transformers
!pip install -Uq gradio
!pip install -Uq accelerate
from diffusers import StableDiffusionPipeline
pipeline = StableDiffusionPipeline
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from accelerate import init_empty_weights
import gradio
import torch
import os
# FOR DEPLOYMENT: uncomment these and delete the notebook_login() below
# api_key = os.environ['api_key']
# my_token = api_key
from huggingface_hub import notebook_login
notebook_login()
import PIL
from PIL import Image
def image_grid(imgs, rows, cols):
assert len(imgs) == rows*cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
pretrained_model_name_or_path = "stabilityai/stable-diffusion-2"
from IPython.display import Markdown
from huggingface_hub import hf_hub_download
#@title 2. Tell it What Concepts to Load
models_to_load = [
"ahx-model-3",
"ahx-model-5",
"ahx-model-6",
"ahx-model-7",
"ahx-model-8",
"ahx-model-9",
"ahx-model-10",
"ahx-model-11",
]
models_to_load = [f"sd-concepts-library/{model}" for model in models_to_load]
completed_concept_pipes = {}
#@title 3. Load the Concepts as Distinct Pipes
for repo_id_embeds in models_to_load:
print(f"loading {repo_id_embeds}")
print("----------------------")
# repo_id_embeds = "sd-concepts-library/ahx-model-3"
embeds_url = "" #Add the URL or path to a learned_embeds.bin file in case you have one
placeholder_token_string = "" #Add what is the token string in case you are uploading your own embed
downloaded_embedding_folder = "./downloaded_embedding"
if not os.path.exists(downloaded_embedding_folder):
os.mkdir(downloaded_embedding_folder)
if(not embeds_url):
embeds_path = hf_hub_download(repo_id=repo_id_embeds, filename="learned_embeds.bin")
token_path = hf_hub_download(repo_id=repo_id_embeds, filename="token_identifier.txt")
!cp
downloaded_embedding_folder
!cp
downloaded_embedding_folder
with open(f'{downloaded_embedding_folder}/token_identifier.txt', 'r') as file:
placeholder_token_string = file.read()
else:
!wget -q -O $downloaded_embedding_folder/learned_embeds.bin $embeds_url
learned_embeds_path = f"{downloaded_embedding_folder}/learned_embeds.bin"
# ----
tokenizer = CLIPTokenizer.from_pretrained(
pretrained_model_name_or_path,
subfolder="tokenizer",
)
text_encoder = CLIPTextModel.from_pretrained(
pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch.float16
)
# ----
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
# separate token and the embeds
trained_token = list(loaded_learned_embeds.keys())[0]
embeds = loaded_learned_embeds[trained_token]
# cast to dtype of text_encoder
dtype = text_encoder.get_input_embeddings().weight.dtype
embeds.to(dtype)
# add the token in tokenizer
token = token if token is not None else trained_token
num_added_tokens = tokenizer.add_tokens(token)
if num_added_tokens == 0:
raise ValueError(f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer.")
# resize the token embeddings
text_encoder.resize_token_embeddings(len(tokenizer))
# get the id for the token and assign the embeds
token_id = tokenizer.convert_tokens_to_ids(token)
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer)
# FOR DEPLOYMENT: add use_auth_token=my_token to pipe keyword args
# ie --> pipe = pipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=my_token).to("cuda")
pipe = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=torch.float16,
text_encoder=text_encoder,
tokenizer=tokenizer,
).to("cuda")
completed_concept_pipes[repo_id_embeds] = pipe
print("--> complete !")
print("----------------------")
#@title 4. Print Available Concept Strings
print("AVAILABLE CONCEPTS TO SELECT FROM")
print("copy one and paste below under 'model'")
print("------------------------------------------------------")
# list(completed_concept_pipes)
for model in completed_concept_pipes:
print(f"{model}")
#@title 5. Optionally Test without Gradio
model = "" #@param {type: "string"}
prompt = "" #@param {type:"string"}
if prompt and model:
if model not in completed_concept_pipes:
raise ValueError("Invalid Model Name")
model_token = model.split("/")[1]
prompt = f"{prompt} in the style of <{model_token}>"
if model == "sd-concepts-library/ahx-model-5":
prompt = f"{prompt} in the style of "
num_samples = 1
num_rows = 1
all_images = []
pipe = completed_concept_pipes[model]
for _ in range(num_rows):
images = pipe(prompt, num_images_per_prompt=num_samples, height=512, width=512, num_inference_steps=30, guidance_scale=7.5).images
all_images.extend(images)
grid = image_grid(all_images, num_samples, num_rows)
grid
#@title 6. Define Custom CSS for Gradio
use_custom_css = True
gradio_css = """
#output-image {
border: 1px solid black;
background-color: white;
width: 500px;
display: block;
margin-left: auto;
margin-right: auto;
}
"""
gradio_css_alternative = """
#go-button {
background-color: white;
border-radius: 0;
border: none;
font-family: serif;
background-image: none;
font-weight: 100;
width: fit-content;
display: block;
margin-left: auto;
margin-right: auto;
text-decoration: underline;
box-shadow: none;
color: blue;
}
.rounded-lg {
border: none;
}
.gr-box {
border-radius: 0;
border: 1px solid black;
}
.text-gray-500 {
color: black;
font-family: serif;
font-size: 15px;
}
.border-gray-200 {
border: 1px solid black;
}
.bg-gray-200 {
background-color: white;
--tw-bg-opacity: 0;
}
footer {
display: none;
}
footer {
opacity: 0;
}
#output-image {
border: 1px solid black;
background-color: white;
width: 500px;
display: block;
margin-left: auto;
margin-right: auto;
}
.absolute {
display: none;
}
#input-text {
width: 500px;
display: block;
margin-left: auto;
margin-right: auto;
padding: 0 0 0 0;
}
.py-6 {
padding-top: 0;
padding-bottom: 0;
}
.px-4 {
padding-left: 0;
padding-right: 0;
}
.rounded-lg {
border-radius: 0;
}
.gr-padded {
padding: 0 0;
margin-bottom: 12.5px;
}
.col > *, .col > .gr-form > * {
width: 500px;
margin-left: auto;
margin-right: auto;
}
"""
#@title 7. Build and Launch the Gradio Interface
DROPDOWNS = {}
for model in models_to_load:
token = model.split("/")[1]
DROPDOWNS[model] = f" in the style of <{token}>"
if "sd-concepts-library/ahx-model-5" in DROPDOWNS:
DROPDOWNS["sd-concepts-library/ahx-model-5"] = f"{prompt} in the style of "
def image_prompt(prompt, dropdown):
prompt = prompt + DROPDOWNS[dropdown]
pipe = completed_concept_pipes[dropdown]
return pipe(prompt=prompt, height=512, width=512).images[0]
with gradio.Blocks(css=gradio_css if use_custom_css else "") as demo:
dropdown = gradio.Dropdown(list(DROPDOWNS), label="choose style...")
prompt = gradio.Textbox(label="image prompt...", elem_id="input-text")
output = gradio.Image(elem_id="output-image")
go_button = gradio.Button("draw it!", elem_id="go-button")
go_button.click(fn=image_prompt, inputs=[prompt, dropdown], outputs=output)
demo.launch(share=True)