File size: 1,724 Bytes
04ef268
29a0427
04ef268
 
 
 
 
ac2625d
04ef268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import spaces
import numpy as np
import gradio as gr
from src.util.base import *
from src.util.params import *

@spaces.GPU()
def display_perturb_images(
    prompt,
    seed,
    num_inference_steps,
    num_images,
    perturbation_size,
    progress=gr.Progress(),
):
    text_embeddings = get_text_embeddings(prompt)

    latents_x = generate_latents(seed)
    scale_x = torch.cos(
        torch.linspace(0, 2, num_images) * torch.pi * perturbation_size / 4
    ).to(torch_device)
    noise_x = torch.tensordot(scale_x, latents_x, dims=0)

    progress(0)
    images = []
    images.append(
        (
            generate_images(latents_x, text_embeddings, num_inference_steps),
            "{}".format(1),
        )
    )

    for i in range(num_images):
        np.random.seed(i)
        progress(i / (num_images))
        latents_y = generate_latents(np.random.randint(0, 100000))
        scale_y = torch.sin(
            torch.linspace(0, 2, num_images) * torch.pi * perturbation_size / 4
        ).to(torch_device)
        noise_y = torch.tensordot(scale_y, latents_y, dims=0)

        noise = noise_x + noise_y
        image = generate_images(
            noise[num_images - 1], text_embeddings, num_inference_steps
        )
        images.append((image, "{}".format(i + 2)))

    fname = "perturbations"
    tab_config = {
        "Tab": "Perturbations",
        "Prompt": prompt,
        "Number of Perturbations": num_images,
        "Perturbation Size": perturbation_size,
        "Number of Inference Steps per Image": num_inference_steps,
        "Seed": seed,
    }
    export_as_zip(images, fname, tab_config)

    return images, f"outputs/{fname}.zip"


__all__ = ["display_perturb_images"]