Spaces:
Running
on
Zero
Running
on
Zero
Upload 37 files
Browse files- .gitattributes +2 -0
- html/circular.html +32 -0
- html/denoising.html +16 -0
- html/embeddings.html +75 -0
- html/guidance.html +17 -0
- html/inpainting.html +14 -0
- html/interpolate.html +24 -0
- html/negative.html +15 -0
- html/perturbations.html +35 -0
- html/poke.html +21 -0
- html/seeds.html +25 -0
- images/circular.gif +3 -0
- images/circular.png +0 -0
- images/denoising.png +0 -0
- images/guidance.png +0 -0
- images/inpainting.png +0 -0
- images/interpolate.gif +3 -0
- images/interpolate.png +0 -0
- images/negative.png +0 -0
- images/perturbations.png +0 -0
- images/poke.png +0 -0
- images/seeds.png +0 -0
- run.py +1029 -0
- src/__init__.py +2 -0
- src/pipelines/__init__.py +9 -0
- src/pipelines/circular.py +52 -0
- src/pipelines/embeddings.py +196 -0
- src/pipelines/guidance.py +39 -0
- src/pipelines/inpainting.py +41 -0
- src/pipelines/interpolate.py +51 -0
- src/pipelines/negative.py +37 -0
- src/pipelines/perturbations.py +62 -0
- src/pipelines/poke.py +83 -0
- src/pipelines/seed.py +32 -0
- src/util/__init__.py +3 -0
- src/util/base.py +304 -0
- src/util/clip_config.py +114 -0
- src/util/params.py +96 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
images/circular.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
images/interpolate.gif filter=lfs diff=lfs merge=lfs -text
|
html/circular.html
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<details open>
|
2 |
+
<summary style="background-color: #CE6400; padding-left: 10px;">
|
3 |
+
About
|
4 |
+
</summary>
|
5 |
+
<div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
|
6 |
+
<div style="flex: 1;">
|
7 |
+
<p style="margin-top: 10px">
|
8 |
+
This tab generates a circular trajectory through latent space that begins and ends with the same image.
|
9 |
+
If we specify a large number of steps around the circle, the successive images will be closely related, resulting in a gradual deformation that produces a nice animation.
|
10 |
+
</p>
|
11 |
+
<p style="font-weight: bold;">
|
12 |
+
Additional Controls:
|
13 |
+
</p>
|
14 |
+
<p style="font-weight: bold;">
|
15 |
+
Number of Steps around the Circle:
|
16 |
+
</p>
|
17 |
+
<p>
|
18 |
+
Specify the number of images to produce along the circular path.
|
19 |
+
</p>
|
20 |
+
<p style="font-weight: bold;">
|
21 |
+
Proportion of Circle:
|
22 |
+
</p>
|
23 |
+
<p>
|
24 |
+
Sets the proportion of the circle to cover during image generation.
|
25 |
+
Ranges from 0 to 360 degrees.
|
26 |
+
Using a high step count with a small number of degrees allows you to explore very subtle image transformations.
|
27 |
+
</p>
|
28 |
+
</div>
|
29 |
+
<div style="flex: 1; align-content: center;">
|
30 |
+
<img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/circular.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
|
31 |
+
</div>
|
32 |
+
</div>
|
html/denoising.html
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<details open>
|
2 |
+
<summary style="background-color: #CE6400; padding-left: 10px;">
|
3 |
+
About
|
4 |
+
</summary>
|
5 |
+
<div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
|
6 |
+
<div style="flex: 1;">
|
7 |
+
<p style="margin-top: 10px">
|
8 |
+
This tab displays the intermediate images generated during the denoising process.
|
9 |
+
Seeing these intermediate images provides insight into how the diffusion model progressively adds detail at each step.
|
10 |
+
</p>
|
11 |
+
</div>
|
12 |
+
<div style="flex: 1; align-content: center;">
|
13 |
+
<img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/denoising.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
|
14 |
+
</div>
|
15 |
+
</div>
|
16 |
+
</details>
|
html/embeddings.html
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<head>
|
2 |
+
<link rel="stylesheet" type="text/css" href="styles.css">
|
3 |
+
</head>
|
4 |
+
|
5 |
+
<details open>
|
6 |
+
<summary style="background-color: #CE6400; padding-left: 10px;">
|
7 |
+
About
|
8 |
+
</summary>
|
9 |
+
<div style="background-color: #D87F2B; padding-left: 10px;">
|
10 |
+
<p style="font-weight: bold;">
|
11 |
+
Basic Exploration
|
12 |
+
</p>
|
13 |
+
The top part of the embeddings tab is the 3D plot of semantic feature space.
|
14 |
+
At the bottom of the tab there are expandable panels that can be opened to reveal more advanced features
|
15 |
+
|
16 |
+
<ul>
|
17 |
+
<li>
|
18 |
+
<strong>
|
19 |
+
Explore the 3D semantic feature space:
|
20 |
+
</strong>
|
21 |
+
Click and drag in the 3D semantic feature space to rotate the view.
|
22 |
+
Use the scroll wheel to zoom in and out.
|
23 |
+
Hold down the control key and click and drag to pan the view.
|
24 |
+
</li>
|
25 |
+
<li>
|
26 |
+
<strong>
|
27 |
+
Find the generated image:
|
28 |
+
</strong>
|
29 |
+
Hover over a point in the semantic feature space, and a window will pop up showing a generated image from this one-word prompt.
|
30 |
+
On left click, the image will be downloaded.
|
31 |
+
</li>
|
32 |
+
<li>
|
33 |
+
<strong>
|
34 |
+
Find the embedding vector display:
|
35 |
+
</strong>
|
36 |
+
Hover over a word in the 3D semantic feature space, and an embedding vector display at the bottom of the tab shows the corresponding embedding vector.
|
37 |
+
</li>
|
38 |
+
<li>
|
39 |
+
<strong>
|
40 |
+
Add/remove words from the 3D plot:
|
41 |
+
</strong>
|
42 |
+
Type a word in the Add/Remove word text box below the 3D plot to add a word to the plot, or if the word is already present, remove it from the plot.
|
43 |
+
You can also type multiple words separated by spaces or commas.
|
44 |
+
</li>
|
45 |
+
<li>
|
46 |
+
<strong>
|
47 |
+
Change image for word in the 3D plot:
|
48 |
+
</strong>
|
49 |
+
Type a word in the Change image for word text box below the 3D plot to generate a new image for the corresponding word in the plot.
|
50 |
+
</li>
|
51 |
+
</ul>
|
52 |
+
|
53 |
+
<p style="font-weight: bold; margin-top: 10px;">
|
54 |
+
Semantic Dimensions
|
55 |
+
</p>
|
56 |
+
<ul>
|
57 |
+
<li>
|
58 |
+
<strong>Select a different semantic dimension.</strong><br>
|
59 |
+
Open the Custom Semantic Dimensions panel and choose another dimension for the X or Y or Z axis.
|
60 |
+
See how the display changes.
|
61 |
+
</li>
|
62 |
+
<li>
|
63 |
+
<strong>Alter a semantic dimension.</strong><br>
|
64 |
+
Examine the positive and negative word pairs used to define the semantic dimension.
|
65 |
+
You can change these pairs to alter the semantic dimension.
|
66 |
+
</li>
|
67 |
+
<li>
|
68 |
+
<strong>Define a new semantic dimension.</strong><br>
|
69 |
+
Pick a new semantic dimension that you can define using pairs of opposed words.
|
70 |
+
For example, you could define a "tense" dimension with pairs such as eat/ate, go/went, see/saw, and is/was to contrast present and past tense forms of verbs.
|
71 |
+
</li>
|
72 |
+
</ul>
|
73 |
+
</div>
|
74 |
+
</details>
|
75 |
+
|
html/guidance.html
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<details open>
|
2 |
+
<summary style="background-color: #CE6400; padding-left: 10px;">
|
3 |
+
About
|
4 |
+
</summary>
|
5 |
+
<div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
|
6 |
+
<div style="flex: 1;">
|
7 |
+
<p style="margin-top: 10px">
|
8 |
+
Guidance is responsible for making the target image adhere to the prompt.
|
9 |
+
A higher value enforces this relation, whereas a lower value does not.
|
10 |
+
For example, a guidance scale of 1 produces a distorted grayscale image, whereas 50 produces a distorted, oversaturated image.
|
11 |
+
The default value of 8 produces normal-looking images that reasonably adhere to the prompt.
|
12 |
+
</p>
|
13 |
+
</div>
|
14 |
+
<div style="flex: 1; align-content: center;">
|
15 |
+
<img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/guidance.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
|
16 |
+
</div>
|
17 |
+
</div>
|
html/inpainting.html
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<details open>
|
2 |
+
<summary style="background-color: #CE6400; padding-left: 10px;">
|
3 |
+
About
|
4 |
+
</summary>
|
5 |
+
<div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
|
6 |
+
<div style="flex: 1;">
|
7 |
+
<p style="margin-top: 10px">
|
8 |
+
Unlike poke, which globally alters the target image via a perturbation in the initial latent noise, inpainting alters just the region of the perturbation and allows us to specify the change we want to make.
|
9 |
+
</p>
|
10 |
+
</div>
|
11 |
+
<div style="flex: 1; align-content: center;">
|
12 |
+
<img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/inpainting.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
|
13 |
+
</div>
|
14 |
+
</div>
|
html/interpolate.html
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<details open>
|
2 |
+
<summary style="background-color: #CE6400; padding-left: 10px;">
|
3 |
+
About
|
4 |
+
</summary>
|
5 |
+
<div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
|
6 |
+
<div style="flex: 1;">
|
7 |
+
<p style="margin-top: 10px">
|
8 |
+
This tab generates noise patterns for two text prompts and then interpolates between them, gradually transforming from the first to the second.
|
9 |
+
With a large number of perturbation steps the transformation is very gradual and makes a nice animation.
|
10 |
+
</p>
|
11 |
+
<p style="font-weight: bold;">
|
12 |
+
Additional Controls:
|
13 |
+
</p>
|
14 |
+
<p style="font-weight: bold;">
|
15 |
+
Number of Interpolation Steps:
|
16 |
+
</p>
|
17 |
+
<p>
|
18 |
+
Defines the number of intermediate images to generate between the two prompts.
|
19 |
+
</p>
|
20 |
+
</div>
|
21 |
+
<div style="flex: 1; align-content: center;">
|
22 |
+
<img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/interpolate.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
|
23 |
+
</div>
|
24 |
+
</div>
|
html/negative.html
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<details open>
|
2 |
+
<summary style="background-color: #CE6400; padding-left: 10px;">
|
3 |
+
About
|
4 |
+
</summary>
|
5 |
+
<div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
|
6 |
+
<div style="flex: 1;">
|
7 |
+
<p style="margin-top: 10px">
|
8 |
+
Negative prompts steer images away from unwanted features.
|
9 |
+
For example, “red” as a negative prompt makes the generated image unlikely to have reddish hues.
|
10 |
+
</p>
|
11 |
+
</div>
|
12 |
+
<div style="flex: 1; align-content: center;">
|
13 |
+
<img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/negative.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
|
14 |
+
</div>
|
15 |
+
</div>
|
html/perturbations.html
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<details open>
|
2 |
+
<summary style="background-color: #CE6400; padding-left: 10px;">
|
3 |
+
About
|
4 |
+
</summary>
|
5 |
+
<div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
|
6 |
+
<div style="flex: 1;">
|
7 |
+
<p style="margin-top: 10px">
|
8 |
+
Perturbations enables the exploration of the latent space around a seed.
|
9 |
+
Perturbing the noise from an initial seed towards the noise from a different seed illustrates the variations in images obtainable from a local region of latent space.
|
10 |
+
Using a small perturbation size produces target images that closely resemble the one from the initial seed.
|
11 |
+
Larger perturbations traverse more distance in latent space towards the second seed, resulting in greater variation in the generated images.
|
12 |
+
</p>
|
13 |
+
<p style="font-weight: bold;">
|
14 |
+
Additional Controls:
|
15 |
+
</p>
|
16 |
+
<p style="font-weight: bold;">
|
17 |
+
Number of Perturbations:
|
18 |
+
</p>
|
19 |
+
<p>
|
20 |
+
Specify the number of perturbations to create, i.e., the number of seeds to use. More perturbations produce more images.
|
21 |
+
</p>
|
22 |
+
<p style="font-weight: bold;">
|
23 |
+
Perturbation Size:
|
24 |
+
</p>
|
25 |
+
<p>
|
26 |
+
Controls the perturbation magnitude, ranging from 0 to 1.
|
27 |
+
With a value of 0, all images will match the one from the initial seed.
|
28 |
+
With a value of 1, images will have no connection to the initial seed.
|
29 |
+
A value such as 0.1 is recommended.
|
30 |
+
</p>
|
31 |
+
</div>
|
32 |
+
<div style="flex: 1; align-content: center;">
|
33 |
+
<img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/perturbations.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
|
34 |
+
</div>
|
35 |
+
</div>
|
html/poke.html
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<details open>
|
2 |
+
<summary style="background-color: #CE6400; padding-left: 10px;">
|
3 |
+
About
|
4 |
+
</summary>
|
5 |
+
<div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
|
6 |
+
<div style="flex: 1;">
|
7 |
+
<p style="margin-top: 10px">
|
8 |
+
Poke explores how perturbations in a local region of the initial latent noise impact the target image.
|
9 |
+
A small perturbation to the initial latent noise gets carried through the denoising process, demonstrating the global effect it can produce.
|
10 |
+
</p>
|
11 |
+
<p style="font-weight: bold;">
|
12 |
+
Additional Controls:
|
13 |
+
</p>
|
14 |
+
<p>
|
15 |
+
You can adjust the perturbation through the X, Y, height, and width controls.
|
16 |
+
</p>
|
17 |
+
</div>
|
18 |
+
<div style="flex: 1; align-content: center;">
|
19 |
+
<img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/poke.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
|
20 |
+
</div>
|
21 |
+
</div>
|
html/seeds.html
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<details open>
|
2 |
+
<summary style="background-color: #CE6400; padding-left: 10px;">
|
3 |
+
About
|
4 |
+
</summary>
|
5 |
+
<div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
|
6 |
+
<div style="flex: 1;">
|
7 |
+
<p style="margin-top: 10px">
|
8 |
+
Seeds create the initial noise that gets refined into the target image.
|
9 |
+
Different seeds produce different noise patterns, hence the target image will differ even when prompted by the same text.
|
10 |
+
This tab produces multiple target images from the same text prompt to showcase how changing the seed changes the target image.
|
11 |
+
</p>
|
12 |
+
<p style="font-weight: bold;">
|
13 |
+
Additional Controls:
|
14 |
+
</p>
|
15 |
+
<p style="font-weight: bold;">
|
16 |
+
Number of Seeds:
|
17 |
+
</p>
|
18 |
+
<p>
|
19 |
+
Specify how many seed values to use.
|
20 |
+
</p>
|
21 |
+
</div>
|
22 |
+
<div style="flex: 1; align-content: center;">
|
23 |
+
<img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/seeds.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
|
24 |
+
</div>
|
25 |
+
</div>
|
images/circular.gif
ADDED
![]() |
Git LFS Details
|
images/circular.png
ADDED
![]() |
images/denoising.png
ADDED
![]() |
images/guidance.png
ADDED
![]() |
images/inpainting.png
ADDED
![]() |
images/interpolate.gif
ADDED
![]() |
Git LFS Details
|
images/interpolate.png
ADDED
![]() |
images/negative.png
ADDED
![]() |
images/perturbations.png
ADDED
![]() |
images/poke.png
ADDED
![]() |
images/seeds.png
ADDED
![]() |
run.py
ADDED
@@ -0,0 +1,1029 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import gradio as gr
|
3 |
+
from PIL import Image
|
4 |
+
from src.util import *
|
5 |
+
from io import BytesIO
|
6 |
+
from src.pipelines import *
|
7 |
+
from threading import Thread
|
8 |
+
from dash import Dash, dcc, html, Input, Output, no_update, callback
|
9 |
+
|
10 |
+
app = Dash(__name__)
|
11 |
+
|
12 |
+
app.layout = html.Div(
|
13 |
+
className="container",
|
14 |
+
children=[
|
15 |
+
dcc.Graph(
|
16 |
+
id="graph", figure=fig, clear_on_unhover=True, style={"height": "90vh"}
|
17 |
+
),
|
18 |
+
dcc.Tooltip(id="tooltip"),
|
19 |
+
html.Div(id="word-emb-txt", style={"background-color": "white"}),
|
20 |
+
html.Div(id="word-emb-vis"),
|
21 |
+
html.Div(
|
22 |
+
[
|
23 |
+
html.Button(id="btn-download-image", hidden=True),
|
24 |
+
dcc.Download(id="download-image"),
|
25 |
+
]
|
26 |
+
),
|
27 |
+
],
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
@callback(
|
32 |
+
Output("tooltip", "show"),
|
33 |
+
Output("tooltip", "bbox"),
|
34 |
+
Output("tooltip", "children"),
|
35 |
+
Output("tooltip", "direction"),
|
36 |
+
Output("word-emb-txt", "children"),
|
37 |
+
Output("word-emb-vis", "children"),
|
38 |
+
Input("graph", "hoverData"),
|
39 |
+
)
|
40 |
+
def display_hover(hoverData):
|
41 |
+
if hoverData is None:
|
42 |
+
return False, no_update, no_update, no_update, no_update, no_update
|
43 |
+
|
44 |
+
hover_data = hoverData["points"][0]
|
45 |
+
bbox = hover_data["bbox"]
|
46 |
+
direction = "left"
|
47 |
+
index = hover_data["pointNumber"]
|
48 |
+
|
49 |
+
children = [
|
50 |
+
html.Img(
|
51 |
+
src=images[index],
|
52 |
+
style={"width": "250px"},
|
53 |
+
),
|
54 |
+
html.P(
|
55 |
+
hover_data["text"],
|
56 |
+
style={
|
57 |
+
"color": "black",
|
58 |
+
"font-size": "20px",
|
59 |
+
"text-align": "center",
|
60 |
+
"background-color": "white",
|
61 |
+
"margin": "5px",
|
62 |
+
},
|
63 |
+
),
|
64 |
+
]
|
65 |
+
|
66 |
+
emb_children = [
|
67 |
+
html.Img(
|
68 |
+
src=generate_word_emb_vis(hover_data["text"]),
|
69 |
+
style={"width": "100%", "height": "25px"},
|
70 |
+
),
|
71 |
+
]
|
72 |
+
|
73 |
+
return True, bbox, children, direction, hover_data["text"], emb_children
|
74 |
+
|
75 |
+
|
76 |
+
@callback(
|
77 |
+
Output("download-image", "data"),
|
78 |
+
Input("graph", "clickData"),
|
79 |
+
)
|
80 |
+
def download_image(clickData):
|
81 |
+
|
82 |
+
if clickData is None:
|
83 |
+
return no_update
|
84 |
+
|
85 |
+
click_data = clickData["points"][0]
|
86 |
+
index = click_data["pointNumber"]
|
87 |
+
txt = click_data["text"]
|
88 |
+
|
89 |
+
img_encoded = images[index]
|
90 |
+
img_decoded = base64.b64decode(img_encoded.split(",")[1])
|
91 |
+
img = Image.open(BytesIO(img_decoded))
|
92 |
+
img.save(f"{txt}.png")
|
93 |
+
return dcc.send_file(f"{txt}.png")
|
94 |
+
|
95 |
+
|
96 |
+
with gr.Blocks() as demo:
|
97 |
+
gr.Markdown("## Stable Diffusion Demo")
|
98 |
+
|
99 |
+
with gr.Tab("Latent Space"):
|
100 |
+
|
101 |
+
with gr.TabItem("Denoising"):
|
102 |
+
gr.Markdown("Observe the intermediate images during denoising.")
|
103 |
+
gr.HTML(read_html("DiffusionDemo/html/denoising.html"))
|
104 |
+
|
105 |
+
with gr.Row():
|
106 |
+
with gr.Column():
|
107 |
+
prompt_denoise = gr.Textbox(
|
108 |
+
lines=1,
|
109 |
+
label="Prompt",
|
110 |
+
value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
|
111 |
+
)
|
112 |
+
num_inference_steps_denoise = gr.Slider(
|
113 |
+
minimum=2,
|
114 |
+
maximum=100,
|
115 |
+
step=1,
|
116 |
+
value=8,
|
117 |
+
label="Number of Inference Steps",
|
118 |
+
)
|
119 |
+
|
120 |
+
with gr.Row():
|
121 |
+
seed_denoise = gr.Slider(
|
122 |
+
minimum=0, maximum=100, step=1, value=14, label="Seed"
|
123 |
+
)
|
124 |
+
seed_vis_denoise = gr.Plot(
|
125 |
+
value=generate_seed_vis(14), label="Seed"
|
126 |
+
)
|
127 |
+
|
128 |
+
generate_images_button_denoise = gr.Button("Generate Images")
|
129 |
+
|
130 |
+
with gr.Column():
|
131 |
+
images_output_denoise = gr.Gallery(label="Images", selected_index=0)
|
132 |
+
gif_denoise = gr.Image(label="GIF")
|
133 |
+
zip_output_denoise = gr.File(label="Download ZIP")
|
134 |
+
|
135 |
+
@generate_images_button_denoise.click(
|
136 |
+
inputs=[prompt_denoise, seed_denoise, num_inference_steps_denoise],
|
137 |
+
outputs=[images_output_denoise, gif_denoise, zip_output_denoise],
|
138 |
+
)
|
139 |
+
def generate_images_wrapper(
|
140 |
+
prompt, seed, num_inference_steps, progress=gr.Progress()
|
141 |
+
):
|
142 |
+
images, _ = display_poke_images(
|
143 |
+
prompt, seed, num_inference_steps, poke=False, intermediate=True
|
144 |
+
)
|
145 |
+
fname = "denoising"
|
146 |
+
tab_config = {
|
147 |
+
"Tab": "Denoising",
|
148 |
+
"Prompt": prompt,
|
149 |
+
"Number of Inference Steps": num_inference_steps,
|
150 |
+
"Seed": seed,
|
151 |
+
}
|
152 |
+
export_as_zip(images, fname, tab_config)
|
153 |
+
progress(1, desc="Exporting as gif")
|
154 |
+
export_as_gif(images, filename="denoising.gif")
|
155 |
+
return images, "outputs/denoising.gif", f"outputs/{fname}.zip"
|
156 |
+
|
157 |
+
seed_denoise.change(
|
158 |
+
fn=generate_seed_vis, inputs=[seed_denoise], outputs=[seed_vis_denoise]
|
159 |
+
)
|
160 |
+
|
161 |
+
with gr.TabItem("Seeds"):
|
162 |
+
gr.Markdown(
|
163 |
+
"Understand how different starting points in latent space can lead to different images."
|
164 |
+
)
|
165 |
+
gr.HTML(read_html("DiffusionDemo/html/seeds.html"))
|
166 |
+
|
167 |
+
with gr.Row():
|
168 |
+
with gr.Column():
|
169 |
+
prompt_seed = gr.Textbox(
|
170 |
+
lines=1,
|
171 |
+
label="Prompt",
|
172 |
+
value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
|
173 |
+
)
|
174 |
+
num_images_seed = gr.Slider(
|
175 |
+
minimum=1, maximum=100, step=1, value=5, label="Number of Seeds"
|
176 |
+
)
|
177 |
+
num_inference_steps_seed = gr.Slider(
|
178 |
+
minimum=2,
|
179 |
+
maximum=100,
|
180 |
+
step=1,
|
181 |
+
value=8,
|
182 |
+
label="Number of Inference Steps per Image",
|
183 |
+
)
|
184 |
+
generate_images_button_seed = gr.Button("Generate Images")
|
185 |
+
|
186 |
+
with gr.Column():
|
187 |
+
images_output_seed = gr.Gallery(label="Images", selected_index=0)
|
188 |
+
zip_output_seed = gr.File(label="Download ZIP")
|
189 |
+
|
190 |
+
generate_images_button_seed.click(
|
191 |
+
fn=display_seed_images,
|
192 |
+
inputs=[prompt_seed, num_inference_steps_seed, num_images_seed],
|
193 |
+
outputs=[images_output_seed, zip_output_seed],
|
194 |
+
)
|
195 |
+
|
196 |
+
with gr.TabItem("Perturbations"):
|
197 |
+
gr.Markdown("Explore different perturbations from a point in latent space.")
|
198 |
+
gr.HTML(read_html("DiffusionDemo/html/perturbations.html"))
|
199 |
+
|
200 |
+
with gr.Row():
|
201 |
+
with gr.Column():
|
202 |
+
prompt_perturb = gr.Textbox(
|
203 |
+
lines=1,
|
204 |
+
label="Prompt",
|
205 |
+
value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
|
206 |
+
)
|
207 |
+
num_images_perturb = gr.Slider(
|
208 |
+
minimum=0,
|
209 |
+
maximum=100,
|
210 |
+
step=1,
|
211 |
+
value=5,
|
212 |
+
label="Number of Perturbations",
|
213 |
+
)
|
214 |
+
perturbation_size_perturb = gr.Slider(
|
215 |
+
minimum=0,
|
216 |
+
maximum=1,
|
217 |
+
step=0.1,
|
218 |
+
value=0.1,
|
219 |
+
label="Perturbation Size",
|
220 |
+
)
|
221 |
+
num_inference_steps_perturb = gr.Slider(
|
222 |
+
minimum=2,
|
223 |
+
maximum=100,
|
224 |
+
step=1,
|
225 |
+
value=8,
|
226 |
+
label="Number of Inference Steps per Image",
|
227 |
+
)
|
228 |
+
|
229 |
+
with gr.Row():
|
230 |
+
seed_perturb = gr.Slider(
|
231 |
+
minimum=0, maximum=100, step=1, value=14, label="Seed"
|
232 |
+
)
|
233 |
+
seed_vis_perturb = gr.Plot(
|
234 |
+
value=generate_seed_vis(14), label="Seed"
|
235 |
+
)
|
236 |
+
|
237 |
+
generate_images_button_perturb = gr.Button("Generate Images")
|
238 |
+
|
239 |
+
with gr.Column():
|
240 |
+
images_output_perturb = gr.Gallery(label="Image", selected_index=0)
|
241 |
+
zip_output_perturb = gr.File(label="Download ZIP")
|
242 |
+
|
243 |
+
generate_images_button_perturb.click(
|
244 |
+
fn=display_perturb_images,
|
245 |
+
inputs=[
|
246 |
+
prompt_perturb,
|
247 |
+
seed_perturb,
|
248 |
+
num_inference_steps_perturb,
|
249 |
+
num_images_perturb,
|
250 |
+
perturbation_size_perturb,
|
251 |
+
],
|
252 |
+
outputs=[images_output_perturb, zip_output_perturb],
|
253 |
+
)
|
254 |
+
seed_perturb.change(
|
255 |
+
fn=generate_seed_vis, inputs=[seed_perturb], outputs=[seed_vis_perturb]
|
256 |
+
)
|
257 |
+
|
258 |
+
with gr.TabItem("Circular"):
|
259 |
+
gr.Markdown(
|
260 |
+
"Generate a circular path in latent space and observe how the images vary along the path."
|
261 |
+
)
|
262 |
+
gr.HTML(read_html("DiffusionDemo/html/circular.html"))
|
263 |
+
|
264 |
+
with gr.Row():
|
265 |
+
with gr.Column():
|
266 |
+
prompt_circular = gr.Textbox(
|
267 |
+
lines=1,
|
268 |
+
label="Prompt",
|
269 |
+
value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
|
270 |
+
)
|
271 |
+
num_images_circular = gr.Slider(
|
272 |
+
minimum=2,
|
273 |
+
maximum=100,
|
274 |
+
step=1,
|
275 |
+
value=5,
|
276 |
+
label="Number of Steps around the Circle",
|
277 |
+
)
|
278 |
+
|
279 |
+
with gr.Row():
|
280 |
+
degree_circular = gr.Slider(
|
281 |
+
minimum=0,
|
282 |
+
maximum=360,
|
283 |
+
step=1,
|
284 |
+
value=360,
|
285 |
+
label="Proportion of Circle",
|
286 |
+
info="Enter the value in degrees",
|
287 |
+
)
|
288 |
+
step_size_circular = gr.Textbox(
|
289 |
+
label="Step Size", value=360 / 5
|
290 |
+
)
|
291 |
+
|
292 |
+
num_inference_steps_circular = gr.Slider(
|
293 |
+
minimum=2,
|
294 |
+
maximum=100,
|
295 |
+
step=1,
|
296 |
+
value=8,
|
297 |
+
label="Number of Inference Steps per Image",
|
298 |
+
)
|
299 |
+
|
300 |
+
with gr.Row():
|
301 |
+
seed_circular = gr.Slider(
|
302 |
+
minimum=0, maximum=100, step=1, value=14, label="Seed"
|
303 |
+
)
|
304 |
+
seed_vis_circular = gr.Plot(
|
305 |
+
value=generate_seed_vis(14), label="Seed"
|
306 |
+
)
|
307 |
+
|
308 |
+
generate_images_button_circular = gr.Button("Generate Images")
|
309 |
+
|
310 |
+
with gr.Column():
|
311 |
+
images_output_circular = gr.Gallery(label="Image", selected_index=0)
|
312 |
+
gif_circular = gr.Image(label="GIF")
|
313 |
+
zip_output_circular = gr.File(label="Download ZIP")
|
314 |
+
|
315 |
+
num_images_circular.change(
|
316 |
+
fn=calculate_step_size,
|
317 |
+
inputs=[num_images_circular, degree_circular],
|
318 |
+
outputs=[step_size_circular],
|
319 |
+
)
|
320 |
+
degree_circular.change(
|
321 |
+
fn=calculate_step_size,
|
322 |
+
inputs=[num_images_circular, degree_circular],
|
323 |
+
outputs=[step_size_circular],
|
324 |
+
)
|
325 |
+
generate_images_button_circular.click(
|
326 |
+
fn=display_circular_images,
|
327 |
+
inputs=[
|
328 |
+
prompt_circular,
|
329 |
+
seed_circular,
|
330 |
+
num_inference_steps_circular,
|
331 |
+
num_images_circular,
|
332 |
+
degree_circular,
|
333 |
+
],
|
334 |
+
outputs=[images_output_circular, gif_circular, zip_output_circular],
|
335 |
+
)
|
336 |
+
seed_circular.change(
|
337 |
+
fn=generate_seed_vis, inputs=[seed_circular], outputs=[seed_vis_circular]
|
338 |
+
)
|
339 |
+
|
340 |
+
with gr.TabItem("Poke"):
|
341 |
+
gr.Markdown("Perturb a region in the image and observe the effect.")
|
342 |
+
gr.HTML(read_html("DiffusionDemo/html/poke.html"))
|
343 |
+
|
344 |
+
with gr.Row():
|
345 |
+
with gr.Column():
|
346 |
+
prompt_poke = gr.Textbox(
|
347 |
+
lines=1,
|
348 |
+
label="Prompt",
|
349 |
+
value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
|
350 |
+
)
|
351 |
+
num_inference_steps_poke = gr.Slider(
|
352 |
+
minimum=2,
|
353 |
+
maximum=100,
|
354 |
+
step=1,
|
355 |
+
value=8,
|
356 |
+
label="Number of Inference Steps per Image",
|
357 |
+
)
|
358 |
+
|
359 |
+
with gr.Row():
|
360 |
+
seed_poke = gr.Slider(
|
361 |
+
minimum=0, maximum=100, step=1, value=14, label="Seed"
|
362 |
+
)
|
363 |
+
seed_vis_poke = gr.Plot(
|
364 |
+
value=generate_seed_vis(14), label="Seed"
|
365 |
+
)
|
366 |
+
|
367 |
+
pokeX = gr.Slider(
|
368 |
+
label="pokeX",
|
369 |
+
minimum=0,
|
370 |
+
maximum=64,
|
371 |
+
step=1,
|
372 |
+
value=32,
|
373 |
+
info="X coordinate of poke center",
|
374 |
+
)
|
375 |
+
pokeY = gr.Slider(
|
376 |
+
label="pokeY",
|
377 |
+
minimum=0,
|
378 |
+
maximum=64,
|
379 |
+
step=1,
|
380 |
+
value=32,
|
381 |
+
info="Y coordinate of poke center",
|
382 |
+
)
|
383 |
+
pokeHeight = gr.Slider(
|
384 |
+
label="pokeHeight",
|
385 |
+
minimum=0,
|
386 |
+
maximum=64,
|
387 |
+
step=1,
|
388 |
+
value=8,
|
389 |
+
info="Height of the poke",
|
390 |
+
)
|
391 |
+
pokeWidth = gr.Slider(
|
392 |
+
label="pokeWidth",
|
393 |
+
minimum=0,
|
394 |
+
maximum=64,
|
395 |
+
step=1,
|
396 |
+
value=8,
|
397 |
+
info="Width of the poke",
|
398 |
+
)
|
399 |
+
|
400 |
+
generate_images_button_poke = gr.Button("Generate Images")
|
401 |
+
|
402 |
+
with gr.Column():
|
403 |
+
original_images_output_poke = gr.Image(
|
404 |
+
value=visualize_poke(32, 32, 8, 8)[0], label="Original Image"
|
405 |
+
)
|
406 |
+
poked_images_output_poke = gr.Image(
|
407 |
+
value=visualize_poke(32, 32, 8, 8)[1], label="Poked Image"
|
408 |
+
)
|
409 |
+
zip_output_poke = gr.File(label="Download ZIP")
|
410 |
+
|
411 |
+
pokeX.change(
|
412 |
+
visualize_poke,
|
413 |
+
inputs=[pokeX, pokeY, pokeHeight, pokeWidth],
|
414 |
+
outputs=[original_images_output_poke, poked_images_output_poke],
|
415 |
+
)
|
416 |
+
pokeY.change(
|
417 |
+
visualize_poke,
|
418 |
+
inputs=[pokeX, pokeY, pokeHeight, pokeWidth],
|
419 |
+
outputs=[original_images_output_poke, poked_images_output_poke],
|
420 |
+
)
|
421 |
+
pokeHeight.change(
|
422 |
+
visualize_poke,
|
423 |
+
inputs=[pokeX, pokeY, pokeHeight, pokeWidth],
|
424 |
+
outputs=[original_images_output_poke, poked_images_output_poke],
|
425 |
+
)
|
426 |
+
pokeWidth.change(
|
427 |
+
visualize_poke,
|
428 |
+
inputs=[pokeX, pokeY, pokeHeight, pokeWidth],
|
429 |
+
outputs=[original_images_output_poke, poked_images_output_poke],
|
430 |
+
)
|
431 |
+
seed_poke.change(
|
432 |
+
fn=generate_seed_vis, inputs=[seed_poke], outputs=[seed_vis_poke]
|
433 |
+
)
|
434 |
+
|
435 |
+
@generate_images_button_poke.click(
|
436 |
+
inputs=[
|
437 |
+
prompt_poke,
|
438 |
+
seed_poke,
|
439 |
+
num_inference_steps_poke,
|
440 |
+
pokeX,
|
441 |
+
pokeY,
|
442 |
+
pokeHeight,
|
443 |
+
pokeWidth,
|
444 |
+
],
|
445 |
+
outputs=[
|
446 |
+
original_images_output_poke,
|
447 |
+
poked_images_output_poke,
|
448 |
+
zip_output_poke,
|
449 |
+
],
|
450 |
+
)
|
451 |
+
def generate_images_wrapper(
|
452 |
+
prompt,
|
453 |
+
seed,
|
454 |
+
num_inference_steps,
|
455 |
+
pokeX=pokeX,
|
456 |
+
pokeY=pokeY,
|
457 |
+
pokeHeight=pokeHeight,
|
458 |
+
pokeWidth=pokeWidth,
|
459 |
+
):
|
460 |
+
_, _ = display_poke_images(
|
461 |
+
prompt,
|
462 |
+
seed,
|
463 |
+
num_inference_steps,
|
464 |
+
poke=True,
|
465 |
+
pokeX=pokeX,
|
466 |
+
pokeY=pokeY,
|
467 |
+
pokeHeight=pokeHeight,
|
468 |
+
pokeWidth=pokeWidth,
|
469 |
+
intermediate=False,
|
470 |
+
)
|
471 |
+
images, modImages = visualize_poke(pokeX, pokeY, pokeHeight, pokeWidth)
|
472 |
+
fname = "poke"
|
473 |
+
tab_config = {
|
474 |
+
"Tab": "Poke",
|
475 |
+
"Prompt": prompt,
|
476 |
+
"Number of Inference Steps per Image": num_inference_steps,
|
477 |
+
"Seed": seed,
|
478 |
+
"PokeX": pokeX,
|
479 |
+
"PokeY": pokeY,
|
480 |
+
"PokeHeight": pokeHeight,
|
481 |
+
"PokeWidth": pokeWidth,
|
482 |
+
}
|
483 |
+
imgs_list = []
|
484 |
+
imgs_list.append((images, "Original Image"))
|
485 |
+
imgs_list.append((modImages, "Poked Image"))
|
486 |
+
|
487 |
+
export_as_zip(imgs_list, fname, tab_config)
|
488 |
+
return images, modImages, f"outputs/{fname}.zip"
|
489 |
+
|
490 |
+
with gr.TabItem("Guidance"):
|
491 |
+
gr.Markdown("Observe the effect of different guidance scales.")
|
492 |
+
gr.HTML(read_html("DiffusionDemo/html/guidance.html"))
|
493 |
+
|
494 |
+
with gr.Row():
|
495 |
+
with gr.Column():
|
496 |
+
prompt_guidance = gr.Textbox(
|
497 |
+
lines=1,
|
498 |
+
label="Prompt",
|
499 |
+
value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
|
500 |
+
)
|
501 |
+
num_inference_steps_guidance = gr.Slider(
|
502 |
+
minimum=2,
|
503 |
+
maximum=100,
|
504 |
+
step=1,
|
505 |
+
value=8,
|
506 |
+
label="Number of Inference Steps per Image",
|
507 |
+
)
|
508 |
+
guidance_scale_values = gr.Textbox(
|
509 |
+
lines=1, value="1, 8, 20, 30", label="Guidance Scale Values"
|
510 |
+
)
|
511 |
+
|
512 |
+
with gr.Row():
|
513 |
+
seed_guidance = gr.Slider(
|
514 |
+
minimum=0, maximum=100, step=1, value=14, label="Seed"
|
515 |
+
)
|
516 |
+
seed_vis_guidance = gr.Plot(
|
517 |
+
value=generate_seed_vis(14), label="Seed"
|
518 |
+
)
|
519 |
+
|
520 |
+
generate_images_button_guidance = gr.Button("Generate Images")
|
521 |
+
|
522 |
+
with gr.Column():
|
523 |
+
images_output_guidance = gr.Gallery(
|
524 |
+
label="Images", selected_index=0
|
525 |
+
)
|
526 |
+
zip_output_guidance = gr.File(label="Download ZIP")
|
527 |
+
|
528 |
+
generate_images_button_guidance.click(
|
529 |
+
fn=display_guidance_images,
|
530 |
+
inputs=[
|
531 |
+
prompt_guidance,
|
532 |
+
seed_guidance,
|
533 |
+
num_inference_steps_guidance,
|
534 |
+
guidance_scale_values,
|
535 |
+
],
|
536 |
+
outputs=[images_output_guidance, zip_output_guidance],
|
537 |
+
)
|
538 |
+
seed_guidance.change(
|
539 |
+
fn=generate_seed_vis, inputs=[seed_guidance], outputs=[seed_vis_guidance]
|
540 |
+
)
|
541 |
+
|
542 |
+
with gr.TabItem("Inpainting"):
|
543 |
+
gr.Markdown("Inpaint the image based on the prompt.")
|
544 |
+
gr.HTML(read_html("DiffusionDemo/html/inpainting.html"))
|
545 |
+
|
546 |
+
with gr.Row():
|
547 |
+
with gr.Column():
|
548 |
+
uploaded_img_inpaint = gr.Image(
|
549 |
+
source="upload", tool="sketch", type="pil", label="Upload"
|
550 |
+
)
|
551 |
+
prompt_inpaint = gr.Textbox(
|
552 |
+
lines=1, label="Prompt", value="sunglasses"
|
553 |
+
)
|
554 |
+
num_inference_steps_inpaint = gr.Slider(
|
555 |
+
minimum=2,
|
556 |
+
maximum=100,
|
557 |
+
step=1,
|
558 |
+
value=8,
|
559 |
+
label="Number of Inference Steps per Image",
|
560 |
+
)
|
561 |
+
|
562 |
+
with gr.Row():
|
563 |
+
seed_inpaint = gr.Slider(
|
564 |
+
minimum=0, maximum=100, step=1, value=14, label="Seed"
|
565 |
+
)
|
566 |
+
seed_vis_inpaint = gr.Plot(
|
567 |
+
value=generate_seed_vis(14), label="Seed"
|
568 |
+
)
|
569 |
+
|
570 |
+
inpaint_button = gr.Button("Inpaint")
|
571 |
+
|
572 |
+
with gr.Column():
|
573 |
+
images_output_inpaint = gr.Image(label="Output")
|
574 |
+
zip_output_inpaint = gr.File(label="Download ZIP")
|
575 |
+
|
576 |
+
inpaint_button.click(
|
577 |
+
fn=inpaint,
|
578 |
+
inputs=[
|
579 |
+
uploaded_img_inpaint,
|
580 |
+
num_inference_steps_inpaint,
|
581 |
+
seed_inpaint,
|
582 |
+
prompt_inpaint,
|
583 |
+
],
|
584 |
+
outputs=[images_output_inpaint, zip_output_inpaint],
|
585 |
+
)
|
586 |
+
seed_inpaint.change(
|
587 |
+
fn=generate_seed_vis, inputs=[seed_inpaint], outputs=[seed_vis_inpaint]
|
588 |
+
)
|
589 |
+
|
590 |
+
with gr.Tab("CLIP Space"):
|
591 |
+
|
592 |
+
with gr.TabItem("Embeddings"):
|
593 |
+
gr.Markdown(
|
594 |
+
"Visualize text embedding space in 3D with input texts and output images based on the chosen axis."
|
595 |
+
)
|
596 |
+
gr.HTML(read_html("DiffusionDemo/html/embeddings.html"))
|
597 |
+
|
598 |
+
with gr.Row():
|
599 |
+
output = gr.HTML(
|
600 |
+
f"""
|
601 |
+
<iframe id="html" src="{dash_tunnel}" style="width:100%; height:700px;"></iframe>
|
602 |
+
"""
|
603 |
+
)
|
604 |
+
with gr.Row():
|
605 |
+
word2add_rem = gr.Textbox(lines=1, label="Add/Remove word")
|
606 |
+
word2change = gr.Textbox(lines=1, label="Change image for word")
|
607 |
+
clear_words_button = gr.Button(value="Clear words")
|
608 |
+
|
609 |
+
with gr.Accordion("Custom Semantic Dimensions", open=False):
|
610 |
+
with gr.Row():
|
611 |
+
axis_name_1 = gr.Textbox(label="Axis name", value="gender")
|
612 |
+
which_axis_1 = gr.Dropdown(
|
613 |
+
choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
|
614 |
+
value=whichAxisMap["which_axis_1"],
|
615 |
+
label="Axis direction",
|
616 |
+
)
|
617 |
+
from_words_1 = gr.Textbox(
|
618 |
+
lines=1,
|
619 |
+
label="Positive",
|
620 |
+
value="prince husband father son uncle",
|
621 |
+
)
|
622 |
+
to_words_1 = gr.Textbox(
|
623 |
+
lines=1,
|
624 |
+
label="Negative",
|
625 |
+
value="princess wife mother daughter aunt",
|
626 |
+
)
|
627 |
+
submit_1 = gr.Button("Submit")
|
628 |
+
|
629 |
+
with gr.Row():
|
630 |
+
axis_name_2 = gr.Textbox(label="Axis name", value="age")
|
631 |
+
which_axis_2 = gr.Dropdown(
|
632 |
+
choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
|
633 |
+
value=whichAxisMap["which_axis_2"],
|
634 |
+
label="Axis direction",
|
635 |
+
)
|
636 |
+
from_words_2 = gr.Textbox(
|
637 |
+
lines=1, label="Positive", value="man woman king queen father"
|
638 |
+
)
|
639 |
+
to_words_2 = gr.Textbox(
|
640 |
+
lines=1, label="Negative", value="boy girl prince princess son"
|
641 |
+
)
|
642 |
+
submit_2 = gr.Button("Submit")
|
643 |
+
|
644 |
+
with gr.Row():
|
645 |
+
axis_name_3 = gr.Textbox(label="Axis name", value="residual")
|
646 |
+
which_axis_3 = gr.Dropdown(
|
647 |
+
choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
|
648 |
+
value=whichAxisMap["which_axis_3"],
|
649 |
+
label="Axis direction",
|
650 |
+
)
|
651 |
+
from_words_3 = gr.Textbox(lines=1, label="Positive")
|
652 |
+
to_words_3 = gr.Textbox(lines=1, label="Negative")
|
653 |
+
submit_3 = gr.Button("Submit")
|
654 |
+
|
655 |
+
with gr.Row():
|
656 |
+
axis_name_4 = gr.Textbox(label="Axis name", value="number")
|
657 |
+
which_axis_4 = gr.Dropdown(
|
658 |
+
choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
|
659 |
+
value=whichAxisMap["which_axis_4"],
|
660 |
+
label="Axis direction",
|
661 |
+
)
|
662 |
+
from_words_4 = gr.Textbox(
|
663 |
+
lines=1,
|
664 |
+
label="Positive",
|
665 |
+
value="boys girls cats puppies computers",
|
666 |
+
)
|
667 |
+
to_words_4 = gr.Textbox(
|
668 |
+
lines=1, label="Negative", value="boy girl cat puppy computer"
|
669 |
+
)
|
670 |
+
submit_4 = gr.Button("Submit")
|
671 |
+
|
672 |
+
with gr.Row():
|
673 |
+
axis_name_5 = gr.Textbox(label="Axis name", value="royalty")
|
674 |
+
which_axis_5 = gr.Dropdown(
|
675 |
+
choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
|
676 |
+
value=whichAxisMap["which_axis_5"],
|
677 |
+
label="Axis direction",
|
678 |
+
)
|
679 |
+
from_words_5 = gr.Textbox(
|
680 |
+
lines=1,
|
681 |
+
label="Positive",
|
682 |
+
value="king queen prince princess duchess",
|
683 |
+
)
|
684 |
+
to_words_5 = gr.Textbox(
|
685 |
+
lines=1, label="Negative", value="man woman boy girl woman"
|
686 |
+
)
|
687 |
+
submit_5 = gr.Button("Submit")
|
688 |
+
|
689 |
+
with gr.Row():
|
690 |
+
axis_name_6 = gr.Textbox(label="Axis name")
|
691 |
+
which_axis_6 = gr.Dropdown(
|
692 |
+
choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
|
693 |
+
value=whichAxisMap["which_axis_6"],
|
694 |
+
label="Axis direction",
|
695 |
+
)
|
696 |
+
from_words_6 = gr.Textbox(lines=1, label="Positive")
|
697 |
+
to_words_6 = gr.Textbox(lines=1, label="Negative")
|
698 |
+
submit_6 = gr.Button("Submit")
|
699 |
+
|
700 |
+
@word2add_rem.submit(inputs=[word2add_rem], outputs=[output, word2add_rem])
|
701 |
+
def add_rem_word_and_clear(words):
|
702 |
+
return add_rem_word(words), ""
|
703 |
+
|
704 |
+
@word2change.submit(inputs=[word2change], outputs=[output, word2change])
|
705 |
+
def change_word_and_clear(word):
|
706 |
+
return change_word(word), ""
|
707 |
+
|
708 |
+
clear_words_button.click(fn=clear_words, outputs=[output])
|
709 |
+
|
710 |
+
@submit_1.click(
|
711 |
+
inputs=[axis_name_1, which_axis_1, from_words_1, to_words_1],
|
712 |
+
outputs=[
|
713 |
+
output,
|
714 |
+
which_axis_2,
|
715 |
+
which_axis_3,
|
716 |
+
which_axis_4,
|
717 |
+
which_axis_5,
|
718 |
+
which_axis_6,
|
719 |
+
],
|
720 |
+
)
|
721 |
+
def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
|
722 |
+
|
723 |
+
for ax in whichAxisMap:
|
724 |
+
if whichAxisMap[ax] == which_axis:
|
725 |
+
whichAxisMap[ax] = "---"
|
726 |
+
|
727 |
+
whichAxisMap["which_axis_1"] = which_axis
|
728 |
+
return (
|
729 |
+
set_axis(axis_name, which_axis, from_words, to_words),
|
730 |
+
whichAxisMap["which_axis_2"],
|
731 |
+
whichAxisMap["which_axis_3"],
|
732 |
+
whichAxisMap["which_axis_4"],
|
733 |
+
whichAxisMap["which_axis_5"],
|
734 |
+
whichAxisMap["which_axis_6"],
|
735 |
+
)
|
736 |
+
|
737 |
+
@submit_2.click(
|
738 |
+
inputs=[axis_name_2, which_axis_2, from_words_2, to_words_2],
|
739 |
+
outputs=[
|
740 |
+
output,
|
741 |
+
which_axis_1,
|
742 |
+
which_axis_3,
|
743 |
+
which_axis_4,
|
744 |
+
which_axis_5,
|
745 |
+
which_axis_6,
|
746 |
+
],
|
747 |
+
)
|
748 |
+
def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
|
749 |
+
|
750 |
+
for ax in whichAxisMap:
|
751 |
+
if whichAxisMap[ax] == which_axis:
|
752 |
+
whichAxisMap[ax] = "---"
|
753 |
+
|
754 |
+
whichAxisMap["which_axis_2"] = which_axis
|
755 |
+
return (
|
756 |
+
set_axis(axis_name, which_axis, from_words, to_words),
|
757 |
+
whichAxisMap["which_axis_1"],
|
758 |
+
whichAxisMap["which_axis_3"],
|
759 |
+
whichAxisMap["which_axis_4"],
|
760 |
+
whichAxisMap["which_axis_5"],
|
761 |
+
whichAxisMap["which_axis_6"],
|
762 |
+
)
|
763 |
+
|
764 |
+
@submit_3.click(
|
765 |
+
inputs=[axis_name_3, which_axis_3, from_words_3, to_words_3],
|
766 |
+
outputs=[
|
767 |
+
output,
|
768 |
+
which_axis_1,
|
769 |
+
which_axis_2,
|
770 |
+
which_axis_4,
|
771 |
+
which_axis_5,
|
772 |
+
which_axis_6,
|
773 |
+
],
|
774 |
+
)
|
775 |
+
def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
|
776 |
+
|
777 |
+
for ax in whichAxisMap:
|
778 |
+
if whichAxisMap[ax] == which_axis:
|
779 |
+
whichAxisMap[ax] = "---"
|
780 |
+
|
781 |
+
whichAxisMap["which_axis_3"] = which_axis
|
782 |
+
return (
|
783 |
+
set_axis(axis_name, which_axis, from_words, to_words),
|
784 |
+
whichAxisMap["which_axis_1"],
|
785 |
+
whichAxisMap["which_axis_2"],
|
786 |
+
whichAxisMap["which_axis_4"],
|
787 |
+
whichAxisMap["which_axis_5"],
|
788 |
+
whichAxisMap["which_axis_6"],
|
789 |
+
)
|
790 |
+
|
791 |
+
@submit_4.click(
|
792 |
+
inputs=[axis_name_4, which_axis_4, from_words_4, to_words_4],
|
793 |
+
outputs=[
|
794 |
+
output,
|
795 |
+
which_axis_1,
|
796 |
+
which_axis_2,
|
797 |
+
which_axis_3,
|
798 |
+
which_axis_5,
|
799 |
+
which_axis_6,
|
800 |
+
],
|
801 |
+
)
|
802 |
+
def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
|
803 |
+
|
804 |
+
for ax in whichAxisMap:
|
805 |
+
if whichAxisMap[ax] == which_axis:
|
806 |
+
whichAxisMap[ax] = "---"
|
807 |
+
|
808 |
+
whichAxisMap["which_axis_4"] = which_axis
|
809 |
+
return (
|
810 |
+
set_axis(axis_name, which_axis, from_words, to_words),
|
811 |
+
whichAxisMap["which_axis_1"],
|
812 |
+
whichAxisMap["which_axis_2"],
|
813 |
+
whichAxisMap["which_axis_3"],
|
814 |
+
whichAxisMap["which_axis_5"],
|
815 |
+
whichAxisMap["which_axis_6"],
|
816 |
+
)
|
817 |
+
|
818 |
+
@submit_5.click(
|
819 |
+
inputs=[axis_name_5, which_axis_5, from_words_5, to_words_5],
|
820 |
+
outputs=[
|
821 |
+
output,
|
822 |
+
which_axis_1,
|
823 |
+
which_axis_2,
|
824 |
+
which_axis_3,
|
825 |
+
which_axis_4,
|
826 |
+
which_axis_6,
|
827 |
+
],
|
828 |
+
)
|
829 |
+
def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
|
830 |
+
|
831 |
+
for ax in whichAxisMap:
|
832 |
+
if whichAxisMap[ax] == which_axis:
|
833 |
+
whichAxisMap[ax] = "---"
|
834 |
+
|
835 |
+
whichAxisMap["which_axis_5"] = which_axis
|
836 |
+
return (
|
837 |
+
set_axis(axis_name, which_axis, from_words, to_words),
|
838 |
+
whichAxisMap["which_axis_1"],
|
839 |
+
whichAxisMap["which_axis_2"],
|
840 |
+
whichAxisMap["which_axis_3"],
|
841 |
+
whichAxisMap["which_axis_4"],
|
842 |
+
whichAxisMap["which_axis_6"],
|
843 |
+
)
|
844 |
+
|
845 |
+
@submit_6.click(
|
846 |
+
inputs=[axis_name_6, which_axis_6, from_words_6, to_words_6],
|
847 |
+
outputs=[
|
848 |
+
output,
|
849 |
+
which_axis_1,
|
850 |
+
which_axis_2,
|
851 |
+
which_axis_3,
|
852 |
+
which_axis_4,
|
853 |
+
which_axis_5,
|
854 |
+
],
|
855 |
+
)
|
856 |
+
def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
|
857 |
+
|
858 |
+
for ax in whichAxisMap:
|
859 |
+
if whichAxisMap[ax] == which_axis:
|
860 |
+
whichAxisMap[ax] = "---"
|
861 |
+
|
862 |
+
whichAxisMap["which_axis_6"] = which_axis
|
863 |
+
return (
|
864 |
+
set_axis(axis_name, which_axis, from_words, to_words),
|
865 |
+
whichAxisMap["which_axis_1"],
|
866 |
+
whichAxisMap["which_axis_2"],
|
867 |
+
whichAxisMap["which_axis_3"],
|
868 |
+
whichAxisMap["which_axis_4"],
|
869 |
+
whichAxisMap["which_axis_5"],
|
870 |
+
)
|
871 |
+
|
872 |
+
with gr.TabItem("Interpolate"):
|
873 |
+
gr.Markdown(
|
874 |
+
"Interpolate between the first and the second prompt, and observe how the output changes."
|
875 |
+
)
|
876 |
+
gr.HTML(read_html("DiffusionDemo/html/interpolate.html"))
|
877 |
+
|
878 |
+
with gr.Row():
|
879 |
+
with gr.Column():
|
880 |
+
promptA = gr.Textbox(
|
881 |
+
lines=1,
|
882 |
+
label="First Prompt",
|
883 |
+
value="Self-portrait oil painting, a beautiful man with golden hair, 8k",
|
884 |
+
)
|
885 |
+
promptB = gr.Textbox(
|
886 |
+
lines=1,
|
887 |
+
label="Second Prompt",
|
888 |
+
value="Self-portrait oil painting, a beautiful woman with golden hair, 8k",
|
889 |
+
)
|
890 |
+
num_images_interpolate = gr.Slider(
|
891 |
+
minimum=0,
|
892 |
+
maximum=100,
|
893 |
+
step=1,
|
894 |
+
value=5,
|
895 |
+
label="Number of Interpolation Steps",
|
896 |
+
)
|
897 |
+
num_inference_steps_interpolate = gr.Slider(
|
898 |
+
minimum=2,
|
899 |
+
maximum=100,
|
900 |
+
step=1,
|
901 |
+
value=8,
|
902 |
+
label="Number of Inference Steps per Image",
|
903 |
+
)
|
904 |
+
|
905 |
+
with gr.Row():
|
906 |
+
seed_interpolate = gr.Slider(
|
907 |
+
minimum=0, maximum=100, step=1, value=14, label="Seed"
|
908 |
+
)
|
909 |
+
seed_vis_interpolate = gr.Plot(
|
910 |
+
value=generate_seed_vis(14), label="Seed"
|
911 |
+
)
|
912 |
+
|
913 |
+
generate_images_button_interpolate = gr.Button("Generate Images")
|
914 |
+
|
915 |
+
with gr.Column():
|
916 |
+
images_output_interpolate = gr.Gallery(
|
917 |
+
label="Interpolated Images", selected_index=0
|
918 |
+
)
|
919 |
+
gif_interpolate = gr.Image(label="GIF")
|
920 |
+
zip_output_interpolate = gr.File(label="Download ZIP")
|
921 |
+
|
922 |
+
generate_images_button_interpolate.click(
|
923 |
+
fn=display_interpolate_images,
|
924 |
+
inputs=[
|
925 |
+
seed_interpolate,
|
926 |
+
promptA,
|
927 |
+
promptB,
|
928 |
+
num_inference_steps_interpolate,
|
929 |
+
num_images_interpolate,
|
930 |
+
],
|
931 |
+
outputs=[
|
932 |
+
images_output_interpolate,
|
933 |
+
gif_interpolate,
|
934 |
+
zip_output_interpolate,
|
935 |
+
],
|
936 |
+
)
|
937 |
+
seed_interpolate.change(
|
938 |
+
fn=generate_seed_vis,
|
939 |
+
inputs=[seed_interpolate],
|
940 |
+
outputs=[seed_vis_interpolate],
|
941 |
+
)
|
942 |
+
|
943 |
+
with gr.TabItem("Negative"):
|
944 |
+
gr.Markdown("Observe the effect of negative prompts.")
|
945 |
+
gr.HTML(read_html("DiffusionDemo/html/negative.html"))
|
946 |
+
|
947 |
+
with gr.Row():
|
948 |
+
with gr.Column():
|
949 |
+
prompt_negative = gr.Textbox(
|
950 |
+
lines=1,
|
951 |
+
label="Prompt",
|
952 |
+
value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
|
953 |
+
)
|
954 |
+
neg_prompt = gr.Textbox(
|
955 |
+
lines=1, label="Negative Prompt", value="Yellow"
|
956 |
+
)
|
957 |
+
num_inference_steps_negative = gr.Slider(
|
958 |
+
minimum=2,
|
959 |
+
maximum=100,
|
960 |
+
step=1,
|
961 |
+
value=8,
|
962 |
+
label="Number of Inference Steps per Image",
|
963 |
+
)
|
964 |
+
|
965 |
+
with gr.Row():
|
966 |
+
seed_negative = gr.Slider(
|
967 |
+
minimum=0, maximum=100, step=1, value=14, label="Seed"
|
968 |
+
)
|
969 |
+
seed_vis_negative = gr.Plot(
|
970 |
+
value=generate_seed_vis(14), label="Seed"
|
971 |
+
)
|
972 |
+
|
973 |
+
generate_images_button_negative = gr.Button("Generate Images")
|
974 |
+
|
975 |
+
with gr.Column():
|
976 |
+
images_output_negative = gr.Image(
|
977 |
+
label="Image without Negative Prompt"
|
978 |
+
)
|
979 |
+
images_neg_output_negative = gr.Image(
|
980 |
+
label="Image with Negative Prompt"
|
981 |
+
)
|
982 |
+
zip_output_negative = gr.File(label="Download ZIP")
|
983 |
+
|
984 |
+
seed_negative.change(
|
985 |
+
fn=generate_seed_vis, inputs=[seed_negative], outputs=[seed_vis_negative]
|
986 |
+
)
|
987 |
+
generate_images_button_negative.click(
|
988 |
+
fn=display_negative_images,
|
989 |
+
inputs=[
|
990 |
+
prompt_negative,
|
991 |
+
seed_negative,
|
992 |
+
num_inference_steps_negative,
|
993 |
+
neg_prompt,
|
994 |
+
],
|
995 |
+
outputs=[
|
996 |
+
images_output_negative,
|
997 |
+
images_neg_output_negative,
|
998 |
+
zip_output_negative,
|
999 |
+
],
|
1000 |
+
)
|
1001 |
+
|
1002 |
+
with gr.Tab("Credits"):
|
1003 |
+
gr.Markdown("""
|
1004 |
+
Author: Adithya Kameswara Rao, Carnegie Mellon University.
|
1005 |
+
|
1006 |
+
Advisor: David S. Touretzky, Carnegie Mellon University.
|
1007 |
+
|
1008 |
+
This work was funded by a grant from NEOM Company, and by National Science Foundation award IIS-2112633.
|
1009 |
+
""")
|
1010 |
+
|
1011 |
+
|
1012 |
+
def run_dash():
|
1013 |
+
app.run(host="127.0.0.1", port="8000")
|
1014 |
+
|
1015 |
+
|
1016 |
+
def run_gradio():
|
1017 |
+
demo.queue()
|
1018 |
+
_, _, public_url = demo.launch(share=True)
|
1019 |
+
return public_url
|
1020 |
+
|
1021 |
+
|
1022 |
+
# if __name__ == "__main__":
|
1023 |
+
# thread = Thread(target=run_dash)
|
1024 |
+
# thread.daemon = True
|
1025 |
+
# thread.start()
|
1026 |
+
# try:
|
1027 |
+
# run_gradio()
|
1028 |
+
# except KeyboardInterrupt:
|
1029 |
+
# print("Server closed")
|
src/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from . import util
|
2 |
+
from . import pipelines
|
src/pipelines/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .circular import *
|
2 |
+
from .embeddings import *
|
3 |
+
from .interpolate import *
|
4 |
+
from .poke import *
|
5 |
+
from .seed import *
|
6 |
+
from .perturbations import *
|
7 |
+
from .negative import *
|
8 |
+
from .guidance import *
|
9 |
+
from .inpainting import *
|
src/pipelines/circular.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import gradio as gr
|
4 |
+
from src.util.base import *
|
5 |
+
from src.util.params import *
|
6 |
+
|
7 |
+
|
8 |
+
def display_circular_images(
|
9 |
+
prompt, seed, num_inference_steps, num_images, degree, progress=gr.Progress()
|
10 |
+
):
|
11 |
+
np.random.seed(seed)
|
12 |
+
text_embeddings = get_text_embeddings(prompt)
|
13 |
+
|
14 |
+
latents_x = generate_latents(seed)
|
15 |
+
latents_y = generate_latents(seed * np.random.randint(0, 100000))
|
16 |
+
|
17 |
+
scale_x = torch.cos(
|
18 |
+
torch.linspace(0, 2, num_images) * torch.pi * (degree / 360)
|
19 |
+
).to(torch_device)
|
20 |
+
scale_y = torch.sin(
|
21 |
+
torch.linspace(0, 2, num_images) * torch.pi * (degree / 360)
|
22 |
+
).to(torch_device)
|
23 |
+
|
24 |
+
noise_x = torch.tensordot(scale_x, latents_x, dims=0)
|
25 |
+
noise_y = torch.tensordot(scale_y, latents_y, dims=0)
|
26 |
+
|
27 |
+
noise = noise_x + noise_y
|
28 |
+
|
29 |
+
progress(0)
|
30 |
+
images = []
|
31 |
+
for i in range(num_images):
|
32 |
+
progress(i / num_images)
|
33 |
+
image = generate_images(noise[i], text_embeddings, num_inference_steps)
|
34 |
+
images.append((image, "{}".format(i)))
|
35 |
+
|
36 |
+
progress(1, desc="Exporting as gif")
|
37 |
+
export_as_gif(images, filename="circular.gif")
|
38 |
+
|
39 |
+
fname = "circular"
|
40 |
+
tab_config = {
|
41 |
+
"Tab": "Circular",
|
42 |
+
"Prompt": prompt,
|
43 |
+
"Number of Steps around the Circle": num_images,
|
44 |
+
"Proportion of Circle": degree,
|
45 |
+
"Number of Inference Steps per Image": num_inference_steps,
|
46 |
+
"Seed": seed,
|
47 |
+
}
|
48 |
+
export_as_zip(images, fname, tab_config)
|
49 |
+
return images, "outputs/circular.gif", f"outputs/{fname}.zip"
|
50 |
+
|
51 |
+
|
52 |
+
__all__ = ["display_circular_images"]
|
src/pipelines/embeddings.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import numpy as np
|
3 |
+
import gradio as gr
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from diffusers import StableDiffusionPipeline
|
6 |
+
|
7 |
+
import base64
|
8 |
+
from io import BytesIO
|
9 |
+
import plotly.express as px
|
10 |
+
|
11 |
+
from src.util.base import *
|
12 |
+
from src.util.params import *
|
13 |
+
from src.util.clip_config import *
|
14 |
+
|
15 |
+
age = get_axis_embeddings(young, old)
|
16 |
+
gender = get_axis_embeddings(masculine, feminine)
|
17 |
+
royalty = get_axis_embeddings(common, elite)
|
18 |
+
|
19 |
+
images = []
|
20 |
+
for example in examples:
|
21 |
+
image = pipe(
|
22 |
+
prompt=example,
|
23 |
+
num_inference_steps=num_inference_steps,
|
24 |
+
guidance_scale=guidance_scale,
|
25 |
+
).images[0]
|
26 |
+
buffer = BytesIO()
|
27 |
+
image.save(buffer, format="JPEG")
|
28 |
+
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
29 |
+
images.append("data:image/jpeg;base64, " + encoded_image)
|
30 |
+
|
31 |
+
axis = np.vstack([gender, royalty, age])
|
32 |
+
axis[1] = calculate_residual(axis, axis_names)
|
33 |
+
|
34 |
+
coords = get_concat_embeddings(examples) @ axis.T
|
35 |
+
coords[:, 1] = 5 * (1.0 - coords[:, 1])
|
36 |
+
|
37 |
+
|
38 |
+
def update_fig():
|
39 |
+
global coords, examples, fig
|
40 |
+
fig.data[0].x = coords[:, 0]
|
41 |
+
fig.data[0].y = coords[:, 1]
|
42 |
+
fig.data[0].z = coords[:, 2]
|
43 |
+
fig.data[0].text = examples
|
44 |
+
|
45 |
+
return f"""
|
46 |
+
<script>
|
47 |
+
document.getElementById("html").src += "?rand={random.random()}"
|
48 |
+
</script>
|
49 |
+
<iframe id="html" src={dash_tunnel} style="width:100%; height:725px;"></iframe>
|
50 |
+
"""
|
51 |
+
|
52 |
+
|
53 |
+
def add_word(new_example):
|
54 |
+
global coords, images, examples
|
55 |
+
new_coord = get_concat_embeddings([new_example]) @ axis.T
|
56 |
+
new_coord[:, 1] = 5 * (1.0 - new_coord[:, 1])
|
57 |
+
coords = np.vstack([coords, new_coord])
|
58 |
+
|
59 |
+
image = pipe(
|
60 |
+
prompt=new_example,
|
61 |
+
num_inference_steps=num_inference_steps,
|
62 |
+
guidance_scale=guidance_scale,
|
63 |
+
).images[0]
|
64 |
+
buffer = BytesIO()
|
65 |
+
image.save(buffer, format="JPEG")
|
66 |
+
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
67 |
+
images.append("data:image/jpeg;base64, " + encoded_image)
|
68 |
+
examples.append(new_example)
|
69 |
+
return update_fig()
|
70 |
+
|
71 |
+
|
72 |
+
def remove_word(new_example):
|
73 |
+
global coords, images, examples
|
74 |
+
examplesMap = {example: index for index, example in enumerate(examples)}
|
75 |
+
index = examplesMap[new_example]
|
76 |
+
|
77 |
+
coords = np.delete(coords, index, 0)
|
78 |
+
images.pop(index)
|
79 |
+
examples.pop(index)
|
80 |
+
return update_fig()
|
81 |
+
|
82 |
+
|
83 |
+
def add_rem_word(new_examples):
|
84 |
+
global examples
|
85 |
+
new_examples = new_examples.replace(",", " ").split()
|
86 |
+
|
87 |
+
for new_example in new_examples:
|
88 |
+
if new_example in examples:
|
89 |
+
remove_word(new_example)
|
90 |
+
gr.Info("Removed {}".format(new_example))
|
91 |
+
else:
|
92 |
+
tokens = tokenizer.encode(new_example)
|
93 |
+
if len(tokens) != 3:
|
94 |
+
gr.Warning(f"{new_example} not found in embeddings")
|
95 |
+
else:
|
96 |
+
add_word(new_example)
|
97 |
+
gr.Info("Added {}".format(new_example))
|
98 |
+
|
99 |
+
return update_fig()
|
100 |
+
|
101 |
+
|
102 |
+
def set_axis(axis_name, which_axis, from_words, to_words):
|
103 |
+
global coords, examples, fig, axis_names
|
104 |
+
|
105 |
+
if axis_name != "residual":
|
106 |
+
from_words, to_words = (
|
107 |
+
from_words.replace(",", " ").split(),
|
108 |
+
to_words.replace(",", " ").split(),
|
109 |
+
)
|
110 |
+
axis_emb = get_axis_embeddings(from_words, to_words)
|
111 |
+
axis[axisMap[which_axis]] = axis_emb
|
112 |
+
axis_names[axisMap[which_axis]] = axis_name
|
113 |
+
|
114 |
+
for i, name in enumerate(axis_names):
|
115 |
+
if name == "residual":
|
116 |
+
axis[i] = calculate_residual(axis, axis_names, from_words, to_words, i)
|
117 |
+
axis_names[i] = "residual"
|
118 |
+
else:
|
119 |
+
residual = calculate_residual(
|
120 |
+
axis, axis_names, residual_axis=axisMap[which_axis]
|
121 |
+
)
|
122 |
+
axis[axisMap[which_axis]] = residual
|
123 |
+
axis_names[axisMap[which_axis]] = axis_name
|
124 |
+
|
125 |
+
coords = get_concat_embeddings(examples) @ axis.T
|
126 |
+
coords[:, 1] = 5 * (1.0 - coords[:, 1])
|
127 |
+
|
128 |
+
fig.update_layout(
|
129 |
+
scene=dict(
|
130 |
+
xaxis_title=axis_names[0],
|
131 |
+
yaxis_title=axis_names[1],
|
132 |
+
zaxis_title=axis_names[2],
|
133 |
+
)
|
134 |
+
)
|
135 |
+
return update_fig()
|
136 |
+
|
137 |
+
|
138 |
+
def change_word(examples):
|
139 |
+
examples = examples.replace(",", " ").split()
|
140 |
+
|
141 |
+
for example in examples:
|
142 |
+
remove_word(example)
|
143 |
+
add_word(example)
|
144 |
+
gr.Info("Changed image for {}".format(example))
|
145 |
+
|
146 |
+
return update_fig()
|
147 |
+
|
148 |
+
|
149 |
+
def clear_words():
|
150 |
+
while examples:
|
151 |
+
remove_word(examples[-1])
|
152 |
+
return update_fig()
|
153 |
+
|
154 |
+
|
155 |
+
def generate_word_emb_vis(prompt):
|
156 |
+
buf = BytesIO()
|
157 |
+
emb = get_word_embeddings(prompt).reshape(77, 768)[1]
|
158 |
+
plt.imsave(buf, [emb], cmap="inferno")
|
159 |
+
img = "data:image/jpeg;base64, " + base64.b64encode(buf.getvalue()).decode("utf-8")
|
160 |
+
return img
|
161 |
+
|
162 |
+
|
163 |
+
fig = px.scatter_3d(
|
164 |
+
x=coords[:, 0],
|
165 |
+
y=coords[:, 1],
|
166 |
+
z=coords[:, 2],
|
167 |
+
labels={
|
168 |
+
"x": axis_names[0],
|
169 |
+
"y": axis_names[1],
|
170 |
+
"z": axis_names[2],
|
171 |
+
},
|
172 |
+
text=examples,
|
173 |
+
height=750,
|
174 |
+
)
|
175 |
+
|
176 |
+
fig.update_layout(
|
177 |
+
margin=dict(l=0, r=0, b=0, t=0), scene_camera=dict(eye=dict(x=2, y=2, z=0.1))
|
178 |
+
)
|
179 |
+
|
180 |
+
fig.update_traces(hoverinfo="none", hovertemplate=None)
|
181 |
+
|
182 |
+
__all__ = [
|
183 |
+
"fig",
|
184 |
+
"update_fig",
|
185 |
+
"coords",
|
186 |
+
"images",
|
187 |
+
"examples",
|
188 |
+
"add_word",
|
189 |
+
"remove_word",
|
190 |
+
"add_rem_word",
|
191 |
+
"change_word",
|
192 |
+
"clear_words",
|
193 |
+
"generate_word_emb_vis",
|
194 |
+
"set_axis",
|
195 |
+
"axis",
|
196 |
+
]
|
src/pipelines/guidance.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from src.util.base import *
|
3 |
+
from src.util.params import *
|
4 |
+
|
5 |
+
|
6 |
+
def display_guidance_images(
|
7 |
+
prompt, seed, num_inference_steps, guidance_values, progress=gr.Progress()
|
8 |
+
):
|
9 |
+
text_embeddings = get_text_embeddings(prompt)
|
10 |
+
latents = generate_latents(seed)
|
11 |
+
|
12 |
+
progress(0)
|
13 |
+
images = []
|
14 |
+
guidance_values = guidance_values.replace(",", " ").split()
|
15 |
+
num_images = len(guidance_values)
|
16 |
+
|
17 |
+
for i in range(num_images):
|
18 |
+
progress(i / num_images)
|
19 |
+
image = generate_images(
|
20 |
+
latents,
|
21 |
+
text_embeddings,
|
22 |
+
num_inference_steps,
|
23 |
+
guidance_scale=int(guidance_values[i]),
|
24 |
+
)
|
25 |
+
images.append((image, "{}".format(int(guidance_values[i]))))
|
26 |
+
|
27 |
+
fname = "guidance"
|
28 |
+
tab_config = {
|
29 |
+
"Tab": "Guidance",
|
30 |
+
"Prompt": prompt,
|
31 |
+
"Guidance Scale Values": guidance_values,
|
32 |
+
"Number of Inference Steps per Image": num_inference_steps,
|
33 |
+
"Seed": seed,
|
34 |
+
}
|
35 |
+
export_as_zip(images, fname, tab_config)
|
36 |
+
return images, f"outputs/{fname}.zip"
|
37 |
+
|
38 |
+
|
39 |
+
__all__ = ["display_guidance_images"]
|
src/pipelines/inpainting.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import gradio as gr
|
3 |
+
from src.util.base import *
|
4 |
+
from src.util.params import *
|
5 |
+
from diffusers import AutoPipelineForInpainting
|
6 |
+
|
7 |
+
# inpaint_pipe = AutoPipelineForInpainting.from_pretrained(inpaint_model_path).to(torch_device)
|
8 |
+
inpaint_pipe = AutoPipelineForInpainting.from_pipe(pipe).to(torch_device)
|
9 |
+
|
10 |
+
|
11 |
+
def inpaint(dict, num_inference_steps, seed, prompt="", progress=gr.Progress()):
|
12 |
+
progress(0)
|
13 |
+
mask = dict["mask"].convert("RGB").resize((imageHeight, imageWidth))
|
14 |
+
init_image = dict["image"].convert("RGB").resize((imageHeight, imageWidth))
|
15 |
+
output = inpaint_pipe(
|
16 |
+
prompt=prompt,
|
17 |
+
image=init_image,
|
18 |
+
mask_image=mask,
|
19 |
+
guidance_scale=guidance_scale,
|
20 |
+
num_inference_steps=num_inference_steps,
|
21 |
+
generator=torch.Generator().manual_seed(seed),
|
22 |
+
)
|
23 |
+
progress(1)
|
24 |
+
|
25 |
+
fname = "inpainting"
|
26 |
+
tab_config = {
|
27 |
+
"Tab": "Inpainting",
|
28 |
+
"Prompt": prompt,
|
29 |
+
"Number of Inference Steps per Image": num_inference_steps,
|
30 |
+
"Seed": seed,
|
31 |
+
}
|
32 |
+
|
33 |
+
imgs_list = []
|
34 |
+
imgs_list.append((output.images[0], "Inpainted Image"))
|
35 |
+
imgs_list.append((mask, "Mask"))
|
36 |
+
|
37 |
+
export_as_zip(imgs_list, fname, tab_config)
|
38 |
+
return output.images[0], f"outputs/{fname}.zip"
|
39 |
+
|
40 |
+
|
41 |
+
__all__ = ["inpaint"]
|
src/pipelines/interpolate.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import gradio as gr
|
3 |
+
from src.util.base import *
|
4 |
+
from src.util.params import *
|
5 |
+
|
6 |
+
|
7 |
+
def interpolate_prompts(promptA, promptB, num_interpolation_steps):
|
8 |
+
text_embeddingsA = get_text_embeddings(promptA)
|
9 |
+
text_embeddingsB = get_text_embeddings(promptB)
|
10 |
+
|
11 |
+
interpolated_embeddings = []
|
12 |
+
|
13 |
+
for i in range(num_interpolation_steps):
|
14 |
+
alpha = i / num_interpolation_steps
|
15 |
+
interpolated_embedding = torch.lerp(text_embeddingsA, text_embeddingsB, alpha)
|
16 |
+
interpolated_embeddings.append(interpolated_embedding)
|
17 |
+
|
18 |
+
return interpolated_embeddings
|
19 |
+
|
20 |
+
|
21 |
+
def display_interpolate_images(
|
22 |
+
seed, promptA, promptB, num_inference_steps, num_images, progress=gr.Progress()
|
23 |
+
):
|
24 |
+
latents = generate_latents(seed)
|
25 |
+
num_images = num_images + 2 # add 2 for first and last image
|
26 |
+
text_embeddings = interpolate_prompts(promptA, promptB, num_images)
|
27 |
+
images = []
|
28 |
+
progress(0)
|
29 |
+
|
30 |
+
for i in range(num_images):
|
31 |
+
progress(i / num_images)
|
32 |
+
image = generate_images(latents, text_embeddings[i], num_inference_steps)
|
33 |
+
images.append((image, "{}".format(i + 1)))
|
34 |
+
|
35 |
+
progress(1, desc="Exporting as gif")
|
36 |
+
export_as_gif(images, filename="interpolate.gif", reverse=True)
|
37 |
+
|
38 |
+
fname = "interpolate"
|
39 |
+
tab_config = {
|
40 |
+
"Tab": "Interpolate",
|
41 |
+
"First Prompt": promptA,
|
42 |
+
"Second Prompt": promptB,
|
43 |
+
"Number of Interpolation Steps": num_images,
|
44 |
+
"Number of Inference Steps per Image": num_inference_steps,
|
45 |
+
"Seed": seed,
|
46 |
+
}
|
47 |
+
export_as_zip(images, fname, tab_config)
|
48 |
+
return images, "outputs/interpolate.gif", f"outputs/{fname}.zip"
|
49 |
+
|
50 |
+
|
51 |
+
__all__ = ["display_interpolate_images"]
|
src/pipelines/negative.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from src.util.base import *
|
3 |
+
from src.util.params import *
|
4 |
+
|
5 |
+
|
6 |
+
def display_negative_images(
|
7 |
+
prompt, seed, num_inference_steps, negative_prompt="", progress=gr.Progress()
|
8 |
+
):
|
9 |
+
text_embeddings = get_text_embeddings(prompt)
|
10 |
+
text_embeddings_neg = get_text_embeddings(prompt, negative_prompt=negative_prompt)
|
11 |
+
|
12 |
+
latents = generate_latents(seed)
|
13 |
+
|
14 |
+
progress(0)
|
15 |
+
images = generate_images(latents, text_embeddings, num_inference_steps)
|
16 |
+
|
17 |
+
progress(0.5)
|
18 |
+
images_neg = generate_images(latents, text_embeddings_neg, num_inference_steps)
|
19 |
+
|
20 |
+
fname = "negative"
|
21 |
+
tab_config = {
|
22 |
+
"Tab": "Negative",
|
23 |
+
"Prompt": prompt,
|
24 |
+
"Negative Prompt": negative_prompt,
|
25 |
+
"Number of Inference Steps per Image": num_inference_steps,
|
26 |
+
"Seed": seed,
|
27 |
+
}
|
28 |
+
|
29 |
+
imgs_list = []
|
30 |
+
imgs_list.append((images, "Without Negative Prompt"))
|
31 |
+
imgs_list.append((images_neg, "With Negative Prompt"))
|
32 |
+
export_as_zip(imgs_list, fname, tab_config)
|
33 |
+
|
34 |
+
return images, images_neg, f"outputs/{fname}.zip"
|
35 |
+
|
36 |
+
|
37 |
+
__all__ = ["display_negative_images"]
|
src/pipelines/perturbations.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import gradio as gr
|
4 |
+
from src.util.base import *
|
5 |
+
from src.util.params import *
|
6 |
+
|
7 |
+
|
8 |
+
def display_perturb_images(
|
9 |
+
prompt,
|
10 |
+
seed,
|
11 |
+
num_inference_steps,
|
12 |
+
num_images,
|
13 |
+
perturbation_size,
|
14 |
+
progress=gr.Progress(),
|
15 |
+
):
|
16 |
+
text_embeddings = get_text_embeddings(prompt)
|
17 |
+
|
18 |
+
latents_x = generate_latents(seed)
|
19 |
+
scale_x = torch.cos(
|
20 |
+
torch.linspace(0, 2, num_images) * torch.pi * perturbation_size / 4
|
21 |
+
).to(torch_device)
|
22 |
+
noise_x = torch.tensordot(scale_x, latents_x, dims=0)
|
23 |
+
|
24 |
+
progress(0)
|
25 |
+
images = []
|
26 |
+
images.append(
|
27 |
+
(
|
28 |
+
generate_images(latents_x, text_embeddings, num_inference_steps),
|
29 |
+
"{}".format(1),
|
30 |
+
)
|
31 |
+
)
|
32 |
+
|
33 |
+
for i in range(num_images):
|
34 |
+
np.random.seed(i)
|
35 |
+
progress(i / (num_images))
|
36 |
+
latents_y = generate_latents(np.random.randint(0, 100000))
|
37 |
+
scale_y = torch.sin(
|
38 |
+
torch.linspace(0, 2, num_images) * torch.pi * perturbation_size / 4
|
39 |
+
).to(torch_device)
|
40 |
+
noise_y = torch.tensordot(scale_y, latents_y, dims=0)
|
41 |
+
|
42 |
+
noise = noise_x + noise_y
|
43 |
+
image = generate_images(
|
44 |
+
noise[num_images - 1], text_embeddings, num_inference_steps
|
45 |
+
)
|
46 |
+
images.append((image, "{}".format(i + 2)))
|
47 |
+
|
48 |
+
fname = "perturbations"
|
49 |
+
tab_config = {
|
50 |
+
"Tab": "Perturbations",
|
51 |
+
"Prompt": prompt,
|
52 |
+
"Number of Perturbations": num_images,
|
53 |
+
"Perturbation Size": perturbation_size,
|
54 |
+
"Number of Inference Steps per Image": num_inference_steps,
|
55 |
+
"Seed": seed,
|
56 |
+
}
|
57 |
+
export_as_zip(images, fname, tab_config)
|
58 |
+
|
59 |
+
return images, f"outputs/{fname}.zip"
|
60 |
+
|
61 |
+
|
62 |
+
__all__ = ["display_perturb_images"]
|
src/pipelines/poke.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
from src.util.base import *
|
4 |
+
from src.util.params import *
|
5 |
+
from PIL import Image, ImageDraw
|
6 |
+
|
7 |
+
|
8 |
+
def visualize_poke(
|
9 |
+
pokeX, pokeY, pokeHeight, pokeWidth, imageHeight=imageHeight, imageWidth=imageWidth
|
10 |
+
):
|
11 |
+
if (
|
12 |
+
(pokeX - pokeWidth // 2 < 0)
|
13 |
+
or (pokeX + pokeWidth // 2 > imageWidth // 8)
|
14 |
+
or (pokeY - pokeHeight // 2 < 0)
|
15 |
+
or (pokeY + pokeHeight // 2 > imageHeight // 8)
|
16 |
+
):
|
17 |
+
gr.Warning("Modification outside image")
|
18 |
+
shape = [
|
19 |
+
(pokeX * 8 - pokeWidth * 8 // 2, pokeY * 8 - pokeHeight * 8 // 2),
|
20 |
+
(pokeX * 8 + pokeWidth * 8 // 2, pokeY * 8 + pokeHeight * 8 // 2),
|
21 |
+
]
|
22 |
+
|
23 |
+
blank = Image.new("RGB", (imageWidth, imageHeight))
|
24 |
+
|
25 |
+
if os.path.exists("outputs/original.png"):
|
26 |
+
oImg = Image.open("outputs/original.png")
|
27 |
+
pImg = Image.open("outputs/poked.png")
|
28 |
+
else:
|
29 |
+
oImg = blank
|
30 |
+
pImg = blank
|
31 |
+
|
32 |
+
oRec = ImageDraw.Draw(oImg)
|
33 |
+
pRec = ImageDraw.Draw(pImg)
|
34 |
+
|
35 |
+
oRec.rectangle(shape, outline="white")
|
36 |
+
pRec.rectangle(shape, outline="white")
|
37 |
+
|
38 |
+
return oImg, pImg
|
39 |
+
|
40 |
+
|
41 |
+
def display_poke_images(
|
42 |
+
prompt,
|
43 |
+
seed,
|
44 |
+
num_inference_steps,
|
45 |
+
poke=False,
|
46 |
+
pokeX=None,
|
47 |
+
pokeY=None,
|
48 |
+
pokeHeight=None,
|
49 |
+
pokeWidth=None,
|
50 |
+
intermediate=False,
|
51 |
+
progress=gr.Progress(),
|
52 |
+
):
|
53 |
+
text_embeddings = get_text_embeddings(prompt)
|
54 |
+
latents, modified_latents = generate_modified_latents(
|
55 |
+
poke, seed, pokeX, pokeY, pokeHeight, pokeWidth
|
56 |
+
)
|
57 |
+
|
58 |
+
progress(0)
|
59 |
+
images = generate_images(
|
60 |
+
latents, text_embeddings, num_inference_steps, intermediate=intermediate
|
61 |
+
)
|
62 |
+
|
63 |
+
if not intermediate:
|
64 |
+
images.save("outputs/original.png")
|
65 |
+
|
66 |
+
if poke:
|
67 |
+
progress(0.5)
|
68 |
+
modImages = generate_images(
|
69 |
+
modified_latents,
|
70 |
+
text_embeddings,
|
71 |
+
num_inference_steps,
|
72 |
+
intermediate=intermediate,
|
73 |
+
)
|
74 |
+
|
75 |
+
if not intermediate:
|
76 |
+
modImages.save("outputs/poked.png")
|
77 |
+
else:
|
78 |
+
modImages = None
|
79 |
+
|
80 |
+
return images, modImages
|
81 |
+
|
82 |
+
|
83 |
+
__all__ = ["display_poke_images", "visualize_poke"]
|
src/pipelines/seed.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from src.util.base import *
|
3 |
+
from src.util.params import *
|
4 |
+
|
5 |
+
|
6 |
+
def display_seed_images(
|
7 |
+
prompt, num_inference_steps, num_images, progress=gr.Progress()
|
8 |
+
):
|
9 |
+
text_embeddings = get_text_embeddings(prompt)
|
10 |
+
|
11 |
+
images = []
|
12 |
+
progress(0)
|
13 |
+
|
14 |
+
for i in range(num_images):
|
15 |
+
progress(i / num_images)
|
16 |
+
latents = generate_latents(i)
|
17 |
+
image = generate_images(latents, text_embeddings, num_inference_steps)
|
18 |
+
images.append((image, "{}".format(i + 1)))
|
19 |
+
|
20 |
+
fname = "seeds"
|
21 |
+
tab_config = {
|
22 |
+
"Tab": "Seeds",
|
23 |
+
"Prompt": prompt,
|
24 |
+
"Number of Seeds": num_images,
|
25 |
+
"Number of Inference Steps per Image": num_inference_steps,
|
26 |
+
}
|
27 |
+
export_as_zip(images, fname, tab_config)
|
28 |
+
|
29 |
+
return images, f"outputs/{fname}.zip"
|
30 |
+
|
31 |
+
|
32 |
+
__all__ = ["display_seed_images"]
|
src/util/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .base import *
|
2 |
+
from .params import *
|
3 |
+
from .clip_config import *
|
src/util/base.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import zipfile
|
5 |
+
import numpy as np
|
6 |
+
import gradio as gr
|
7 |
+
from PIL import Image
|
8 |
+
from tqdm.auto import tqdm
|
9 |
+
from src.util.params import *
|
10 |
+
from src.util.clip_config import *
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
|
13 |
+
|
14 |
+
def get_text_embeddings(
|
15 |
+
prompt,
|
16 |
+
tokenizer=tokenizer,
|
17 |
+
text_encoder=text_encoder,
|
18 |
+
torch_device=torch_device,
|
19 |
+
batch_size=1,
|
20 |
+
negative_prompt="",
|
21 |
+
):
|
22 |
+
text_input = tokenizer(
|
23 |
+
prompt,
|
24 |
+
padding="max_length",
|
25 |
+
max_length=tokenizer.model_max_length,
|
26 |
+
truncation=True,
|
27 |
+
return_tensors="pt",
|
28 |
+
)
|
29 |
+
|
30 |
+
with torch.no_grad():
|
31 |
+
text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
|
32 |
+
max_length = text_input.input_ids.shape[-1]
|
33 |
+
uncond_input = tokenizer(
|
34 |
+
[negative_prompt] * batch_size,
|
35 |
+
padding="max_length",
|
36 |
+
max_length=max_length,
|
37 |
+
return_tensors="pt",
|
38 |
+
)
|
39 |
+
with torch.no_grad():
|
40 |
+
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
|
41 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
42 |
+
|
43 |
+
return text_embeddings
|
44 |
+
|
45 |
+
|
46 |
+
def generate_latents(
|
47 |
+
seed,
|
48 |
+
height=imageHeight,
|
49 |
+
width=imageWidth,
|
50 |
+
torch_device=torch_device,
|
51 |
+
unet=unet,
|
52 |
+
batch_size=1,
|
53 |
+
):
|
54 |
+
generator = torch.Generator().manual_seed(int(seed))
|
55 |
+
|
56 |
+
latents = torch.randn(
|
57 |
+
(batch_size, unet.config.in_channels, height // 8, width // 8),
|
58 |
+
generator=generator,
|
59 |
+
).to(torch_device)
|
60 |
+
|
61 |
+
return latents
|
62 |
+
|
63 |
+
|
64 |
+
def generate_modified_latents(
|
65 |
+
poke,
|
66 |
+
seed,
|
67 |
+
pokeX=None,
|
68 |
+
pokeY=None,
|
69 |
+
pokeHeight=None,
|
70 |
+
pokeWidth=None,
|
71 |
+
imageHeight=imageHeight,
|
72 |
+
imageWidth=imageWidth,
|
73 |
+
):
|
74 |
+
original_latents = generate_latents(seed, height=imageHeight, width=imageWidth)
|
75 |
+
if poke:
|
76 |
+
np.random.seed(seed)
|
77 |
+
poke_latents = generate_latents(
|
78 |
+
np.random.randint(0, 100000), height=pokeHeight * 8, width=pokeWidth * 8
|
79 |
+
)
|
80 |
+
|
81 |
+
x_origin = pokeX - pokeWidth // 2
|
82 |
+
y_origin = pokeY - pokeHeight // 2
|
83 |
+
|
84 |
+
modified_latents = original_latents.clone()
|
85 |
+
modified_latents[
|
86 |
+
:, :, y_origin : y_origin + pokeHeight, x_origin : x_origin + pokeWidth
|
87 |
+
] = poke_latents
|
88 |
+
else:
|
89 |
+
modified_latents = None
|
90 |
+
|
91 |
+
return original_latents, modified_latents
|
92 |
+
|
93 |
+
|
94 |
+
def convert_to_pil_image(image):
|
95 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
96 |
+
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
97 |
+
images = (image * 255).round().astype("uint8")
|
98 |
+
pil_images = [Image.fromarray(image) for image in images]
|
99 |
+
return pil_images[0]
|
100 |
+
|
101 |
+
|
102 |
+
def generate_images(
|
103 |
+
latents,
|
104 |
+
text_embeddings,
|
105 |
+
num_inference_steps,
|
106 |
+
unet=unet,
|
107 |
+
guidance_scale=guidance_scale,
|
108 |
+
vae=vae,
|
109 |
+
scheduler=scheduler,
|
110 |
+
intermediate=False,
|
111 |
+
progress=gr.Progress(),
|
112 |
+
):
|
113 |
+
scheduler.set_timesteps(num_inference_steps)
|
114 |
+
latents = latents * scheduler.init_noise_sigma
|
115 |
+
images = []
|
116 |
+
i = 1
|
117 |
+
|
118 |
+
for t in tqdm(scheduler.timesteps):
|
119 |
+
latent_model_input = torch.cat([latents] * 2)
|
120 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
121 |
+
|
122 |
+
with torch.no_grad():
|
123 |
+
noise_pred = unet(
|
124 |
+
latent_model_input, t, encoder_hidden_states=text_embeddings
|
125 |
+
).sample
|
126 |
+
|
127 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
128 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
129 |
+
noise_pred_text - noise_pred_uncond
|
130 |
+
)
|
131 |
+
|
132 |
+
if intermediate:
|
133 |
+
progress(((1000 - t) / 1000))
|
134 |
+
Latents = 1 / 0.18215 * latents
|
135 |
+
with torch.no_grad():
|
136 |
+
image = vae.decode(Latents).sample
|
137 |
+
images.append((convert_to_pil_image(image), "{}".format(i)))
|
138 |
+
|
139 |
+
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
140 |
+
i += 1
|
141 |
+
|
142 |
+
if not intermediate:
|
143 |
+
Latents = 1 / 0.18215 * latents
|
144 |
+
with torch.no_grad():
|
145 |
+
image = vae.decode(Latents).sample
|
146 |
+
images = convert_to_pil_image(image)
|
147 |
+
|
148 |
+
return images
|
149 |
+
|
150 |
+
|
151 |
+
def get_word_embeddings(
|
152 |
+
prompt, tokenizer=tokenizer, text_encoder=text_encoder, torch_device=torch_device
|
153 |
+
):
|
154 |
+
text_input = tokenizer(
|
155 |
+
prompt,
|
156 |
+
padding="max_length",
|
157 |
+
max_length=tokenizer.model_max_length,
|
158 |
+
truncation=True,
|
159 |
+
return_tensors="pt",
|
160 |
+
).to(torch_device)
|
161 |
+
|
162 |
+
with torch.no_grad():
|
163 |
+
text_embeddings = text_encoder(text_input.input_ids)[0].reshape(1, -1)
|
164 |
+
|
165 |
+
text_embeddings = text_embeddings.cpu().numpy()
|
166 |
+
return text_embeddings / np.linalg.norm(text_embeddings)
|
167 |
+
|
168 |
+
|
169 |
+
def get_concat_embeddings(names, merge=False):
|
170 |
+
embeddings = []
|
171 |
+
|
172 |
+
for name in names:
|
173 |
+
embedding = get_word_embeddings(name)
|
174 |
+
embeddings.append(embedding)
|
175 |
+
|
176 |
+
embeddings = np.vstack(embeddings)
|
177 |
+
|
178 |
+
if merge:
|
179 |
+
embeddings = np.average(embeddings, axis=0).reshape(1, -1)
|
180 |
+
|
181 |
+
return embeddings
|
182 |
+
|
183 |
+
|
184 |
+
def get_axis_embeddings(A, B):
|
185 |
+
emb = []
|
186 |
+
|
187 |
+
for a, b in zip(A, B):
|
188 |
+
e = get_word_embeddings(a) - get_word_embeddings(b)
|
189 |
+
emb.append(e)
|
190 |
+
|
191 |
+
emb = np.vstack(emb)
|
192 |
+
ax = np.average(emb, axis=0).reshape(1, -1)
|
193 |
+
|
194 |
+
return ax
|
195 |
+
|
196 |
+
|
197 |
+
def calculate_residual(
|
198 |
+
axis, axis_names, from_words=None, to_words=None, residual_axis=1
|
199 |
+
):
|
200 |
+
axis_indices = [0, 1, 2]
|
201 |
+
axis_indices.remove(residual_axis)
|
202 |
+
|
203 |
+
if axis_names[axis_indices[0]] in axis_combinations:
|
204 |
+
fembeddings = get_concat_embeddings(
|
205 |
+
axis_combinations[axis_names[axis_indices[0]]], merge=True
|
206 |
+
)
|
207 |
+
else:
|
208 |
+
axis_combinations[axis_names[axis_indices[0]]] = from_words + to_words
|
209 |
+
fembeddings = get_concat_embeddings(from_words + to_words, merge=True)
|
210 |
+
|
211 |
+
if axis_names[axis_indices[1]] in axis_combinations:
|
212 |
+
sembeddings = get_concat_embeddings(
|
213 |
+
axis_combinations[axis_names[axis_indices[1]]], merge=True
|
214 |
+
)
|
215 |
+
else:
|
216 |
+
axis_combinations[axis_names[axis_indices[1]]] = from_words + to_words
|
217 |
+
sembeddings = get_concat_embeddings(from_words + to_words, merge=True)
|
218 |
+
|
219 |
+
fprojections = fembeddings @ axis[axis_indices[0]].T
|
220 |
+
sprojections = sembeddings @ axis[axis_indices[1]].T
|
221 |
+
|
222 |
+
partial_residual = fembeddings - (fprojections.reshape(-1, 1) * fembeddings)
|
223 |
+
residual = partial_residual - (sprojections.reshape(-1, 1) * sembeddings)
|
224 |
+
|
225 |
+
return residual
|
226 |
+
|
227 |
+
|
228 |
+
def calculate_step_size(num_images, differentiation):
|
229 |
+
return differentiation / (num_images - 1)
|
230 |
+
|
231 |
+
|
232 |
+
def generate_seed_vis(seed):
|
233 |
+
np.random.seed(seed)
|
234 |
+
emb = np.random.rand(15)
|
235 |
+
plt.close()
|
236 |
+
plt.switch_backend("agg")
|
237 |
+
plt.figure(figsize=(10, 0.5))
|
238 |
+
plt.imshow([emb], cmap="viridis")
|
239 |
+
plt.axis("off")
|
240 |
+
return plt
|
241 |
+
|
242 |
+
|
243 |
+
def export_as_gif(images, filename, frames_per_second=2, reverse=False):
|
244 |
+
imgs = [img[0] for img in images]
|
245 |
+
|
246 |
+
if reverse:
|
247 |
+
imgs += imgs[2:-1][::-1]
|
248 |
+
|
249 |
+
imgs[0].save(
|
250 |
+
f"outputs/{filename}",
|
251 |
+
format="GIF",
|
252 |
+
save_all=True,
|
253 |
+
append_images=imgs[1:],
|
254 |
+
duration=1000 // frames_per_second,
|
255 |
+
loop=0,
|
256 |
+
)
|
257 |
+
|
258 |
+
|
259 |
+
def export_as_zip(images, fname, tab_config=None):
|
260 |
+
|
261 |
+
if not os.path.exists(f"outputs/{fname}.zip"):
|
262 |
+
os.makedirs("outputs", exist_ok=True)
|
263 |
+
|
264 |
+
with zipfile.ZipFile(f"outputs/{fname}.zip", "w") as img_zip:
|
265 |
+
|
266 |
+
if tab_config:
|
267 |
+
with open("outputs/config.txt", "w") as f:
|
268 |
+
for key, value in tab_config.items():
|
269 |
+
f.write(f"{key}: {value}\n")
|
270 |
+
f.close()
|
271 |
+
|
272 |
+
img_zip.write("outputs/config.txt", "config.txt")
|
273 |
+
|
274 |
+
for idx, img in enumerate(images):
|
275 |
+
buff = io.BytesIO()
|
276 |
+
img[0].save(buff, format="PNG")
|
277 |
+
buff = buff.getvalue()
|
278 |
+
max_num = len(images)
|
279 |
+
num_leading_zeros = len(str(max_num))
|
280 |
+
img_name = f"{{:0{num_leading_zeros}}}.png"
|
281 |
+
img_zip.writestr(img_name.format(idx + 1), buff)
|
282 |
+
|
283 |
+
|
284 |
+
def read_html(file_path):
|
285 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
286 |
+
content = f.read()
|
287 |
+
return content
|
288 |
+
|
289 |
+
|
290 |
+
__all__ = [
|
291 |
+
"get_text_embeddings",
|
292 |
+
"generate_latents",
|
293 |
+
"generate_modified_latents",
|
294 |
+
"generate_images",
|
295 |
+
"get_word_embeddings",
|
296 |
+
"get_concat_embeddings",
|
297 |
+
"get_axis_embeddings",
|
298 |
+
"calculate_residual",
|
299 |
+
"calculate_step_size",
|
300 |
+
"generate_seed_vis",
|
301 |
+
"export_as_gif",
|
302 |
+
"export_as_zip",
|
303 |
+
"read_html",
|
304 |
+
]
|
src/util/clip_config.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
masculine = [
|
2 |
+
"man",
|
3 |
+
"king",
|
4 |
+
"prince",
|
5 |
+
"husband",
|
6 |
+
"father",
|
7 |
+
]
|
8 |
+
|
9 |
+
feminine = [
|
10 |
+
"woman",
|
11 |
+
"queen",
|
12 |
+
"princess",
|
13 |
+
"wife",
|
14 |
+
"mother",
|
15 |
+
]
|
16 |
+
|
17 |
+
young = [
|
18 |
+
"man",
|
19 |
+
"woman",
|
20 |
+
"king",
|
21 |
+
"queen",
|
22 |
+
"father",
|
23 |
+
]
|
24 |
+
|
25 |
+
old = [
|
26 |
+
"boy",
|
27 |
+
"girl",
|
28 |
+
"prince",
|
29 |
+
"princess",
|
30 |
+
"son",
|
31 |
+
]
|
32 |
+
|
33 |
+
common = [
|
34 |
+
"man",
|
35 |
+
"woman",
|
36 |
+
"boy",
|
37 |
+
"girl",
|
38 |
+
"woman",
|
39 |
+
]
|
40 |
+
|
41 |
+
elite = [
|
42 |
+
"king",
|
43 |
+
"queen",
|
44 |
+
"prince",
|
45 |
+
"princess",
|
46 |
+
"duchess",
|
47 |
+
]
|
48 |
+
|
49 |
+
singular = [
|
50 |
+
"boy",
|
51 |
+
"girl",
|
52 |
+
"cat",
|
53 |
+
"puppy",
|
54 |
+
"computer",
|
55 |
+
]
|
56 |
+
|
57 |
+
plural = [
|
58 |
+
"boys",
|
59 |
+
"girls",
|
60 |
+
"cats",
|
61 |
+
"puppies",
|
62 |
+
"computers",
|
63 |
+
]
|
64 |
+
|
65 |
+
examples = [
|
66 |
+
"king",
|
67 |
+
"queen",
|
68 |
+
"man",
|
69 |
+
"woman",
|
70 |
+
"boys",
|
71 |
+
"girls",
|
72 |
+
"apple",
|
73 |
+
"orange",
|
74 |
+
]
|
75 |
+
|
76 |
+
axis_names = ["gender", "residual", "age"]
|
77 |
+
|
78 |
+
axis_combinations = {
|
79 |
+
"age": young + old,
|
80 |
+
"gender": masculine + feminine,
|
81 |
+
"royalty": common + elite,
|
82 |
+
"number": singular + plural,
|
83 |
+
}
|
84 |
+
|
85 |
+
axisMap = {
|
86 |
+
"X - Axis": 0,
|
87 |
+
"Y - Axis": 1,
|
88 |
+
"Z - Axis": 2,
|
89 |
+
}
|
90 |
+
|
91 |
+
whichAxisMap = {
|
92 |
+
"which_axis_1": "X - Axis",
|
93 |
+
"which_axis_2": "Z - Axis",
|
94 |
+
"which_axis_3": "Y - Axis",
|
95 |
+
"which_axis_4": "---",
|
96 |
+
"which_axis_5": "---",
|
97 |
+
"which_axis_6": "---",
|
98 |
+
}
|
99 |
+
|
100 |
+
__all__ = [
|
101 |
+
"axisMap",
|
102 |
+
"whichAxisMap",
|
103 |
+
"axis_names",
|
104 |
+
"axis_combinations",
|
105 |
+
"examples",
|
106 |
+
"masculine",
|
107 |
+
"feminine",
|
108 |
+
"young",
|
109 |
+
"old",
|
110 |
+
"common",
|
111 |
+
"elite",
|
112 |
+
"singular",
|
113 |
+
"plural",
|
114 |
+
]
|
src/util/params.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import secrets
|
3 |
+
from gradio.networking import setup_tunnel
|
4 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
5 |
+
from diffusers import (
|
6 |
+
AutoencoderKL,
|
7 |
+
UNet2DConditionModel,
|
8 |
+
LCMScheduler,
|
9 |
+
DDIMScheduler,
|
10 |
+
StableDiffusionPipeline,
|
11 |
+
)
|
12 |
+
|
13 |
+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
+
|
15 |
+
isLCM = False
|
16 |
+
HF_ACCESS_TOKEN = ""
|
17 |
+
|
18 |
+
model_path = "segmind/small-sd"
|
19 |
+
inpaint_model_path = "Lykon/dreamshaper-8-inpainting"
|
20 |
+
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
21 |
+
promptA = "Self-portrait oil painting, a beautiful man with golden hair, 8k"
|
22 |
+
promptB = "Self-portrait oil painting, a beautiful woman with golden hair, 8k"
|
23 |
+
negative_prompt = "a photo frame"
|
24 |
+
|
25 |
+
num_images = 5
|
26 |
+
degree = 360
|
27 |
+
perturbation_size = 0.1
|
28 |
+
num_inference_steps = 8
|
29 |
+
seed = 69420
|
30 |
+
|
31 |
+
guidance_scale = 8
|
32 |
+
guidance_values = "1, 8, 20"
|
33 |
+
intermediate = True
|
34 |
+
pokeX, pokeY = 256, 256
|
35 |
+
pokeHeight, pokeWidth = 128, 128
|
36 |
+
imageHeight, imageWidth = 512, 512
|
37 |
+
|
38 |
+
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
39 |
+
text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder").to(
|
40 |
+
torch_device
|
41 |
+
)
|
42 |
+
|
43 |
+
if isLCM:
|
44 |
+
scheduler = LCMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
45 |
+
else:
|
46 |
+
scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
47 |
+
|
48 |
+
unet = UNet2DConditionModel.from_pretrained(model_path, subfolder="unet").to(
|
49 |
+
torch_device
|
50 |
+
)
|
51 |
+
vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae").to(torch_device)
|
52 |
+
|
53 |
+
pipe = StableDiffusionPipeline(
|
54 |
+
tokenizer=tokenizer,
|
55 |
+
text_encoder=text_encoder,
|
56 |
+
unet=unet,
|
57 |
+
scheduler=scheduler,
|
58 |
+
vae=vae,
|
59 |
+
safety_checker=None,
|
60 |
+
feature_extractor=None,
|
61 |
+
requires_safety_checker=False,
|
62 |
+
).to(torch_device)
|
63 |
+
|
64 |
+
dash_tunnel = setup_tunnel("0.0.0.0", 8000, secrets.token_urlsafe(32))
|
65 |
+
|
66 |
+
__all__ = [
|
67 |
+
"prompt",
|
68 |
+
"negative_prompt",
|
69 |
+
"num_images",
|
70 |
+
"degree",
|
71 |
+
"perturbation_size",
|
72 |
+
"num_inference_steps",
|
73 |
+
"seed",
|
74 |
+
"intermediate",
|
75 |
+
"pokeX",
|
76 |
+
"pokeY",
|
77 |
+
"pokeHeight",
|
78 |
+
"pokeWidth",
|
79 |
+
"promptA",
|
80 |
+
"promptB",
|
81 |
+
"tokenizer",
|
82 |
+
"text_encoder",
|
83 |
+
"scheduler",
|
84 |
+
"unet",
|
85 |
+
"vae",
|
86 |
+
"torch_device",
|
87 |
+
"imageHeight",
|
88 |
+
"imageWidth",
|
89 |
+
"guidance_scale",
|
90 |
+
"guidance_values",
|
91 |
+
"HF_ACCESS_TOKEN",
|
92 |
+
"model_path",
|
93 |
+
"inpaint_model_path",
|
94 |
+
"dash_tunnel",
|
95 |
+
"pipe",
|
96 |
+
]
|