File size: 8,098 Bytes
937f590
7e14edd
 
 
 
 
 
 
 
 
 
 
 
937f590
7e14edd
93dd707
cd7f3ee
93dd707
7e14edd
 
 
937f590
7e14edd
 
937f590
7e14edd
 
c91c54c
7e14edd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7713986
7e14edd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
937f590
 
 
 
 
 
 
 
 
 
 
 
 
5c0e5df
937f590
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5b9aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b991d0a
 
 
 
 
 
 
 
 
 
 
4f8e7f5
 
 
 
04448f6
 
 
 
 
7e14edd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c91c54c
d7040df
 
c91c54c
36fb421
7e14edd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305


#@title 1. General Setup

!pip install -qq diffusers==0.11.1 transformers ftfy accelerate
!pip install -Uq diffusers transformers
!pip install -Uq gradio
!pip install -Uq accelerate

from diffusers import StableDiffusionPipeline
pipeline = StableDiffusionPipeline

from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from accelerate import init_empty_weights
import gradio
import torch
import os

# FOR DEPLOYMENT: uncomment these and delete the notebook_login() below
# api_key = os.environ['api_key']
# my_token = api_key

from huggingface_hub import notebook_login
notebook_login()

import PIL
from PIL import Image

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

pretrained_model_name_or_path = "stabilityai/stable-diffusion-2"

from IPython.display import Markdown
from huggingface_hub import hf_hub_download
     

#@title 2. Tell it What Concepts to Load

models_to_load = [
    "ahx-model-3",
    "ahx-model-5",
    "ahx-model-6",
    "ahx-model-7",
    "ahx-model-8",
    "ahx-model-9",
    "ahx-model-10",
    "ahx-model-11",
]

models_to_load = [f"sd-concepts-library/{model}" for model in models_to_load]
completed_concept_pipes = {}
     

#@title 3. Load the Concepts as Distinct Pipes

for repo_id_embeds in models_to_load:
  print(f"loading {repo_id_embeds}")
  print("----------------------")
  # repo_id_embeds = "sd-concepts-library/ahx-model-3"

  embeds_url = "" #Add the URL or path to a learned_embeds.bin file in case you have one
  placeholder_token_string = "" #Add what is the token string in case you are uploading your own embed

  downloaded_embedding_folder = "./downloaded_embedding"
  if not os.path.exists(downloaded_embedding_folder):
    os.mkdir(downloaded_embedding_folder)
  if(not embeds_url):
    embeds_path = hf_hub_download(repo_id=repo_id_embeds, filename="learned_embeds.bin")
    token_path = hf_hub_download(repo_id=repo_id_embeds, filename="token_identifier.txt")
    !cp 
downloaded_embedding_folder
    !cp 
downloaded_embedding_folder
    with open(f'{downloaded_embedding_folder}/token_identifier.txt', 'r') as file:
      placeholder_token_string = file.read()
  else:
    !wget -q -O $downloaded_embedding_folder/learned_embeds.bin $embeds_url

  learned_embeds_path = f"{downloaded_embedding_folder}/learned_embeds.bin"

  # ----

  tokenizer = CLIPTokenizer.from_pretrained(
    pretrained_model_name_or_path,
    subfolder="tokenizer",
  )
  text_encoder = CLIPTextModel.from_pretrained(
      pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch.float16
  )

  # ----

  def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
    loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
    
    # separate token and the embeds
    trained_token = list(loaded_learned_embeds.keys())[0]
    embeds = loaded_learned_embeds[trained_token]

    # cast to dtype of text_encoder
    dtype = text_encoder.get_input_embeddings().weight.dtype
    embeds.to(dtype)

    # add the token in tokenizer
    token = token if token is not None else trained_token
    num_added_tokens = tokenizer.add_tokens(token)
    if num_added_tokens == 0:
      raise ValueError(f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer.")
    
    # resize the token embeddings
    text_encoder.resize_token_embeddings(len(tokenizer))
    
    # get the id for the token and assign the embeds
    token_id = tokenizer.convert_tokens_to_ids(token)
    text_encoder.get_input_embeddings().weight.data[token_id] = embeds

  load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer)

  # FOR DEPLOYMENT: add use_auth_token=my_token to pipe keyword args
    # ie --> pipe = pipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=my_token).to("cuda")
  pipe = StableDiffusionPipeline.from_pretrained(
    pretrained_model_name_or_path,
    torch_dtype=torch.float16,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
  ).to("cuda")

  completed_concept_pipes[repo_id_embeds] = pipe
  print("--> complete !")
  print("----------------------")


     

#@title 4. Print Available Concept Strings

print("AVAILABLE CONCEPTS TO SELECT FROM")
print("copy one and paste below under 'model'")
print("------------------------------------------------------")
# list(completed_concept_pipes)
for model in completed_concept_pipes:
  print(f"{model}")
     

#@title 5. Optionally Test without Gradio

model = "" #@param {type: "string"}
prompt = "" #@param {type:"string"}

if prompt and model:
  if model not in completed_concept_pipes:
    raise ValueError("Invalid Model Name")

  model_token = model.split("/")[1]
  prompt = f"{prompt} in the style of <{model_token}>"

  if model == "sd-concepts-library/ahx-model-5":
    prompt = f"{prompt} in the style of "

  num_samples = 1
  num_rows = 1

  all_images = [] 
  pipe = completed_concept_pipes[model]

  for _ in range(num_rows):
      images = pipe(prompt, num_images_per_prompt=num_samples, height=512, width=512, num_inference_steps=30, guidance_scale=7.5).images
      all_images.extend(images)

  grid = image_grid(all_images, num_samples, num_rows)
  grid
     

#@title 6. Define Custom CSS for Gradio

use_custom_css = True

gradio_css = """
  #output-image {
    border: 1px solid black;
    background-color: white;
    width: 500px;
    display: block;
    margin-left: auto;
    margin-right: auto;
  }
"""

gradio_css_alternative = """
  #go-button {
    background-color: white;
    border-radius: 0;
    border: none;
    font-family: serif;
    background-image: none;
    font-weight: 100;
    width: fit-content;
    display: block;
    margin-left: auto;
    margin-right: auto;
    text-decoration: underline;
    box-shadow: none;
    color: blue;
  }
  .rounded-lg {
    border: none;
  }
  .gr-box {
    border-radius: 0;
    border: 1px solid black;
  }
  .text-gray-500 {
    color: black;
    font-family: serif;
    font-size: 15px;
  }
  .border-gray-200 {
    border: 1px solid black;
  }
  .bg-gray-200 {
    background-color: white;
    --tw-bg-opacity: 0;
  }
  footer {
    display: none;
  }
  footer {
    opacity: 0;
  }
  #output-image {
    border: 1px solid black;
    background-color: white;
    width: 500px;
    display: block;
    margin-left: auto;
    margin-right: auto;
  }
  .absolute {
    display: none;
  }
  #input-text {
    width: 500px;
    display: block;
    margin-left: auto;
    margin-right: auto;
    padding: 0 0 0 0;
  }
  .py-6 {
    padding-top: 0;
    padding-bottom: 0;
  }
  .px-4 {
    padding-left: 0;
    padding-right: 0;
  }
  .rounded-lg {
    border-radius: 0;
  }
  .gr-padded {
    padding: 0 0;
    margin-bottom: 12.5px;
  }
  .col > *, .col > .gr-form > * {
    width: 500px;
    margin-left: auto;
    margin-right: auto;
  }
"""
     

#@title 7. Build and Launch the Gradio Interface

DROPDOWNS = {}

for model in models_to_load:
  token = model.split("/")[1]
  DROPDOWNS[model] = f" in the style of <{token}>"

if "sd-concepts-library/ahx-model-5" in DROPDOWNS:
  DROPDOWNS["sd-concepts-library/ahx-model-5"] = f"{prompt} in the style of "

def image_prompt(prompt, dropdown):
  prompt = prompt + DROPDOWNS[dropdown]
  pipe = completed_concept_pipes[dropdown]
  return pipe(prompt=prompt, height=512, width=512).images[0]

with gradio.Blocks(css=gradio_css if use_custom_css else "") as demo:
  dropdown = gradio.Dropdown(list(DROPDOWNS), label="choose style...")
  prompt = gradio.Textbox(label="image prompt...", elem_id="input-text")
  output = gradio.Image(elem_id="output-image")
  go_button = gradio.Button("draw it!", elem_id="go-button")
  go_button.click(fn=image_prompt, inputs=[prompt, dropdown], outputs=output)

demo.launch(share=True)