tombetthauser commited on
Commit
7e14edd
Β·
1 Parent(s): 3491a10

Working artist concept selector

Browse files
Files changed (1) hide show
  1. app.py +212 -23
app.py CHANGED
@@ -1,32 +1,201 @@
1
- # !pip install -Uq diffusers transformers
2
- # !pip install -Uq gradio
3
- # !pip install -Uq accelerate
4
 
5
- import gradio
6
- from diffusers import StableDiffusionPipeline as pipeline
 
 
 
 
 
 
 
 
 
 
7
  from accelerate import init_empty_weights
 
8
  import torch
9
  import os
10
 
11
- api_key = os.environ['api_key']
12
- my_token = api_key
 
13
 
14
- with init_empty_weights():
15
- pipe = pipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=my_token).to("cuda")
16
 
17
- DROPDOWNS = {
18
- "gustav": " by dan mumford and gustav klimt and john harris and jean delville and victo ngai and josan gonzalez",
19
- "hayao": " by studio ghibli",
20
- "vinny": " painting by Vincent van Gogh",
21
- "danny": " drawn by a child",
22
- "jeff": " by jeff koons",
23
- }
24
 
25
- def image_prompt(prompt, dropdown):
26
- prompt = prompt + DROPDOWNS[dropdown]
27
- return pipe(prompt=prompt, height=512, width=512).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- with gradio.Blocks(css="""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  #go-button {
31
  background-color: white;
32
  border-radius: 0;
@@ -105,11 +274,31 @@ with gradio.Blocks(css="""
105
  margin-left: auto;
106
  margin-right: auto;
107
  }
108
- """) as demo:
109
- dropdown = gradio.Dropdown(["danny", "gustav", "hayao", "vinny", "jeff"], label="choose style...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  prompt = gradio.Textbox(label="image prompt...", elem_id="input-text")
111
  output = gradio.Image(elem_id="output-image")
112
  go_button = gradio.Button("draw it!", elem_id="go-button")
113
  go_button.click(fn=image_prompt, inputs=[prompt, dropdown], outputs=output)
114
 
115
- demo.launch()
 
 
 
 
 
1
 
2
+
3
+ #@title 1. General Setup
4
+
5
+ !pip install -qq diffusers==0.11.1 transformers ftfy accelerate
6
+ !pip install -Uq diffusers transformers
7
+ !pip install -Uq gradio
8
+ !pip install -Uq accelerate
9
+
10
+ from diffusers import StableDiffusionPipeline
11
+ pipeline = StableDiffusionPipeline
12
+
13
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
14
  from accelerate import init_empty_weights
15
+ import gradio
16
  import torch
17
  import os
18
 
19
+ # FOR DEPLOYMENT: uncomment these and delete the notebook_login() below
20
+ # api_key = os.environ['api_key']
21
+ # my_token = api_key
22
 
23
+ from huggingface_hub import notebook_login
24
+ notebook_login()
25
 
26
+ import PIL
27
+ from PIL import Image
 
 
 
 
 
28
 
29
+ def image_grid(imgs, rows, cols):
30
+ assert len(imgs) == rows*cols
31
+
32
+ w, h = imgs[0].size
33
+ grid = Image.new('RGB', size=(cols*w, rows*h))
34
+ grid_w, grid_h = grid.size
35
+
36
+ for i, img in enumerate(imgs):
37
+ grid.paste(img, box=(i%cols*w, i//cols*h))
38
+ return grid
39
+
40
+ pretrained_model_name_or_path = "stabilityai/stable-diffusion-2"
41
+
42
+ from IPython.display import Markdown
43
+ from huggingface_hub import hf_hub_download
44
+
45
+
46
+ #@title 2. Tell it What Concepts to Load
47
+
48
+ models_to_load = [
49
+ "ahx-model-3",
50
+ "ahx-model-5",
51
+ "ahx-model-6",
52
+ "ahx-model-7",
53
+ "ahx-model-8",
54
+ "ahx-model-9",
55
+ "ahx-model-10",
56
+ "ahx-model-11",
57
+ ]
58
+
59
+ models_to_load = [f"sd-concepts-library/{model}" for model in models_to_load]
60
+ completed_concept_pipes = {}
61
+
62
+
63
+ #@title 3. Load the Concepts as Distinct Pipes
64
+
65
+ for repo_id_embeds in models_to_load:
66
+ print(f"loading {repo_id_embeds}")
67
+ print("----------------------")
68
+ # repo_id_embeds = "sd-concepts-library/ahx-model-3"
69
+
70
+ embeds_url = "" #Add the URL or path to a learned_embeds.bin file in case you have one
71
+ placeholder_token_string = "" #Add what is the token string in case you are uploading your own embed
72
+
73
+ downloaded_embedding_folder = "./downloaded_embedding"
74
+ if not os.path.exists(downloaded_embedding_folder):
75
+ os.mkdir(downloaded_embedding_folder)
76
+ if(not embeds_url):
77
+ embeds_path = hf_hub_download(repo_id=repo_id_embeds, filename="learned_embeds.bin")
78
+ token_path = hf_hub_download(repo_id=repo_id_embeds, filename="token_identifier.txt")
79
+ !cp
80
+ downloaded_embedding_folder
81
+ !cp
82
+ downloaded_embedding_folder
83
+ with open(f'{downloaded_embedding_folder}/token_identifier.txt', 'r') as file:
84
+ placeholder_token_string = file.read()
85
+ else:
86
+ !wget -q -O $downloaded_embedding_folder/learned_embeds.bin $embeds_url
87
+
88
+ learned_embeds_path = f"{downloaded_embedding_folder}/learned_embeds.bin"
89
+
90
+ # ----
91
+
92
+ tokenizer = CLIPTokenizer.from_pretrained(
93
+ pretrained_model_name_or_path,
94
+ subfolder="tokenizer",
95
+ )
96
+ text_encoder = CLIPTextModel.from_pretrained(
97
+ pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch.float16
98
+ )
99
+
100
+ # ----
101
+
102
+ def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
103
+ loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
104
+
105
+ # separate token and the embeds
106
+ trained_token = list(loaded_learned_embeds.keys())[0]
107
+ embeds = loaded_learned_embeds[trained_token]
108
+
109
+ # cast to dtype of text_encoder
110
+ dtype = text_encoder.get_input_embeddings().weight.dtype
111
+ embeds.to(dtype)
112
+
113
+ # add the token in tokenizer
114
+ token = token if token is not None else trained_token
115
+ num_added_tokens = tokenizer.add_tokens(token)
116
+ if num_added_tokens == 0:
117
+ raise ValueError(f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer.")
118
+
119
+ # resize the token embeddings
120
+ text_encoder.resize_token_embeddings(len(tokenizer))
121
+
122
+ # get the id for the token and assign the embeds
123
+ token_id = tokenizer.convert_tokens_to_ids(token)
124
+ text_encoder.get_input_embeddings().weight.data[token_id] = embeds
125
+
126
+ load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer)
127
+
128
+ # FOR DEPLOYMENT: add use_auth_token=my_token to pipe keyword args
129
+ # ie --> pipe = pipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=my_token).to("cuda")
130
+ pipe = StableDiffusionPipeline.from_pretrained(
131
+ pretrained_model_name_or_path,
132
+ torch_dtype=torch.float16,
133
+ text_encoder=text_encoder,
134
+ tokenizer=tokenizer,
135
+ ).to("cuda")
136
+
137
+ completed_concept_pipes[repo_id_embeds] = pipe
138
+ print("--> complete !")
139
+ print("----------------------")
140
+
141
+
142
+
143
 
144
+ #@title 4. Print Available Concept Strings
145
+
146
+ print("AVAILABLE CONCEPTS TO SELECT FROM")
147
+ print("copy one and paste below under 'model'")
148
+ print("------------------------------------------------------")
149
+ # list(completed_concept_pipes)
150
+ for model in completed_concept_pipes:
151
+ print(f"{model}")
152
+
153
+
154
+ #@title 5. Optionally Test without Gradio
155
+
156
+ model = "" #@param {type: "string"}
157
+ prompt = "" #@param {type:"string"}
158
+
159
+ if prompt and model:
160
+ if model not in completed_concept_pipes:
161
+ raise ValueError("Invalid Model Name")
162
+
163
+ model_token = model.split("/")[1]
164
+ prompt = f"{prompt} in the style of <{model_token}>"
165
+
166
+ if model == "sd-concepts-library/ahx-model-5":
167
+ prompt = f"{prompt} in the style of "
168
+
169
+ num_samples = 1
170
+ num_rows = 1
171
+
172
+ all_images = []
173
+ pipe = completed_concept_pipes[model]
174
+
175
+ for _ in range(num_rows):
176
+ images = pipe(prompt, num_images_per_prompt=num_samples, height=512, width=512, num_inference_steps=30, guidance_scale=7.5).images
177
+ all_images.extend(images)
178
+
179
+ grid = image_grid(all_images, num_samples, num_rows)
180
+ grid
181
+
182
+
183
+ #@title 6. Define Custom CSS for Gradio
184
+
185
+ use_custom_css = True
186
+
187
+ gradio_css = """
188
+ #output-image {
189
+ border: 1px solid black;
190
+ background-color: white;
191
+ width: 500px;
192
+ display: block;
193
+ margin-left: auto;
194
+ margin-right: auto;
195
+ }
196
+ """
197
+
198
+ gradio_css_alternative = """
199
  #go-button {
200
  background-color: white;
201
  border-radius: 0;
 
274
  margin-left: auto;
275
  margin-right: auto;
276
  }
277
+ """
278
+
279
+
280
+ #@title 7. Build and Launch the Gradio Interface
281
+
282
+ DROPDOWNS = {}
283
+
284
+ for model in models_to_load:
285
+ token = model.split("/")[1]
286
+ DROPDOWNS[model] = f" in the style of <{token}>"
287
+
288
+ if "sd-concepts-library/ahx-model-5" in DROPDOWNS:
289
+ DROPDOWNS["sd-concepts-library/ahx-model-5"] = f"{prompt} in the style of "
290
+
291
+ def image_prompt(prompt, dropdown):
292
+ prompt = prompt + DROPDOWNS[dropdown]
293
+ pipe = completed_concept_pipes[dropdown]
294
+ return pipe(prompt=prompt, height=512, width=512).images[0]
295
+
296
+ with gradio.Blocks(css=gradio_css if use_custom_css else "") as demo:
297
+ dropdown = gradio.Dropdown(list(DROPDOWNS), label="choose style...")
298
  prompt = gradio.Textbox(label="image prompt...", elem_id="input-text")
299
  output = gradio.Image(elem_id="output-image")
300
  go_button = gradio.Button("draw it!", elem_id="go-button")
301
  go_button.click(fn=image_prompt, inputs=[prompt, dropdown], outputs=output)
302
 
303
+ demo.launch(share=True)
304
+