akameswa commited on
Commit
04ef268
·
verified ·
1 Parent(s): f9d9229

Upload 37 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ images/circular.gif filter=lfs diff=lfs merge=lfs -text
37
+ images/interpolate.gif filter=lfs diff=lfs merge=lfs -text
html/circular.html ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <details open>
2
+ <summary style="background-color: #CE6400; padding-left: 10px;">
3
+ About
4
+ </summary>
5
+ <div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
6
+ <div style="flex: 1;">
7
+ <p style="margin-top: 10px">
8
+ This tab generates a circular trajectory through latent space that begins and ends with the same image.
9
+ If we specify a large number of steps around the circle, the successive images will be closely related, resulting in a gradual deformation that produces a nice animation.
10
+ </p>
11
+ <p style="font-weight: bold;">
12
+ Additional Controls:
13
+ </p>
14
+ <p style="font-weight: bold;">
15
+ Number of Steps around the Circle:
16
+ </p>
17
+ <p>
18
+ Specify the number of images to produce along the circular path.
19
+ </p>
20
+ <p style="font-weight: bold;">
21
+ Proportion of Circle:
22
+ </p>
23
+ <p>
24
+ Sets the proportion of the circle to cover during image generation.
25
+ Ranges from 0 to 360 degrees.
26
+ Using a high step count with a small number of degrees allows you to explore very subtle image transformations.
27
+ </p>
28
+ </div>
29
+ <div style="flex: 1; align-content: center;">
30
+ <img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/circular.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
31
+ </div>
32
+ </div>
html/denoising.html ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <details open>
2
+ <summary style="background-color: #CE6400; padding-left: 10px;">
3
+ About
4
+ </summary>
5
+ <div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
6
+ <div style="flex: 1;">
7
+ <p style="margin-top: 10px">
8
+ This tab displays the intermediate images generated during the denoising process.
9
+ Seeing these intermediate images provides insight into how the diffusion model progressively adds detail at each step.
10
+ </p>
11
+ </div>
12
+ <div style="flex: 1; align-content: center;">
13
+ <img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/denoising.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
14
+ </div>
15
+ </div>
16
+ </details>
html/embeddings.html ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <head>
2
+ <link rel="stylesheet" type="text/css" href="styles.css">
3
+ </head>
4
+
5
+ <details open>
6
+ <summary style="background-color: #CE6400; padding-left: 10px;">
7
+ About
8
+ </summary>
9
+ <div style="background-color: #D87F2B; padding-left: 10px;">
10
+ <p style="font-weight: bold;">
11
+ Basic Exploration
12
+ </p>
13
+ The top part of the embeddings tab is the 3D plot of semantic feature space.
14
+ At the bottom of the tab there are expandable panels that can be opened to reveal more advanced features
15
+
16
+ <ul>
17
+ <li>
18
+ <strong>
19
+ Explore the 3D semantic feature space:
20
+ </strong>
21
+ Click and drag in the 3D semantic feature space to rotate the view.
22
+ Use the scroll wheel to zoom in and out.
23
+ Hold down the control key and click and drag to pan the view.
24
+ </li>
25
+ <li>
26
+ <strong>
27
+ Find the generated image:
28
+ </strong>
29
+ Hover over a point in the semantic feature space, and a window will pop up showing a generated image from this one-word prompt.
30
+ On left click, the image will be downloaded.
31
+ </li>
32
+ <li>
33
+ <strong>
34
+ Find the embedding vector display:
35
+ </strong>
36
+ Hover over a word in the 3D semantic feature space, and an embedding vector display at the bottom of the tab shows the corresponding embedding vector.
37
+ </li>
38
+ <li>
39
+ <strong>
40
+ Add/remove words from the 3D plot:
41
+ </strong>
42
+ Type a word in the Add/Remove word text box below the 3D plot to add a word to the plot, or if the word is already present, remove it from the plot.
43
+ You can also type multiple words separated by spaces or commas.
44
+ </li>
45
+ <li>
46
+ <strong>
47
+ Change image for word in the 3D plot:
48
+ </strong>
49
+ Type a word in the Change image for word text box below the 3D plot to generate a new image for the corresponding word in the plot.
50
+ </li>
51
+ </ul>
52
+
53
+ <p style="font-weight: bold; margin-top: 10px;">
54
+ Semantic Dimensions
55
+ </p>
56
+ <ul>
57
+ <li>
58
+ <strong>Select a different semantic dimension.</strong><br>
59
+ Open the Custom Semantic Dimensions panel and choose another dimension for the X or Y or Z axis.
60
+ See how the display changes.
61
+ </li>
62
+ <li>
63
+ <strong>Alter a semantic dimension.</strong><br>
64
+ Examine the positive and negative word pairs used to define the semantic dimension.
65
+ You can change these pairs to alter the semantic dimension.
66
+ </li>
67
+ <li>
68
+ <strong>Define a new semantic dimension.</strong><br>
69
+ Pick a new semantic dimension that you can define using pairs of opposed words.
70
+ For example, you could define a "tense" dimension with pairs such as eat/ate, go/went, see/saw, and is/was to contrast present and past tense forms of verbs.
71
+ </li>
72
+ </ul>
73
+ </div>
74
+ </details>
75
+
html/guidance.html ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <details open>
2
+ <summary style="background-color: #CE6400; padding-left: 10px;">
3
+ About
4
+ </summary>
5
+ <div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
6
+ <div style="flex: 1;">
7
+ <p style="margin-top: 10px">
8
+ Guidance is responsible for making the target image adhere to the prompt.
9
+ A higher value enforces this relation, whereas a lower value does not.
10
+ For example, a guidance scale of 1 produces a distorted grayscale image, whereas 50 produces a distorted, oversaturated image.
11
+ The default value of 8 produces normal-looking images that reasonably adhere to the prompt.
12
+ </p>
13
+ </div>
14
+ <div style="flex: 1; align-content: center;">
15
+ <img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/guidance.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
16
+ </div>
17
+ </div>
html/inpainting.html ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <details open>
2
+ <summary style="background-color: #CE6400; padding-left: 10px;">
3
+ About
4
+ </summary>
5
+ <div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
6
+ <div style="flex: 1;">
7
+ <p style="margin-top: 10px">
8
+ Unlike poke, which globally alters the target image via a perturbation in the initial latent noise, inpainting alters just the region of the perturbation and allows us to specify the change we want to make.
9
+ </p>
10
+ </div>
11
+ <div style="flex: 1; align-content: center;">
12
+ <img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/inpainting.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
13
+ </div>
14
+ </div>
html/interpolate.html ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <details open>
2
+ <summary style="background-color: #CE6400; padding-left: 10px;">
3
+ About
4
+ </summary>
5
+ <div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
6
+ <div style="flex: 1;">
7
+ <p style="margin-top: 10px">
8
+ This tab generates noise patterns for two text prompts and then interpolates between them, gradually transforming from the first to the second.
9
+ With a large number of perturbation steps the transformation is very gradual and makes a nice animation.
10
+ </p>
11
+ <p style="font-weight: bold;">
12
+ Additional Controls:
13
+ </p>
14
+ <p style="font-weight: bold;">
15
+ Number of Interpolation Steps:
16
+ </p>
17
+ <p>
18
+ Defines the number of intermediate images to generate between the two prompts.
19
+ </p>
20
+ </div>
21
+ <div style="flex: 1; align-content: center;">
22
+ <img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/interpolate.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
23
+ </div>
24
+ </div>
html/negative.html ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <details open>
2
+ <summary style="background-color: #CE6400; padding-left: 10px;">
3
+ About
4
+ </summary>
5
+ <div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
6
+ <div style="flex: 1;">
7
+ <p style="margin-top: 10px">
8
+ Negative prompts steer images away from unwanted features.
9
+ For example, “red” as a negative prompt makes the generated image unlikely to have reddish hues.
10
+ </p>
11
+ </div>
12
+ <div style="flex: 1; align-content: center;">
13
+ <img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/negative.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
14
+ </div>
15
+ </div>
html/perturbations.html ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <details open>
2
+ <summary style="background-color: #CE6400; padding-left: 10px;">
3
+ About
4
+ </summary>
5
+ <div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
6
+ <div style="flex: 1;">
7
+ <p style="margin-top: 10px">
8
+ Perturbations enables the exploration of the latent space around a seed.
9
+ Perturbing the noise from an initial seed towards the noise from a different seed illustrates the variations in images obtainable from a local region of latent space.
10
+ Using a small perturbation size produces target images that closely resemble the one from the initial seed.
11
+ Larger perturbations traverse more distance in latent space towards the second seed, resulting in greater variation in the generated images.
12
+ </p>
13
+ <p style="font-weight: bold;">
14
+ Additional Controls:
15
+ </p>
16
+ <p style="font-weight: bold;">
17
+ Number of Perturbations:
18
+ </p>
19
+ <p>
20
+ Specify the number of perturbations to create, i.e., the number of seeds to use. More perturbations produce more images.
21
+ </p>
22
+ <p style="font-weight: bold;">
23
+ Perturbation Size:
24
+ </p>
25
+ <p>
26
+ Controls the perturbation magnitude, ranging from 0 to 1.
27
+ With a value of 0, all images will match the one from the initial seed.
28
+ With a value of 1, images will have no connection to the initial seed.
29
+ A value such as 0.1 is recommended.
30
+ </p>
31
+ </div>
32
+ <div style="flex: 1; align-content: center;">
33
+ <img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/perturbations.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
34
+ </div>
35
+ </div>
html/poke.html ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <details open>
2
+ <summary style="background-color: #CE6400; padding-left: 10px;">
3
+ About
4
+ </summary>
5
+ <div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
6
+ <div style="flex: 1;">
7
+ <p style="margin-top: 10px">
8
+ Poke explores how perturbations in a local region of the initial latent noise impact the target image.
9
+ A small perturbation to the initial latent noise gets carried through the denoising process, demonstrating the global effect it can produce.
10
+ </p>
11
+ <p style="font-weight: bold;">
12
+ Additional Controls:
13
+ </p>
14
+ <p>
15
+ You can adjust the perturbation through the X, Y, height, and width controls.
16
+ </p>
17
+ </div>
18
+ <div style="flex: 1; align-content: center;">
19
+ <img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/poke.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
20
+ </div>
21
+ </div>
html/seeds.html ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <details open>
2
+ <summary style="background-color: #CE6400; padding-left: 10px;">
3
+ About
4
+ </summary>
5
+ <div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
6
+ <div style="flex: 1;">
7
+ <p style="margin-top: 10px">
8
+ Seeds create the initial noise that gets refined into the target image.
9
+ Different seeds produce different noise patterns, hence the target image will differ even when prompted by the same text.
10
+ This tab produces multiple target images from the same text prompt to showcase how changing the seed changes the target image.
11
+ </p>
12
+ <p style="font-weight: bold;">
13
+ Additional Controls:
14
+ </p>
15
+ <p style="font-weight: bold;">
16
+ Number of Seeds:
17
+ </p>
18
+ <p>
19
+ Specify how many seed values to use.
20
+ </p>
21
+ </div>
22
+ <div style="flex: 1; align-content: center;">
23
+ <img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/seeds.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
24
+ </div>
25
+ </div>
images/circular.gif ADDED

Git LFS Details

  • SHA256: c6f6c81e8b01b3c68ea55d06bbb7769a3748f3e60ba1fdbcb11b18bb8bee38ec
  • Pointer size: 133 Bytes
  • Size of remote file: 18.6 MB
images/circular.png ADDED
images/denoising.png ADDED
images/guidance.png ADDED
images/inpainting.png ADDED
images/interpolate.gif ADDED

Git LFS Details

  • SHA256: 01e94a4b4272c7ef705957e59e6640afcad0f2ba675fc9051c87317c24785e2f
  • Pointer size: 132 Bytes
  • Size of remote file: 3.06 MB
images/interpolate.png ADDED
images/negative.png ADDED
images/perturbations.png ADDED
images/poke.png ADDED
images/seeds.png ADDED
run.py ADDED
@@ -0,0 +1,1029 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import gradio as gr
3
+ from PIL import Image
4
+ from src.util import *
5
+ from io import BytesIO
6
+ from src.pipelines import *
7
+ from threading import Thread
8
+ from dash import Dash, dcc, html, Input, Output, no_update, callback
9
+
10
+ app = Dash(__name__)
11
+
12
+ app.layout = html.Div(
13
+ className="container",
14
+ children=[
15
+ dcc.Graph(
16
+ id="graph", figure=fig, clear_on_unhover=True, style={"height": "90vh"}
17
+ ),
18
+ dcc.Tooltip(id="tooltip"),
19
+ html.Div(id="word-emb-txt", style={"background-color": "white"}),
20
+ html.Div(id="word-emb-vis"),
21
+ html.Div(
22
+ [
23
+ html.Button(id="btn-download-image", hidden=True),
24
+ dcc.Download(id="download-image"),
25
+ ]
26
+ ),
27
+ ],
28
+ )
29
+
30
+
31
+ @callback(
32
+ Output("tooltip", "show"),
33
+ Output("tooltip", "bbox"),
34
+ Output("tooltip", "children"),
35
+ Output("tooltip", "direction"),
36
+ Output("word-emb-txt", "children"),
37
+ Output("word-emb-vis", "children"),
38
+ Input("graph", "hoverData"),
39
+ )
40
+ def display_hover(hoverData):
41
+ if hoverData is None:
42
+ return False, no_update, no_update, no_update, no_update, no_update
43
+
44
+ hover_data = hoverData["points"][0]
45
+ bbox = hover_data["bbox"]
46
+ direction = "left"
47
+ index = hover_data["pointNumber"]
48
+
49
+ children = [
50
+ html.Img(
51
+ src=images[index],
52
+ style={"width": "250px"},
53
+ ),
54
+ html.P(
55
+ hover_data["text"],
56
+ style={
57
+ "color": "black",
58
+ "font-size": "20px",
59
+ "text-align": "center",
60
+ "background-color": "white",
61
+ "margin": "5px",
62
+ },
63
+ ),
64
+ ]
65
+
66
+ emb_children = [
67
+ html.Img(
68
+ src=generate_word_emb_vis(hover_data["text"]),
69
+ style={"width": "100%", "height": "25px"},
70
+ ),
71
+ ]
72
+
73
+ return True, bbox, children, direction, hover_data["text"], emb_children
74
+
75
+
76
+ @callback(
77
+ Output("download-image", "data"),
78
+ Input("graph", "clickData"),
79
+ )
80
+ def download_image(clickData):
81
+
82
+ if clickData is None:
83
+ return no_update
84
+
85
+ click_data = clickData["points"][0]
86
+ index = click_data["pointNumber"]
87
+ txt = click_data["text"]
88
+
89
+ img_encoded = images[index]
90
+ img_decoded = base64.b64decode(img_encoded.split(",")[1])
91
+ img = Image.open(BytesIO(img_decoded))
92
+ img.save(f"{txt}.png")
93
+ return dcc.send_file(f"{txt}.png")
94
+
95
+
96
+ with gr.Blocks() as demo:
97
+ gr.Markdown("## Stable Diffusion Demo")
98
+
99
+ with gr.Tab("Latent Space"):
100
+
101
+ with gr.TabItem("Denoising"):
102
+ gr.Markdown("Observe the intermediate images during denoising.")
103
+ gr.HTML(read_html("DiffusionDemo/html/denoising.html"))
104
+
105
+ with gr.Row():
106
+ with gr.Column():
107
+ prompt_denoise = gr.Textbox(
108
+ lines=1,
109
+ label="Prompt",
110
+ value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
111
+ )
112
+ num_inference_steps_denoise = gr.Slider(
113
+ minimum=2,
114
+ maximum=100,
115
+ step=1,
116
+ value=8,
117
+ label="Number of Inference Steps",
118
+ )
119
+
120
+ with gr.Row():
121
+ seed_denoise = gr.Slider(
122
+ minimum=0, maximum=100, step=1, value=14, label="Seed"
123
+ )
124
+ seed_vis_denoise = gr.Plot(
125
+ value=generate_seed_vis(14), label="Seed"
126
+ )
127
+
128
+ generate_images_button_denoise = gr.Button("Generate Images")
129
+
130
+ with gr.Column():
131
+ images_output_denoise = gr.Gallery(label="Images", selected_index=0)
132
+ gif_denoise = gr.Image(label="GIF")
133
+ zip_output_denoise = gr.File(label="Download ZIP")
134
+
135
+ @generate_images_button_denoise.click(
136
+ inputs=[prompt_denoise, seed_denoise, num_inference_steps_denoise],
137
+ outputs=[images_output_denoise, gif_denoise, zip_output_denoise],
138
+ )
139
+ def generate_images_wrapper(
140
+ prompt, seed, num_inference_steps, progress=gr.Progress()
141
+ ):
142
+ images, _ = display_poke_images(
143
+ prompt, seed, num_inference_steps, poke=False, intermediate=True
144
+ )
145
+ fname = "denoising"
146
+ tab_config = {
147
+ "Tab": "Denoising",
148
+ "Prompt": prompt,
149
+ "Number of Inference Steps": num_inference_steps,
150
+ "Seed": seed,
151
+ }
152
+ export_as_zip(images, fname, tab_config)
153
+ progress(1, desc="Exporting as gif")
154
+ export_as_gif(images, filename="denoising.gif")
155
+ return images, "outputs/denoising.gif", f"outputs/{fname}.zip"
156
+
157
+ seed_denoise.change(
158
+ fn=generate_seed_vis, inputs=[seed_denoise], outputs=[seed_vis_denoise]
159
+ )
160
+
161
+ with gr.TabItem("Seeds"):
162
+ gr.Markdown(
163
+ "Understand how different starting points in latent space can lead to different images."
164
+ )
165
+ gr.HTML(read_html("DiffusionDemo/html/seeds.html"))
166
+
167
+ with gr.Row():
168
+ with gr.Column():
169
+ prompt_seed = gr.Textbox(
170
+ lines=1,
171
+ label="Prompt",
172
+ value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
173
+ )
174
+ num_images_seed = gr.Slider(
175
+ minimum=1, maximum=100, step=1, value=5, label="Number of Seeds"
176
+ )
177
+ num_inference_steps_seed = gr.Slider(
178
+ minimum=2,
179
+ maximum=100,
180
+ step=1,
181
+ value=8,
182
+ label="Number of Inference Steps per Image",
183
+ )
184
+ generate_images_button_seed = gr.Button("Generate Images")
185
+
186
+ with gr.Column():
187
+ images_output_seed = gr.Gallery(label="Images", selected_index=0)
188
+ zip_output_seed = gr.File(label="Download ZIP")
189
+
190
+ generate_images_button_seed.click(
191
+ fn=display_seed_images,
192
+ inputs=[prompt_seed, num_inference_steps_seed, num_images_seed],
193
+ outputs=[images_output_seed, zip_output_seed],
194
+ )
195
+
196
+ with gr.TabItem("Perturbations"):
197
+ gr.Markdown("Explore different perturbations from a point in latent space.")
198
+ gr.HTML(read_html("DiffusionDemo/html/perturbations.html"))
199
+
200
+ with gr.Row():
201
+ with gr.Column():
202
+ prompt_perturb = gr.Textbox(
203
+ lines=1,
204
+ label="Prompt",
205
+ value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
206
+ )
207
+ num_images_perturb = gr.Slider(
208
+ minimum=0,
209
+ maximum=100,
210
+ step=1,
211
+ value=5,
212
+ label="Number of Perturbations",
213
+ )
214
+ perturbation_size_perturb = gr.Slider(
215
+ minimum=0,
216
+ maximum=1,
217
+ step=0.1,
218
+ value=0.1,
219
+ label="Perturbation Size",
220
+ )
221
+ num_inference_steps_perturb = gr.Slider(
222
+ minimum=2,
223
+ maximum=100,
224
+ step=1,
225
+ value=8,
226
+ label="Number of Inference Steps per Image",
227
+ )
228
+
229
+ with gr.Row():
230
+ seed_perturb = gr.Slider(
231
+ minimum=0, maximum=100, step=1, value=14, label="Seed"
232
+ )
233
+ seed_vis_perturb = gr.Plot(
234
+ value=generate_seed_vis(14), label="Seed"
235
+ )
236
+
237
+ generate_images_button_perturb = gr.Button("Generate Images")
238
+
239
+ with gr.Column():
240
+ images_output_perturb = gr.Gallery(label="Image", selected_index=0)
241
+ zip_output_perturb = gr.File(label="Download ZIP")
242
+
243
+ generate_images_button_perturb.click(
244
+ fn=display_perturb_images,
245
+ inputs=[
246
+ prompt_perturb,
247
+ seed_perturb,
248
+ num_inference_steps_perturb,
249
+ num_images_perturb,
250
+ perturbation_size_perturb,
251
+ ],
252
+ outputs=[images_output_perturb, zip_output_perturb],
253
+ )
254
+ seed_perturb.change(
255
+ fn=generate_seed_vis, inputs=[seed_perturb], outputs=[seed_vis_perturb]
256
+ )
257
+
258
+ with gr.TabItem("Circular"):
259
+ gr.Markdown(
260
+ "Generate a circular path in latent space and observe how the images vary along the path."
261
+ )
262
+ gr.HTML(read_html("DiffusionDemo/html/circular.html"))
263
+
264
+ with gr.Row():
265
+ with gr.Column():
266
+ prompt_circular = gr.Textbox(
267
+ lines=1,
268
+ label="Prompt",
269
+ value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
270
+ )
271
+ num_images_circular = gr.Slider(
272
+ minimum=2,
273
+ maximum=100,
274
+ step=1,
275
+ value=5,
276
+ label="Number of Steps around the Circle",
277
+ )
278
+
279
+ with gr.Row():
280
+ degree_circular = gr.Slider(
281
+ minimum=0,
282
+ maximum=360,
283
+ step=1,
284
+ value=360,
285
+ label="Proportion of Circle",
286
+ info="Enter the value in degrees",
287
+ )
288
+ step_size_circular = gr.Textbox(
289
+ label="Step Size", value=360 / 5
290
+ )
291
+
292
+ num_inference_steps_circular = gr.Slider(
293
+ minimum=2,
294
+ maximum=100,
295
+ step=1,
296
+ value=8,
297
+ label="Number of Inference Steps per Image",
298
+ )
299
+
300
+ with gr.Row():
301
+ seed_circular = gr.Slider(
302
+ minimum=0, maximum=100, step=1, value=14, label="Seed"
303
+ )
304
+ seed_vis_circular = gr.Plot(
305
+ value=generate_seed_vis(14), label="Seed"
306
+ )
307
+
308
+ generate_images_button_circular = gr.Button("Generate Images")
309
+
310
+ with gr.Column():
311
+ images_output_circular = gr.Gallery(label="Image", selected_index=0)
312
+ gif_circular = gr.Image(label="GIF")
313
+ zip_output_circular = gr.File(label="Download ZIP")
314
+
315
+ num_images_circular.change(
316
+ fn=calculate_step_size,
317
+ inputs=[num_images_circular, degree_circular],
318
+ outputs=[step_size_circular],
319
+ )
320
+ degree_circular.change(
321
+ fn=calculate_step_size,
322
+ inputs=[num_images_circular, degree_circular],
323
+ outputs=[step_size_circular],
324
+ )
325
+ generate_images_button_circular.click(
326
+ fn=display_circular_images,
327
+ inputs=[
328
+ prompt_circular,
329
+ seed_circular,
330
+ num_inference_steps_circular,
331
+ num_images_circular,
332
+ degree_circular,
333
+ ],
334
+ outputs=[images_output_circular, gif_circular, zip_output_circular],
335
+ )
336
+ seed_circular.change(
337
+ fn=generate_seed_vis, inputs=[seed_circular], outputs=[seed_vis_circular]
338
+ )
339
+
340
+ with gr.TabItem("Poke"):
341
+ gr.Markdown("Perturb a region in the image and observe the effect.")
342
+ gr.HTML(read_html("DiffusionDemo/html/poke.html"))
343
+
344
+ with gr.Row():
345
+ with gr.Column():
346
+ prompt_poke = gr.Textbox(
347
+ lines=1,
348
+ label="Prompt",
349
+ value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
350
+ )
351
+ num_inference_steps_poke = gr.Slider(
352
+ minimum=2,
353
+ maximum=100,
354
+ step=1,
355
+ value=8,
356
+ label="Number of Inference Steps per Image",
357
+ )
358
+
359
+ with gr.Row():
360
+ seed_poke = gr.Slider(
361
+ minimum=0, maximum=100, step=1, value=14, label="Seed"
362
+ )
363
+ seed_vis_poke = gr.Plot(
364
+ value=generate_seed_vis(14), label="Seed"
365
+ )
366
+
367
+ pokeX = gr.Slider(
368
+ label="pokeX",
369
+ minimum=0,
370
+ maximum=64,
371
+ step=1,
372
+ value=32,
373
+ info="X coordinate of poke center",
374
+ )
375
+ pokeY = gr.Slider(
376
+ label="pokeY",
377
+ minimum=0,
378
+ maximum=64,
379
+ step=1,
380
+ value=32,
381
+ info="Y coordinate of poke center",
382
+ )
383
+ pokeHeight = gr.Slider(
384
+ label="pokeHeight",
385
+ minimum=0,
386
+ maximum=64,
387
+ step=1,
388
+ value=8,
389
+ info="Height of the poke",
390
+ )
391
+ pokeWidth = gr.Slider(
392
+ label="pokeWidth",
393
+ minimum=0,
394
+ maximum=64,
395
+ step=1,
396
+ value=8,
397
+ info="Width of the poke",
398
+ )
399
+
400
+ generate_images_button_poke = gr.Button("Generate Images")
401
+
402
+ with gr.Column():
403
+ original_images_output_poke = gr.Image(
404
+ value=visualize_poke(32, 32, 8, 8)[0], label="Original Image"
405
+ )
406
+ poked_images_output_poke = gr.Image(
407
+ value=visualize_poke(32, 32, 8, 8)[1], label="Poked Image"
408
+ )
409
+ zip_output_poke = gr.File(label="Download ZIP")
410
+
411
+ pokeX.change(
412
+ visualize_poke,
413
+ inputs=[pokeX, pokeY, pokeHeight, pokeWidth],
414
+ outputs=[original_images_output_poke, poked_images_output_poke],
415
+ )
416
+ pokeY.change(
417
+ visualize_poke,
418
+ inputs=[pokeX, pokeY, pokeHeight, pokeWidth],
419
+ outputs=[original_images_output_poke, poked_images_output_poke],
420
+ )
421
+ pokeHeight.change(
422
+ visualize_poke,
423
+ inputs=[pokeX, pokeY, pokeHeight, pokeWidth],
424
+ outputs=[original_images_output_poke, poked_images_output_poke],
425
+ )
426
+ pokeWidth.change(
427
+ visualize_poke,
428
+ inputs=[pokeX, pokeY, pokeHeight, pokeWidth],
429
+ outputs=[original_images_output_poke, poked_images_output_poke],
430
+ )
431
+ seed_poke.change(
432
+ fn=generate_seed_vis, inputs=[seed_poke], outputs=[seed_vis_poke]
433
+ )
434
+
435
+ @generate_images_button_poke.click(
436
+ inputs=[
437
+ prompt_poke,
438
+ seed_poke,
439
+ num_inference_steps_poke,
440
+ pokeX,
441
+ pokeY,
442
+ pokeHeight,
443
+ pokeWidth,
444
+ ],
445
+ outputs=[
446
+ original_images_output_poke,
447
+ poked_images_output_poke,
448
+ zip_output_poke,
449
+ ],
450
+ )
451
+ def generate_images_wrapper(
452
+ prompt,
453
+ seed,
454
+ num_inference_steps,
455
+ pokeX=pokeX,
456
+ pokeY=pokeY,
457
+ pokeHeight=pokeHeight,
458
+ pokeWidth=pokeWidth,
459
+ ):
460
+ _, _ = display_poke_images(
461
+ prompt,
462
+ seed,
463
+ num_inference_steps,
464
+ poke=True,
465
+ pokeX=pokeX,
466
+ pokeY=pokeY,
467
+ pokeHeight=pokeHeight,
468
+ pokeWidth=pokeWidth,
469
+ intermediate=False,
470
+ )
471
+ images, modImages = visualize_poke(pokeX, pokeY, pokeHeight, pokeWidth)
472
+ fname = "poke"
473
+ tab_config = {
474
+ "Tab": "Poke",
475
+ "Prompt": prompt,
476
+ "Number of Inference Steps per Image": num_inference_steps,
477
+ "Seed": seed,
478
+ "PokeX": pokeX,
479
+ "PokeY": pokeY,
480
+ "PokeHeight": pokeHeight,
481
+ "PokeWidth": pokeWidth,
482
+ }
483
+ imgs_list = []
484
+ imgs_list.append((images, "Original Image"))
485
+ imgs_list.append((modImages, "Poked Image"))
486
+
487
+ export_as_zip(imgs_list, fname, tab_config)
488
+ return images, modImages, f"outputs/{fname}.zip"
489
+
490
+ with gr.TabItem("Guidance"):
491
+ gr.Markdown("Observe the effect of different guidance scales.")
492
+ gr.HTML(read_html("DiffusionDemo/html/guidance.html"))
493
+
494
+ with gr.Row():
495
+ with gr.Column():
496
+ prompt_guidance = gr.Textbox(
497
+ lines=1,
498
+ label="Prompt",
499
+ value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
500
+ )
501
+ num_inference_steps_guidance = gr.Slider(
502
+ minimum=2,
503
+ maximum=100,
504
+ step=1,
505
+ value=8,
506
+ label="Number of Inference Steps per Image",
507
+ )
508
+ guidance_scale_values = gr.Textbox(
509
+ lines=1, value="1, 8, 20, 30", label="Guidance Scale Values"
510
+ )
511
+
512
+ with gr.Row():
513
+ seed_guidance = gr.Slider(
514
+ minimum=0, maximum=100, step=1, value=14, label="Seed"
515
+ )
516
+ seed_vis_guidance = gr.Plot(
517
+ value=generate_seed_vis(14), label="Seed"
518
+ )
519
+
520
+ generate_images_button_guidance = gr.Button("Generate Images")
521
+
522
+ with gr.Column():
523
+ images_output_guidance = gr.Gallery(
524
+ label="Images", selected_index=0
525
+ )
526
+ zip_output_guidance = gr.File(label="Download ZIP")
527
+
528
+ generate_images_button_guidance.click(
529
+ fn=display_guidance_images,
530
+ inputs=[
531
+ prompt_guidance,
532
+ seed_guidance,
533
+ num_inference_steps_guidance,
534
+ guidance_scale_values,
535
+ ],
536
+ outputs=[images_output_guidance, zip_output_guidance],
537
+ )
538
+ seed_guidance.change(
539
+ fn=generate_seed_vis, inputs=[seed_guidance], outputs=[seed_vis_guidance]
540
+ )
541
+
542
+ with gr.TabItem("Inpainting"):
543
+ gr.Markdown("Inpaint the image based on the prompt.")
544
+ gr.HTML(read_html("DiffusionDemo/html/inpainting.html"))
545
+
546
+ with gr.Row():
547
+ with gr.Column():
548
+ uploaded_img_inpaint = gr.Image(
549
+ source="upload", tool="sketch", type="pil", label="Upload"
550
+ )
551
+ prompt_inpaint = gr.Textbox(
552
+ lines=1, label="Prompt", value="sunglasses"
553
+ )
554
+ num_inference_steps_inpaint = gr.Slider(
555
+ minimum=2,
556
+ maximum=100,
557
+ step=1,
558
+ value=8,
559
+ label="Number of Inference Steps per Image",
560
+ )
561
+
562
+ with gr.Row():
563
+ seed_inpaint = gr.Slider(
564
+ minimum=0, maximum=100, step=1, value=14, label="Seed"
565
+ )
566
+ seed_vis_inpaint = gr.Plot(
567
+ value=generate_seed_vis(14), label="Seed"
568
+ )
569
+
570
+ inpaint_button = gr.Button("Inpaint")
571
+
572
+ with gr.Column():
573
+ images_output_inpaint = gr.Image(label="Output")
574
+ zip_output_inpaint = gr.File(label="Download ZIP")
575
+
576
+ inpaint_button.click(
577
+ fn=inpaint,
578
+ inputs=[
579
+ uploaded_img_inpaint,
580
+ num_inference_steps_inpaint,
581
+ seed_inpaint,
582
+ prompt_inpaint,
583
+ ],
584
+ outputs=[images_output_inpaint, zip_output_inpaint],
585
+ )
586
+ seed_inpaint.change(
587
+ fn=generate_seed_vis, inputs=[seed_inpaint], outputs=[seed_vis_inpaint]
588
+ )
589
+
590
+ with gr.Tab("CLIP Space"):
591
+
592
+ with gr.TabItem("Embeddings"):
593
+ gr.Markdown(
594
+ "Visualize text embedding space in 3D with input texts and output images based on the chosen axis."
595
+ )
596
+ gr.HTML(read_html("DiffusionDemo/html/embeddings.html"))
597
+
598
+ with gr.Row():
599
+ output = gr.HTML(
600
+ f"""
601
+ <iframe id="html" src="{dash_tunnel}" style="width:100%; height:700px;"></iframe>
602
+ """
603
+ )
604
+ with gr.Row():
605
+ word2add_rem = gr.Textbox(lines=1, label="Add/Remove word")
606
+ word2change = gr.Textbox(lines=1, label="Change image for word")
607
+ clear_words_button = gr.Button(value="Clear words")
608
+
609
+ with gr.Accordion("Custom Semantic Dimensions", open=False):
610
+ with gr.Row():
611
+ axis_name_1 = gr.Textbox(label="Axis name", value="gender")
612
+ which_axis_1 = gr.Dropdown(
613
+ choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
614
+ value=whichAxisMap["which_axis_1"],
615
+ label="Axis direction",
616
+ )
617
+ from_words_1 = gr.Textbox(
618
+ lines=1,
619
+ label="Positive",
620
+ value="prince husband father son uncle",
621
+ )
622
+ to_words_1 = gr.Textbox(
623
+ lines=1,
624
+ label="Negative",
625
+ value="princess wife mother daughter aunt",
626
+ )
627
+ submit_1 = gr.Button("Submit")
628
+
629
+ with gr.Row():
630
+ axis_name_2 = gr.Textbox(label="Axis name", value="age")
631
+ which_axis_2 = gr.Dropdown(
632
+ choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
633
+ value=whichAxisMap["which_axis_2"],
634
+ label="Axis direction",
635
+ )
636
+ from_words_2 = gr.Textbox(
637
+ lines=1, label="Positive", value="man woman king queen father"
638
+ )
639
+ to_words_2 = gr.Textbox(
640
+ lines=1, label="Negative", value="boy girl prince princess son"
641
+ )
642
+ submit_2 = gr.Button("Submit")
643
+
644
+ with gr.Row():
645
+ axis_name_3 = gr.Textbox(label="Axis name", value="residual")
646
+ which_axis_3 = gr.Dropdown(
647
+ choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
648
+ value=whichAxisMap["which_axis_3"],
649
+ label="Axis direction",
650
+ )
651
+ from_words_3 = gr.Textbox(lines=1, label="Positive")
652
+ to_words_3 = gr.Textbox(lines=1, label="Negative")
653
+ submit_3 = gr.Button("Submit")
654
+
655
+ with gr.Row():
656
+ axis_name_4 = gr.Textbox(label="Axis name", value="number")
657
+ which_axis_4 = gr.Dropdown(
658
+ choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
659
+ value=whichAxisMap["which_axis_4"],
660
+ label="Axis direction",
661
+ )
662
+ from_words_4 = gr.Textbox(
663
+ lines=1,
664
+ label="Positive",
665
+ value="boys girls cats puppies computers",
666
+ )
667
+ to_words_4 = gr.Textbox(
668
+ lines=1, label="Negative", value="boy girl cat puppy computer"
669
+ )
670
+ submit_4 = gr.Button("Submit")
671
+
672
+ with gr.Row():
673
+ axis_name_5 = gr.Textbox(label="Axis name", value="royalty")
674
+ which_axis_5 = gr.Dropdown(
675
+ choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
676
+ value=whichAxisMap["which_axis_5"],
677
+ label="Axis direction",
678
+ )
679
+ from_words_5 = gr.Textbox(
680
+ lines=1,
681
+ label="Positive",
682
+ value="king queen prince princess duchess",
683
+ )
684
+ to_words_5 = gr.Textbox(
685
+ lines=1, label="Negative", value="man woman boy girl woman"
686
+ )
687
+ submit_5 = gr.Button("Submit")
688
+
689
+ with gr.Row():
690
+ axis_name_6 = gr.Textbox(label="Axis name")
691
+ which_axis_6 = gr.Dropdown(
692
+ choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
693
+ value=whichAxisMap["which_axis_6"],
694
+ label="Axis direction",
695
+ )
696
+ from_words_6 = gr.Textbox(lines=1, label="Positive")
697
+ to_words_6 = gr.Textbox(lines=1, label="Negative")
698
+ submit_6 = gr.Button("Submit")
699
+
700
+ @word2add_rem.submit(inputs=[word2add_rem], outputs=[output, word2add_rem])
701
+ def add_rem_word_and_clear(words):
702
+ return add_rem_word(words), ""
703
+
704
+ @word2change.submit(inputs=[word2change], outputs=[output, word2change])
705
+ def change_word_and_clear(word):
706
+ return change_word(word), ""
707
+
708
+ clear_words_button.click(fn=clear_words, outputs=[output])
709
+
710
+ @submit_1.click(
711
+ inputs=[axis_name_1, which_axis_1, from_words_1, to_words_1],
712
+ outputs=[
713
+ output,
714
+ which_axis_2,
715
+ which_axis_3,
716
+ which_axis_4,
717
+ which_axis_5,
718
+ which_axis_6,
719
+ ],
720
+ )
721
+ def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
722
+
723
+ for ax in whichAxisMap:
724
+ if whichAxisMap[ax] == which_axis:
725
+ whichAxisMap[ax] = "---"
726
+
727
+ whichAxisMap["which_axis_1"] = which_axis
728
+ return (
729
+ set_axis(axis_name, which_axis, from_words, to_words),
730
+ whichAxisMap["which_axis_2"],
731
+ whichAxisMap["which_axis_3"],
732
+ whichAxisMap["which_axis_4"],
733
+ whichAxisMap["which_axis_5"],
734
+ whichAxisMap["which_axis_6"],
735
+ )
736
+
737
+ @submit_2.click(
738
+ inputs=[axis_name_2, which_axis_2, from_words_2, to_words_2],
739
+ outputs=[
740
+ output,
741
+ which_axis_1,
742
+ which_axis_3,
743
+ which_axis_4,
744
+ which_axis_5,
745
+ which_axis_6,
746
+ ],
747
+ )
748
+ def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
749
+
750
+ for ax in whichAxisMap:
751
+ if whichAxisMap[ax] == which_axis:
752
+ whichAxisMap[ax] = "---"
753
+
754
+ whichAxisMap["which_axis_2"] = which_axis
755
+ return (
756
+ set_axis(axis_name, which_axis, from_words, to_words),
757
+ whichAxisMap["which_axis_1"],
758
+ whichAxisMap["which_axis_3"],
759
+ whichAxisMap["which_axis_4"],
760
+ whichAxisMap["which_axis_5"],
761
+ whichAxisMap["which_axis_6"],
762
+ )
763
+
764
+ @submit_3.click(
765
+ inputs=[axis_name_3, which_axis_3, from_words_3, to_words_3],
766
+ outputs=[
767
+ output,
768
+ which_axis_1,
769
+ which_axis_2,
770
+ which_axis_4,
771
+ which_axis_5,
772
+ which_axis_6,
773
+ ],
774
+ )
775
+ def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
776
+
777
+ for ax in whichAxisMap:
778
+ if whichAxisMap[ax] == which_axis:
779
+ whichAxisMap[ax] = "---"
780
+
781
+ whichAxisMap["which_axis_3"] = which_axis
782
+ return (
783
+ set_axis(axis_name, which_axis, from_words, to_words),
784
+ whichAxisMap["which_axis_1"],
785
+ whichAxisMap["which_axis_2"],
786
+ whichAxisMap["which_axis_4"],
787
+ whichAxisMap["which_axis_5"],
788
+ whichAxisMap["which_axis_6"],
789
+ )
790
+
791
+ @submit_4.click(
792
+ inputs=[axis_name_4, which_axis_4, from_words_4, to_words_4],
793
+ outputs=[
794
+ output,
795
+ which_axis_1,
796
+ which_axis_2,
797
+ which_axis_3,
798
+ which_axis_5,
799
+ which_axis_6,
800
+ ],
801
+ )
802
+ def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
803
+
804
+ for ax in whichAxisMap:
805
+ if whichAxisMap[ax] == which_axis:
806
+ whichAxisMap[ax] = "---"
807
+
808
+ whichAxisMap["which_axis_4"] = which_axis
809
+ return (
810
+ set_axis(axis_name, which_axis, from_words, to_words),
811
+ whichAxisMap["which_axis_1"],
812
+ whichAxisMap["which_axis_2"],
813
+ whichAxisMap["which_axis_3"],
814
+ whichAxisMap["which_axis_5"],
815
+ whichAxisMap["which_axis_6"],
816
+ )
817
+
818
+ @submit_5.click(
819
+ inputs=[axis_name_5, which_axis_5, from_words_5, to_words_5],
820
+ outputs=[
821
+ output,
822
+ which_axis_1,
823
+ which_axis_2,
824
+ which_axis_3,
825
+ which_axis_4,
826
+ which_axis_6,
827
+ ],
828
+ )
829
+ def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
830
+
831
+ for ax in whichAxisMap:
832
+ if whichAxisMap[ax] == which_axis:
833
+ whichAxisMap[ax] = "---"
834
+
835
+ whichAxisMap["which_axis_5"] = which_axis
836
+ return (
837
+ set_axis(axis_name, which_axis, from_words, to_words),
838
+ whichAxisMap["which_axis_1"],
839
+ whichAxisMap["which_axis_2"],
840
+ whichAxisMap["which_axis_3"],
841
+ whichAxisMap["which_axis_4"],
842
+ whichAxisMap["which_axis_6"],
843
+ )
844
+
845
+ @submit_6.click(
846
+ inputs=[axis_name_6, which_axis_6, from_words_6, to_words_6],
847
+ outputs=[
848
+ output,
849
+ which_axis_1,
850
+ which_axis_2,
851
+ which_axis_3,
852
+ which_axis_4,
853
+ which_axis_5,
854
+ ],
855
+ )
856
+ def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
857
+
858
+ for ax in whichAxisMap:
859
+ if whichAxisMap[ax] == which_axis:
860
+ whichAxisMap[ax] = "---"
861
+
862
+ whichAxisMap["which_axis_6"] = which_axis
863
+ return (
864
+ set_axis(axis_name, which_axis, from_words, to_words),
865
+ whichAxisMap["which_axis_1"],
866
+ whichAxisMap["which_axis_2"],
867
+ whichAxisMap["which_axis_3"],
868
+ whichAxisMap["which_axis_4"],
869
+ whichAxisMap["which_axis_5"],
870
+ )
871
+
872
+ with gr.TabItem("Interpolate"):
873
+ gr.Markdown(
874
+ "Interpolate between the first and the second prompt, and observe how the output changes."
875
+ )
876
+ gr.HTML(read_html("DiffusionDemo/html/interpolate.html"))
877
+
878
+ with gr.Row():
879
+ with gr.Column():
880
+ promptA = gr.Textbox(
881
+ lines=1,
882
+ label="First Prompt",
883
+ value="Self-portrait oil painting, a beautiful man with golden hair, 8k",
884
+ )
885
+ promptB = gr.Textbox(
886
+ lines=1,
887
+ label="Second Prompt",
888
+ value="Self-portrait oil painting, a beautiful woman with golden hair, 8k",
889
+ )
890
+ num_images_interpolate = gr.Slider(
891
+ minimum=0,
892
+ maximum=100,
893
+ step=1,
894
+ value=5,
895
+ label="Number of Interpolation Steps",
896
+ )
897
+ num_inference_steps_interpolate = gr.Slider(
898
+ minimum=2,
899
+ maximum=100,
900
+ step=1,
901
+ value=8,
902
+ label="Number of Inference Steps per Image",
903
+ )
904
+
905
+ with gr.Row():
906
+ seed_interpolate = gr.Slider(
907
+ minimum=0, maximum=100, step=1, value=14, label="Seed"
908
+ )
909
+ seed_vis_interpolate = gr.Plot(
910
+ value=generate_seed_vis(14), label="Seed"
911
+ )
912
+
913
+ generate_images_button_interpolate = gr.Button("Generate Images")
914
+
915
+ with gr.Column():
916
+ images_output_interpolate = gr.Gallery(
917
+ label="Interpolated Images", selected_index=0
918
+ )
919
+ gif_interpolate = gr.Image(label="GIF")
920
+ zip_output_interpolate = gr.File(label="Download ZIP")
921
+
922
+ generate_images_button_interpolate.click(
923
+ fn=display_interpolate_images,
924
+ inputs=[
925
+ seed_interpolate,
926
+ promptA,
927
+ promptB,
928
+ num_inference_steps_interpolate,
929
+ num_images_interpolate,
930
+ ],
931
+ outputs=[
932
+ images_output_interpolate,
933
+ gif_interpolate,
934
+ zip_output_interpolate,
935
+ ],
936
+ )
937
+ seed_interpolate.change(
938
+ fn=generate_seed_vis,
939
+ inputs=[seed_interpolate],
940
+ outputs=[seed_vis_interpolate],
941
+ )
942
+
943
+ with gr.TabItem("Negative"):
944
+ gr.Markdown("Observe the effect of negative prompts.")
945
+ gr.HTML(read_html("DiffusionDemo/html/negative.html"))
946
+
947
+ with gr.Row():
948
+ with gr.Column():
949
+ prompt_negative = gr.Textbox(
950
+ lines=1,
951
+ label="Prompt",
952
+ value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
953
+ )
954
+ neg_prompt = gr.Textbox(
955
+ lines=1, label="Negative Prompt", value="Yellow"
956
+ )
957
+ num_inference_steps_negative = gr.Slider(
958
+ minimum=2,
959
+ maximum=100,
960
+ step=1,
961
+ value=8,
962
+ label="Number of Inference Steps per Image",
963
+ )
964
+
965
+ with gr.Row():
966
+ seed_negative = gr.Slider(
967
+ minimum=0, maximum=100, step=1, value=14, label="Seed"
968
+ )
969
+ seed_vis_negative = gr.Plot(
970
+ value=generate_seed_vis(14), label="Seed"
971
+ )
972
+
973
+ generate_images_button_negative = gr.Button("Generate Images")
974
+
975
+ with gr.Column():
976
+ images_output_negative = gr.Image(
977
+ label="Image without Negative Prompt"
978
+ )
979
+ images_neg_output_negative = gr.Image(
980
+ label="Image with Negative Prompt"
981
+ )
982
+ zip_output_negative = gr.File(label="Download ZIP")
983
+
984
+ seed_negative.change(
985
+ fn=generate_seed_vis, inputs=[seed_negative], outputs=[seed_vis_negative]
986
+ )
987
+ generate_images_button_negative.click(
988
+ fn=display_negative_images,
989
+ inputs=[
990
+ prompt_negative,
991
+ seed_negative,
992
+ num_inference_steps_negative,
993
+ neg_prompt,
994
+ ],
995
+ outputs=[
996
+ images_output_negative,
997
+ images_neg_output_negative,
998
+ zip_output_negative,
999
+ ],
1000
+ )
1001
+
1002
+ with gr.Tab("Credits"):
1003
+ gr.Markdown("""
1004
+ Author: Adithya Kameswara Rao, Carnegie Mellon University.
1005
+
1006
+ Advisor: David S. Touretzky, Carnegie Mellon University.
1007
+
1008
+ This work was funded by a grant from NEOM Company, and by National Science Foundation award IIS-2112633.
1009
+ """)
1010
+
1011
+
1012
+ def run_dash():
1013
+ app.run(host="127.0.0.1", port="8000")
1014
+
1015
+
1016
+ def run_gradio():
1017
+ demo.queue()
1018
+ _, _, public_url = demo.launch(share=True)
1019
+ return public_url
1020
+
1021
+
1022
+ # if __name__ == "__main__":
1023
+ # thread = Thread(target=run_dash)
1024
+ # thread.daemon = True
1025
+ # thread.start()
1026
+ # try:
1027
+ # run_gradio()
1028
+ # except KeyboardInterrupt:
1029
+ # print("Server closed")
src/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import util
2
+ from . import pipelines
src/pipelines/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .circular import *
2
+ from .embeddings import *
3
+ from .interpolate import *
4
+ from .poke import *
5
+ from .seed import *
6
+ from .perturbations import *
7
+ from .negative import *
8
+ from .guidance import *
9
+ from .inpainting import *
src/pipelines/circular.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gradio as gr
4
+ from src.util.base import *
5
+ from src.util.params import *
6
+
7
+
8
+ def display_circular_images(
9
+ prompt, seed, num_inference_steps, num_images, degree, progress=gr.Progress()
10
+ ):
11
+ np.random.seed(seed)
12
+ text_embeddings = get_text_embeddings(prompt)
13
+
14
+ latents_x = generate_latents(seed)
15
+ latents_y = generate_latents(seed * np.random.randint(0, 100000))
16
+
17
+ scale_x = torch.cos(
18
+ torch.linspace(0, 2, num_images) * torch.pi * (degree / 360)
19
+ ).to(torch_device)
20
+ scale_y = torch.sin(
21
+ torch.linspace(0, 2, num_images) * torch.pi * (degree / 360)
22
+ ).to(torch_device)
23
+
24
+ noise_x = torch.tensordot(scale_x, latents_x, dims=0)
25
+ noise_y = torch.tensordot(scale_y, latents_y, dims=0)
26
+
27
+ noise = noise_x + noise_y
28
+
29
+ progress(0)
30
+ images = []
31
+ for i in range(num_images):
32
+ progress(i / num_images)
33
+ image = generate_images(noise[i], text_embeddings, num_inference_steps)
34
+ images.append((image, "{}".format(i)))
35
+
36
+ progress(1, desc="Exporting as gif")
37
+ export_as_gif(images, filename="circular.gif")
38
+
39
+ fname = "circular"
40
+ tab_config = {
41
+ "Tab": "Circular",
42
+ "Prompt": prompt,
43
+ "Number of Steps around the Circle": num_images,
44
+ "Proportion of Circle": degree,
45
+ "Number of Inference Steps per Image": num_inference_steps,
46
+ "Seed": seed,
47
+ }
48
+ export_as_zip(images, fname, tab_config)
49
+ return images, "outputs/circular.gif", f"outputs/{fname}.zip"
50
+
51
+
52
+ __all__ = ["display_circular_images"]
src/pipelines/embeddings.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ import gradio as gr
4
+ import matplotlib.pyplot as plt
5
+ from diffusers import StableDiffusionPipeline
6
+
7
+ import base64
8
+ from io import BytesIO
9
+ import plotly.express as px
10
+
11
+ from src.util.base import *
12
+ from src.util.params import *
13
+ from src.util.clip_config import *
14
+
15
+ age = get_axis_embeddings(young, old)
16
+ gender = get_axis_embeddings(masculine, feminine)
17
+ royalty = get_axis_embeddings(common, elite)
18
+
19
+ images = []
20
+ for example in examples:
21
+ image = pipe(
22
+ prompt=example,
23
+ num_inference_steps=num_inference_steps,
24
+ guidance_scale=guidance_scale,
25
+ ).images[0]
26
+ buffer = BytesIO()
27
+ image.save(buffer, format="JPEG")
28
+ encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
29
+ images.append("data:image/jpeg;base64, " + encoded_image)
30
+
31
+ axis = np.vstack([gender, royalty, age])
32
+ axis[1] = calculate_residual(axis, axis_names)
33
+
34
+ coords = get_concat_embeddings(examples) @ axis.T
35
+ coords[:, 1] = 5 * (1.0 - coords[:, 1])
36
+
37
+
38
+ def update_fig():
39
+ global coords, examples, fig
40
+ fig.data[0].x = coords[:, 0]
41
+ fig.data[0].y = coords[:, 1]
42
+ fig.data[0].z = coords[:, 2]
43
+ fig.data[0].text = examples
44
+
45
+ return f"""
46
+ <script>
47
+ document.getElementById("html").src += "?rand={random.random()}"
48
+ </script>
49
+ <iframe id="html" src={dash_tunnel} style="width:100%; height:725px;"></iframe>
50
+ """
51
+
52
+
53
+ def add_word(new_example):
54
+ global coords, images, examples
55
+ new_coord = get_concat_embeddings([new_example]) @ axis.T
56
+ new_coord[:, 1] = 5 * (1.0 - new_coord[:, 1])
57
+ coords = np.vstack([coords, new_coord])
58
+
59
+ image = pipe(
60
+ prompt=new_example,
61
+ num_inference_steps=num_inference_steps,
62
+ guidance_scale=guidance_scale,
63
+ ).images[0]
64
+ buffer = BytesIO()
65
+ image.save(buffer, format="JPEG")
66
+ encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
67
+ images.append("data:image/jpeg;base64, " + encoded_image)
68
+ examples.append(new_example)
69
+ return update_fig()
70
+
71
+
72
+ def remove_word(new_example):
73
+ global coords, images, examples
74
+ examplesMap = {example: index for index, example in enumerate(examples)}
75
+ index = examplesMap[new_example]
76
+
77
+ coords = np.delete(coords, index, 0)
78
+ images.pop(index)
79
+ examples.pop(index)
80
+ return update_fig()
81
+
82
+
83
+ def add_rem_word(new_examples):
84
+ global examples
85
+ new_examples = new_examples.replace(",", " ").split()
86
+
87
+ for new_example in new_examples:
88
+ if new_example in examples:
89
+ remove_word(new_example)
90
+ gr.Info("Removed {}".format(new_example))
91
+ else:
92
+ tokens = tokenizer.encode(new_example)
93
+ if len(tokens) != 3:
94
+ gr.Warning(f"{new_example} not found in embeddings")
95
+ else:
96
+ add_word(new_example)
97
+ gr.Info("Added {}".format(new_example))
98
+
99
+ return update_fig()
100
+
101
+
102
+ def set_axis(axis_name, which_axis, from_words, to_words):
103
+ global coords, examples, fig, axis_names
104
+
105
+ if axis_name != "residual":
106
+ from_words, to_words = (
107
+ from_words.replace(",", " ").split(),
108
+ to_words.replace(",", " ").split(),
109
+ )
110
+ axis_emb = get_axis_embeddings(from_words, to_words)
111
+ axis[axisMap[which_axis]] = axis_emb
112
+ axis_names[axisMap[which_axis]] = axis_name
113
+
114
+ for i, name in enumerate(axis_names):
115
+ if name == "residual":
116
+ axis[i] = calculate_residual(axis, axis_names, from_words, to_words, i)
117
+ axis_names[i] = "residual"
118
+ else:
119
+ residual = calculate_residual(
120
+ axis, axis_names, residual_axis=axisMap[which_axis]
121
+ )
122
+ axis[axisMap[which_axis]] = residual
123
+ axis_names[axisMap[which_axis]] = axis_name
124
+
125
+ coords = get_concat_embeddings(examples) @ axis.T
126
+ coords[:, 1] = 5 * (1.0 - coords[:, 1])
127
+
128
+ fig.update_layout(
129
+ scene=dict(
130
+ xaxis_title=axis_names[0],
131
+ yaxis_title=axis_names[1],
132
+ zaxis_title=axis_names[2],
133
+ )
134
+ )
135
+ return update_fig()
136
+
137
+
138
+ def change_word(examples):
139
+ examples = examples.replace(",", " ").split()
140
+
141
+ for example in examples:
142
+ remove_word(example)
143
+ add_word(example)
144
+ gr.Info("Changed image for {}".format(example))
145
+
146
+ return update_fig()
147
+
148
+
149
+ def clear_words():
150
+ while examples:
151
+ remove_word(examples[-1])
152
+ return update_fig()
153
+
154
+
155
+ def generate_word_emb_vis(prompt):
156
+ buf = BytesIO()
157
+ emb = get_word_embeddings(prompt).reshape(77, 768)[1]
158
+ plt.imsave(buf, [emb], cmap="inferno")
159
+ img = "data:image/jpeg;base64, " + base64.b64encode(buf.getvalue()).decode("utf-8")
160
+ return img
161
+
162
+
163
+ fig = px.scatter_3d(
164
+ x=coords[:, 0],
165
+ y=coords[:, 1],
166
+ z=coords[:, 2],
167
+ labels={
168
+ "x": axis_names[0],
169
+ "y": axis_names[1],
170
+ "z": axis_names[2],
171
+ },
172
+ text=examples,
173
+ height=750,
174
+ )
175
+
176
+ fig.update_layout(
177
+ margin=dict(l=0, r=0, b=0, t=0), scene_camera=dict(eye=dict(x=2, y=2, z=0.1))
178
+ )
179
+
180
+ fig.update_traces(hoverinfo="none", hovertemplate=None)
181
+
182
+ __all__ = [
183
+ "fig",
184
+ "update_fig",
185
+ "coords",
186
+ "images",
187
+ "examples",
188
+ "add_word",
189
+ "remove_word",
190
+ "add_rem_word",
191
+ "change_word",
192
+ "clear_words",
193
+ "generate_word_emb_vis",
194
+ "set_axis",
195
+ "axis",
196
+ ]
src/pipelines/guidance.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from src.util.base import *
3
+ from src.util.params import *
4
+
5
+
6
+ def display_guidance_images(
7
+ prompt, seed, num_inference_steps, guidance_values, progress=gr.Progress()
8
+ ):
9
+ text_embeddings = get_text_embeddings(prompt)
10
+ latents = generate_latents(seed)
11
+
12
+ progress(0)
13
+ images = []
14
+ guidance_values = guidance_values.replace(",", " ").split()
15
+ num_images = len(guidance_values)
16
+
17
+ for i in range(num_images):
18
+ progress(i / num_images)
19
+ image = generate_images(
20
+ latents,
21
+ text_embeddings,
22
+ num_inference_steps,
23
+ guidance_scale=int(guidance_values[i]),
24
+ )
25
+ images.append((image, "{}".format(int(guidance_values[i]))))
26
+
27
+ fname = "guidance"
28
+ tab_config = {
29
+ "Tab": "Guidance",
30
+ "Prompt": prompt,
31
+ "Guidance Scale Values": guidance_values,
32
+ "Number of Inference Steps per Image": num_inference_steps,
33
+ "Seed": seed,
34
+ }
35
+ export_as_zip(images, fname, tab_config)
36
+ return images, f"outputs/{fname}.zip"
37
+
38
+
39
+ __all__ = ["display_guidance_images"]
src/pipelines/inpainting.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from src.util.base import *
4
+ from src.util.params import *
5
+ from diffusers import AutoPipelineForInpainting
6
+
7
+ # inpaint_pipe = AutoPipelineForInpainting.from_pretrained(inpaint_model_path).to(torch_device)
8
+ inpaint_pipe = AutoPipelineForInpainting.from_pipe(pipe).to(torch_device)
9
+
10
+
11
+ def inpaint(dict, num_inference_steps, seed, prompt="", progress=gr.Progress()):
12
+ progress(0)
13
+ mask = dict["mask"].convert("RGB").resize((imageHeight, imageWidth))
14
+ init_image = dict["image"].convert("RGB").resize((imageHeight, imageWidth))
15
+ output = inpaint_pipe(
16
+ prompt=prompt,
17
+ image=init_image,
18
+ mask_image=mask,
19
+ guidance_scale=guidance_scale,
20
+ num_inference_steps=num_inference_steps,
21
+ generator=torch.Generator().manual_seed(seed),
22
+ )
23
+ progress(1)
24
+
25
+ fname = "inpainting"
26
+ tab_config = {
27
+ "Tab": "Inpainting",
28
+ "Prompt": prompt,
29
+ "Number of Inference Steps per Image": num_inference_steps,
30
+ "Seed": seed,
31
+ }
32
+
33
+ imgs_list = []
34
+ imgs_list.append((output.images[0], "Inpainted Image"))
35
+ imgs_list.append((mask, "Mask"))
36
+
37
+ export_as_zip(imgs_list, fname, tab_config)
38
+ return output.images[0], f"outputs/{fname}.zip"
39
+
40
+
41
+ __all__ = ["inpaint"]
src/pipelines/interpolate.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from src.util.base import *
4
+ from src.util.params import *
5
+
6
+
7
+ def interpolate_prompts(promptA, promptB, num_interpolation_steps):
8
+ text_embeddingsA = get_text_embeddings(promptA)
9
+ text_embeddingsB = get_text_embeddings(promptB)
10
+
11
+ interpolated_embeddings = []
12
+
13
+ for i in range(num_interpolation_steps):
14
+ alpha = i / num_interpolation_steps
15
+ interpolated_embedding = torch.lerp(text_embeddingsA, text_embeddingsB, alpha)
16
+ interpolated_embeddings.append(interpolated_embedding)
17
+
18
+ return interpolated_embeddings
19
+
20
+
21
+ def display_interpolate_images(
22
+ seed, promptA, promptB, num_inference_steps, num_images, progress=gr.Progress()
23
+ ):
24
+ latents = generate_latents(seed)
25
+ num_images = num_images + 2 # add 2 for first and last image
26
+ text_embeddings = interpolate_prompts(promptA, promptB, num_images)
27
+ images = []
28
+ progress(0)
29
+
30
+ for i in range(num_images):
31
+ progress(i / num_images)
32
+ image = generate_images(latents, text_embeddings[i], num_inference_steps)
33
+ images.append((image, "{}".format(i + 1)))
34
+
35
+ progress(1, desc="Exporting as gif")
36
+ export_as_gif(images, filename="interpolate.gif", reverse=True)
37
+
38
+ fname = "interpolate"
39
+ tab_config = {
40
+ "Tab": "Interpolate",
41
+ "First Prompt": promptA,
42
+ "Second Prompt": promptB,
43
+ "Number of Interpolation Steps": num_images,
44
+ "Number of Inference Steps per Image": num_inference_steps,
45
+ "Seed": seed,
46
+ }
47
+ export_as_zip(images, fname, tab_config)
48
+ return images, "outputs/interpolate.gif", f"outputs/{fname}.zip"
49
+
50
+
51
+ __all__ = ["display_interpolate_images"]
src/pipelines/negative.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from src.util.base import *
3
+ from src.util.params import *
4
+
5
+
6
+ def display_negative_images(
7
+ prompt, seed, num_inference_steps, negative_prompt="", progress=gr.Progress()
8
+ ):
9
+ text_embeddings = get_text_embeddings(prompt)
10
+ text_embeddings_neg = get_text_embeddings(prompt, negative_prompt=negative_prompt)
11
+
12
+ latents = generate_latents(seed)
13
+
14
+ progress(0)
15
+ images = generate_images(latents, text_embeddings, num_inference_steps)
16
+
17
+ progress(0.5)
18
+ images_neg = generate_images(latents, text_embeddings_neg, num_inference_steps)
19
+
20
+ fname = "negative"
21
+ tab_config = {
22
+ "Tab": "Negative",
23
+ "Prompt": prompt,
24
+ "Negative Prompt": negative_prompt,
25
+ "Number of Inference Steps per Image": num_inference_steps,
26
+ "Seed": seed,
27
+ }
28
+
29
+ imgs_list = []
30
+ imgs_list.append((images, "Without Negative Prompt"))
31
+ imgs_list.append((images_neg, "With Negative Prompt"))
32
+ export_as_zip(imgs_list, fname, tab_config)
33
+
34
+ return images, images_neg, f"outputs/{fname}.zip"
35
+
36
+
37
+ __all__ = ["display_negative_images"]
src/pipelines/perturbations.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gradio as gr
4
+ from src.util.base import *
5
+ from src.util.params import *
6
+
7
+
8
+ def display_perturb_images(
9
+ prompt,
10
+ seed,
11
+ num_inference_steps,
12
+ num_images,
13
+ perturbation_size,
14
+ progress=gr.Progress(),
15
+ ):
16
+ text_embeddings = get_text_embeddings(prompt)
17
+
18
+ latents_x = generate_latents(seed)
19
+ scale_x = torch.cos(
20
+ torch.linspace(0, 2, num_images) * torch.pi * perturbation_size / 4
21
+ ).to(torch_device)
22
+ noise_x = torch.tensordot(scale_x, latents_x, dims=0)
23
+
24
+ progress(0)
25
+ images = []
26
+ images.append(
27
+ (
28
+ generate_images(latents_x, text_embeddings, num_inference_steps),
29
+ "{}".format(1),
30
+ )
31
+ )
32
+
33
+ for i in range(num_images):
34
+ np.random.seed(i)
35
+ progress(i / (num_images))
36
+ latents_y = generate_latents(np.random.randint(0, 100000))
37
+ scale_y = torch.sin(
38
+ torch.linspace(0, 2, num_images) * torch.pi * perturbation_size / 4
39
+ ).to(torch_device)
40
+ noise_y = torch.tensordot(scale_y, latents_y, dims=0)
41
+
42
+ noise = noise_x + noise_y
43
+ image = generate_images(
44
+ noise[num_images - 1], text_embeddings, num_inference_steps
45
+ )
46
+ images.append((image, "{}".format(i + 2)))
47
+
48
+ fname = "perturbations"
49
+ tab_config = {
50
+ "Tab": "Perturbations",
51
+ "Prompt": prompt,
52
+ "Number of Perturbations": num_images,
53
+ "Perturbation Size": perturbation_size,
54
+ "Number of Inference Steps per Image": num_inference_steps,
55
+ "Seed": seed,
56
+ }
57
+ export_as_zip(images, fname, tab_config)
58
+
59
+ return images, f"outputs/{fname}.zip"
60
+
61
+
62
+ __all__ = ["display_perturb_images"]
src/pipelines/poke.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from src.util.base import *
4
+ from src.util.params import *
5
+ from PIL import Image, ImageDraw
6
+
7
+
8
+ def visualize_poke(
9
+ pokeX, pokeY, pokeHeight, pokeWidth, imageHeight=imageHeight, imageWidth=imageWidth
10
+ ):
11
+ if (
12
+ (pokeX - pokeWidth // 2 < 0)
13
+ or (pokeX + pokeWidth // 2 > imageWidth // 8)
14
+ or (pokeY - pokeHeight // 2 < 0)
15
+ or (pokeY + pokeHeight // 2 > imageHeight // 8)
16
+ ):
17
+ gr.Warning("Modification outside image")
18
+ shape = [
19
+ (pokeX * 8 - pokeWidth * 8 // 2, pokeY * 8 - pokeHeight * 8 // 2),
20
+ (pokeX * 8 + pokeWidth * 8 // 2, pokeY * 8 + pokeHeight * 8 // 2),
21
+ ]
22
+
23
+ blank = Image.new("RGB", (imageWidth, imageHeight))
24
+
25
+ if os.path.exists("outputs/original.png"):
26
+ oImg = Image.open("outputs/original.png")
27
+ pImg = Image.open("outputs/poked.png")
28
+ else:
29
+ oImg = blank
30
+ pImg = blank
31
+
32
+ oRec = ImageDraw.Draw(oImg)
33
+ pRec = ImageDraw.Draw(pImg)
34
+
35
+ oRec.rectangle(shape, outline="white")
36
+ pRec.rectangle(shape, outline="white")
37
+
38
+ return oImg, pImg
39
+
40
+
41
+ def display_poke_images(
42
+ prompt,
43
+ seed,
44
+ num_inference_steps,
45
+ poke=False,
46
+ pokeX=None,
47
+ pokeY=None,
48
+ pokeHeight=None,
49
+ pokeWidth=None,
50
+ intermediate=False,
51
+ progress=gr.Progress(),
52
+ ):
53
+ text_embeddings = get_text_embeddings(prompt)
54
+ latents, modified_latents = generate_modified_latents(
55
+ poke, seed, pokeX, pokeY, pokeHeight, pokeWidth
56
+ )
57
+
58
+ progress(0)
59
+ images = generate_images(
60
+ latents, text_embeddings, num_inference_steps, intermediate=intermediate
61
+ )
62
+
63
+ if not intermediate:
64
+ images.save("outputs/original.png")
65
+
66
+ if poke:
67
+ progress(0.5)
68
+ modImages = generate_images(
69
+ modified_latents,
70
+ text_embeddings,
71
+ num_inference_steps,
72
+ intermediate=intermediate,
73
+ )
74
+
75
+ if not intermediate:
76
+ modImages.save("outputs/poked.png")
77
+ else:
78
+ modImages = None
79
+
80
+ return images, modImages
81
+
82
+
83
+ __all__ = ["display_poke_images", "visualize_poke"]
src/pipelines/seed.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from src.util.base import *
3
+ from src.util.params import *
4
+
5
+
6
+ def display_seed_images(
7
+ prompt, num_inference_steps, num_images, progress=gr.Progress()
8
+ ):
9
+ text_embeddings = get_text_embeddings(prompt)
10
+
11
+ images = []
12
+ progress(0)
13
+
14
+ for i in range(num_images):
15
+ progress(i / num_images)
16
+ latents = generate_latents(i)
17
+ image = generate_images(latents, text_embeddings, num_inference_steps)
18
+ images.append((image, "{}".format(i + 1)))
19
+
20
+ fname = "seeds"
21
+ tab_config = {
22
+ "Tab": "Seeds",
23
+ "Prompt": prompt,
24
+ "Number of Seeds": num_images,
25
+ "Number of Inference Steps per Image": num_inference_steps,
26
+ }
27
+ export_as_zip(images, fname, tab_config)
28
+
29
+ return images, f"outputs/{fname}.zip"
30
+
31
+
32
+ __all__ = ["display_seed_images"]
src/util/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .base import *
2
+ from .params import *
3
+ from .clip_config import *
src/util/base.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import torch
4
+ import zipfile
5
+ import numpy as np
6
+ import gradio as gr
7
+ from PIL import Image
8
+ from tqdm.auto import tqdm
9
+ from src.util.params import *
10
+ from src.util.clip_config import *
11
+ import matplotlib.pyplot as plt
12
+
13
+
14
+ def get_text_embeddings(
15
+ prompt,
16
+ tokenizer=tokenizer,
17
+ text_encoder=text_encoder,
18
+ torch_device=torch_device,
19
+ batch_size=1,
20
+ negative_prompt="",
21
+ ):
22
+ text_input = tokenizer(
23
+ prompt,
24
+ padding="max_length",
25
+ max_length=tokenizer.model_max_length,
26
+ truncation=True,
27
+ return_tensors="pt",
28
+ )
29
+
30
+ with torch.no_grad():
31
+ text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
32
+ max_length = text_input.input_ids.shape[-1]
33
+ uncond_input = tokenizer(
34
+ [negative_prompt] * batch_size,
35
+ padding="max_length",
36
+ max_length=max_length,
37
+ return_tensors="pt",
38
+ )
39
+ with torch.no_grad():
40
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
41
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
42
+
43
+ return text_embeddings
44
+
45
+
46
+ def generate_latents(
47
+ seed,
48
+ height=imageHeight,
49
+ width=imageWidth,
50
+ torch_device=torch_device,
51
+ unet=unet,
52
+ batch_size=1,
53
+ ):
54
+ generator = torch.Generator().manual_seed(int(seed))
55
+
56
+ latents = torch.randn(
57
+ (batch_size, unet.config.in_channels, height // 8, width // 8),
58
+ generator=generator,
59
+ ).to(torch_device)
60
+
61
+ return latents
62
+
63
+
64
+ def generate_modified_latents(
65
+ poke,
66
+ seed,
67
+ pokeX=None,
68
+ pokeY=None,
69
+ pokeHeight=None,
70
+ pokeWidth=None,
71
+ imageHeight=imageHeight,
72
+ imageWidth=imageWidth,
73
+ ):
74
+ original_latents = generate_latents(seed, height=imageHeight, width=imageWidth)
75
+ if poke:
76
+ np.random.seed(seed)
77
+ poke_latents = generate_latents(
78
+ np.random.randint(0, 100000), height=pokeHeight * 8, width=pokeWidth * 8
79
+ )
80
+
81
+ x_origin = pokeX - pokeWidth // 2
82
+ y_origin = pokeY - pokeHeight // 2
83
+
84
+ modified_latents = original_latents.clone()
85
+ modified_latents[
86
+ :, :, y_origin : y_origin + pokeHeight, x_origin : x_origin + pokeWidth
87
+ ] = poke_latents
88
+ else:
89
+ modified_latents = None
90
+
91
+ return original_latents, modified_latents
92
+
93
+
94
+ def convert_to_pil_image(image):
95
+ image = (image / 2 + 0.5).clamp(0, 1)
96
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
97
+ images = (image * 255).round().astype("uint8")
98
+ pil_images = [Image.fromarray(image) for image in images]
99
+ return pil_images[0]
100
+
101
+
102
+ def generate_images(
103
+ latents,
104
+ text_embeddings,
105
+ num_inference_steps,
106
+ unet=unet,
107
+ guidance_scale=guidance_scale,
108
+ vae=vae,
109
+ scheduler=scheduler,
110
+ intermediate=False,
111
+ progress=gr.Progress(),
112
+ ):
113
+ scheduler.set_timesteps(num_inference_steps)
114
+ latents = latents * scheduler.init_noise_sigma
115
+ images = []
116
+ i = 1
117
+
118
+ for t in tqdm(scheduler.timesteps):
119
+ latent_model_input = torch.cat([latents] * 2)
120
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
121
+
122
+ with torch.no_grad():
123
+ noise_pred = unet(
124
+ latent_model_input, t, encoder_hidden_states=text_embeddings
125
+ ).sample
126
+
127
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
128
+ noise_pred = noise_pred_uncond + guidance_scale * (
129
+ noise_pred_text - noise_pred_uncond
130
+ )
131
+
132
+ if intermediate:
133
+ progress(((1000 - t) / 1000))
134
+ Latents = 1 / 0.18215 * latents
135
+ with torch.no_grad():
136
+ image = vae.decode(Latents).sample
137
+ images.append((convert_to_pil_image(image), "{}".format(i)))
138
+
139
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
140
+ i += 1
141
+
142
+ if not intermediate:
143
+ Latents = 1 / 0.18215 * latents
144
+ with torch.no_grad():
145
+ image = vae.decode(Latents).sample
146
+ images = convert_to_pil_image(image)
147
+
148
+ return images
149
+
150
+
151
+ def get_word_embeddings(
152
+ prompt, tokenizer=tokenizer, text_encoder=text_encoder, torch_device=torch_device
153
+ ):
154
+ text_input = tokenizer(
155
+ prompt,
156
+ padding="max_length",
157
+ max_length=tokenizer.model_max_length,
158
+ truncation=True,
159
+ return_tensors="pt",
160
+ ).to(torch_device)
161
+
162
+ with torch.no_grad():
163
+ text_embeddings = text_encoder(text_input.input_ids)[0].reshape(1, -1)
164
+
165
+ text_embeddings = text_embeddings.cpu().numpy()
166
+ return text_embeddings / np.linalg.norm(text_embeddings)
167
+
168
+
169
+ def get_concat_embeddings(names, merge=False):
170
+ embeddings = []
171
+
172
+ for name in names:
173
+ embedding = get_word_embeddings(name)
174
+ embeddings.append(embedding)
175
+
176
+ embeddings = np.vstack(embeddings)
177
+
178
+ if merge:
179
+ embeddings = np.average(embeddings, axis=0).reshape(1, -1)
180
+
181
+ return embeddings
182
+
183
+
184
+ def get_axis_embeddings(A, B):
185
+ emb = []
186
+
187
+ for a, b in zip(A, B):
188
+ e = get_word_embeddings(a) - get_word_embeddings(b)
189
+ emb.append(e)
190
+
191
+ emb = np.vstack(emb)
192
+ ax = np.average(emb, axis=0).reshape(1, -1)
193
+
194
+ return ax
195
+
196
+
197
+ def calculate_residual(
198
+ axis, axis_names, from_words=None, to_words=None, residual_axis=1
199
+ ):
200
+ axis_indices = [0, 1, 2]
201
+ axis_indices.remove(residual_axis)
202
+
203
+ if axis_names[axis_indices[0]] in axis_combinations:
204
+ fembeddings = get_concat_embeddings(
205
+ axis_combinations[axis_names[axis_indices[0]]], merge=True
206
+ )
207
+ else:
208
+ axis_combinations[axis_names[axis_indices[0]]] = from_words + to_words
209
+ fembeddings = get_concat_embeddings(from_words + to_words, merge=True)
210
+
211
+ if axis_names[axis_indices[1]] in axis_combinations:
212
+ sembeddings = get_concat_embeddings(
213
+ axis_combinations[axis_names[axis_indices[1]]], merge=True
214
+ )
215
+ else:
216
+ axis_combinations[axis_names[axis_indices[1]]] = from_words + to_words
217
+ sembeddings = get_concat_embeddings(from_words + to_words, merge=True)
218
+
219
+ fprojections = fembeddings @ axis[axis_indices[0]].T
220
+ sprojections = sembeddings @ axis[axis_indices[1]].T
221
+
222
+ partial_residual = fembeddings - (fprojections.reshape(-1, 1) * fembeddings)
223
+ residual = partial_residual - (sprojections.reshape(-1, 1) * sembeddings)
224
+
225
+ return residual
226
+
227
+
228
+ def calculate_step_size(num_images, differentiation):
229
+ return differentiation / (num_images - 1)
230
+
231
+
232
+ def generate_seed_vis(seed):
233
+ np.random.seed(seed)
234
+ emb = np.random.rand(15)
235
+ plt.close()
236
+ plt.switch_backend("agg")
237
+ plt.figure(figsize=(10, 0.5))
238
+ plt.imshow([emb], cmap="viridis")
239
+ plt.axis("off")
240
+ return plt
241
+
242
+
243
+ def export_as_gif(images, filename, frames_per_second=2, reverse=False):
244
+ imgs = [img[0] for img in images]
245
+
246
+ if reverse:
247
+ imgs += imgs[2:-1][::-1]
248
+
249
+ imgs[0].save(
250
+ f"outputs/{filename}",
251
+ format="GIF",
252
+ save_all=True,
253
+ append_images=imgs[1:],
254
+ duration=1000 // frames_per_second,
255
+ loop=0,
256
+ )
257
+
258
+
259
+ def export_as_zip(images, fname, tab_config=None):
260
+
261
+ if not os.path.exists(f"outputs/{fname}.zip"):
262
+ os.makedirs("outputs", exist_ok=True)
263
+
264
+ with zipfile.ZipFile(f"outputs/{fname}.zip", "w") as img_zip:
265
+
266
+ if tab_config:
267
+ with open("outputs/config.txt", "w") as f:
268
+ for key, value in tab_config.items():
269
+ f.write(f"{key}: {value}\n")
270
+ f.close()
271
+
272
+ img_zip.write("outputs/config.txt", "config.txt")
273
+
274
+ for idx, img in enumerate(images):
275
+ buff = io.BytesIO()
276
+ img[0].save(buff, format="PNG")
277
+ buff = buff.getvalue()
278
+ max_num = len(images)
279
+ num_leading_zeros = len(str(max_num))
280
+ img_name = f"{{:0{num_leading_zeros}}}.png"
281
+ img_zip.writestr(img_name.format(idx + 1), buff)
282
+
283
+
284
+ def read_html(file_path):
285
+ with open(file_path, "r", encoding="utf-8") as f:
286
+ content = f.read()
287
+ return content
288
+
289
+
290
+ __all__ = [
291
+ "get_text_embeddings",
292
+ "generate_latents",
293
+ "generate_modified_latents",
294
+ "generate_images",
295
+ "get_word_embeddings",
296
+ "get_concat_embeddings",
297
+ "get_axis_embeddings",
298
+ "calculate_residual",
299
+ "calculate_step_size",
300
+ "generate_seed_vis",
301
+ "export_as_gif",
302
+ "export_as_zip",
303
+ "read_html",
304
+ ]
src/util/clip_config.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ masculine = [
2
+ "man",
3
+ "king",
4
+ "prince",
5
+ "husband",
6
+ "father",
7
+ ]
8
+
9
+ feminine = [
10
+ "woman",
11
+ "queen",
12
+ "princess",
13
+ "wife",
14
+ "mother",
15
+ ]
16
+
17
+ young = [
18
+ "man",
19
+ "woman",
20
+ "king",
21
+ "queen",
22
+ "father",
23
+ ]
24
+
25
+ old = [
26
+ "boy",
27
+ "girl",
28
+ "prince",
29
+ "princess",
30
+ "son",
31
+ ]
32
+
33
+ common = [
34
+ "man",
35
+ "woman",
36
+ "boy",
37
+ "girl",
38
+ "woman",
39
+ ]
40
+
41
+ elite = [
42
+ "king",
43
+ "queen",
44
+ "prince",
45
+ "princess",
46
+ "duchess",
47
+ ]
48
+
49
+ singular = [
50
+ "boy",
51
+ "girl",
52
+ "cat",
53
+ "puppy",
54
+ "computer",
55
+ ]
56
+
57
+ plural = [
58
+ "boys",
59
+ "girls",
60
+ "cats",
61
+ "puppies",
62
+ "computers",
63
+ ]
64
+
65
+ examples = [
66
+ "king",
67
+ "queen",
68
+ "man",
69
+ "woman",
70
+ "boys",
71
+ "girls",
72
+ "apple",
73
+ "orange",
74
+ ]
75
+
76
+ axis_names = ["gender", "residual", "age"]
77
+
78
+ axis_combinations = {
79
+ "age": young + old,
80
+ "gender": masculine + feminine,
81
+ "royalty": common + elite,
82
+ "number": singular + plural,
83
+ }
84
+
85
+ axisMap = {
86
+ "X - Axis": 0,
87
+ "Y - Axis": 1,
88
+ "Z - Axis": 2,
89
+ }
90
+
91
+ whichAxisMap = {
92
+ "which_axis_1": "X - Axis",
93
+ "which_axis_2": "Z - Axis",
94
+ "which_axis_3": "Y - Axis",
95
+ "which_axis_4": "---",
96
+ "which_axis_5": "---",
97
+ "which_axis_6": "---",
98
+ }
99
+
100
+ __all__ = [
101
+ "axisMap",
102
+ "whichAxisMap",
103
+ "axis_names",
104
+ "axis_combinations",
105
+ "examples",
106
+ "masculine",
107
+ "feminine",
108
+ "young",
109
+ "old",
110
+ "common",
111
+ "elite",
112
+ "singular",
113
+ "plural",
114
+ ]
src/util/params.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import secrets
3
+ from gradio.networking import setup_tunnel
4
+ from transformers import CLIPTextModel, CLIPTokenizer
5
+ from diffusers import (
6
+ AutoencoderKL,
7
+ UNet2DConditionModel,
8
+ LCMScheduler,
9
+ DDIMScheduler,
10
+ StableDiffusionPipeline,
11
+ )
12
+
13
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ isLCM = False
16
+ HF_ACCESS_TOKEN = ""
17
+
18
+ model_path = "segmind/small-sd"
19
+ inpaint_model_path = "Lykon/dreamshaper-8-inpainting"
20
+ prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
21
+ promptA = "Self-portrait oil painting, a beautiful man with golden hair, 8k"
22
+ promptB = "Self-portrait oil painting, a beautiful woman with golden hair, 8k"
23
+ negative_prompt = "a photo frame"
24
+
25
+ num_images = 5
26
+ degree = 360
27
+ perturbation_size = 0.1
28
+ num_inference_steps = 8
29
+ seed = 69420
30
+
31
+ guidance_scale = 8
32
+ guidance_values = "1, 8, 20"
33
+ intermediate = True
34
+ pokeX, pokeY = 256, 256
35
+ pokeHeight, pokeWidth = 128, 128
36
+ imageHeight, imageWidth = 512, 512
37
+
38
+ tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
39
+ text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder").to(
40
+ torch_device
41
+ )
42
+
43
+ if isLCM:
44
+ scheduler = LCMScheduler.from_pretrained(model_path, subfolder="scheduler")
45
+ else:
46
+ scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler")
47
+
48
+ unet = UNet2DConditionModel.from_pretrained(model_path, subfolder="unet").to(
49
+ torch_device
50
+ )
51
+ vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae").to(torch_device)
52
+
53
+ pipe = StableDiffusionPipeline(
54
+ tokenizer=tokenizer,
55
+ text_encoder=text_encoder,
56
+ unet=unet,
57
+ scheduler=scheduler,
58
+ vae=vae,
59
+ safety_checker=None,
60
+ feature_extractor=None,
61
+ requires_safety_checker=False,
62
+ ).to(torch_device)
63
+
64
+ dash_tunnel = setup_tunnel("0.0.0.0", 8000, secrets.token_urlsafe(32))
65
+
66
+ __all__ = [
67
+ "prompt",
68
+ "negative_prompt",
69
+ "num_images",
70
+ "degree",
71
+ "perturbation_size",
72
+ "num_inference_steps",
73
+ "seed",
74
+ "intermediate",
75
+ "pokeX",
76
+ "pokeY",
77
+ "pokeHeight",
78
+ "pokeWidth",
79
+ "promptA",
80
+ "promptB",
81
+ "tokenizer",
82
+ "text_encoder",
83
+ "scheduler",
84
+ "unet",
85
+ "vae",
86
+ "torch_device",
87
+ "imageHeight",
88
+ "imageWidth",
89
+ "guidance_scale",
90
+ "guidance_values",
91
+ "HF_ACCESS_TOKEN",
92
+ "model_path",
93
+ "inpaint_model_path",
94
+ "dash_tunnel",
95
+ "pipe",
96
+ ]