File size: 4,034 Bytes
e4174c1
 
 
 
 
384dde1
e4174c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384dde1
e4174c1
 
 
 
 
 
 
 
 
 
 
57e2f92
 
 
384dde1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from flax.jax_utils import replicate
from jax import pmap
from flax.training.common_utils import shard
import jax
import jax.numpy as jnp
import gradio as gr

from pathlib import Path
from PIL import Image
import numpy as np

from diffusers import FlaxStableDiffusionPipeline


import os
if 'TPU_NAME' in os.environ:
    import requests
    if 'TPU_DRIVER_MODE' not in globals():
        url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
        resp = requests.post(url)
        TPU_DRIVER_MODE = 1


    from jax.config import config
    config.FLAGS.jax_xla_backend = "tpu_driver"
    config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
    print('Registered TPU:', config.FLAGS.jax_backend_target)
else:
    print('No TPU detected. Can be changed under "Runtime/Change runtime type".')

import jax
jax.local_devices()
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind

print(f"Found {num_devices} JAX devices of type {device_type}.")

def sd2_inference(pipeline, prompts, params, seed = 42, num_inference_steps = 50 ):
    prng_seed = jax.random.PRNGKey(seed)
    prompt_ids = pipeline.prepare_inputs(prompts)
    params = replicate(params)
    prng_seed = jax.random.split(prng_seed, jax.device_count())
    prompt_ids = shard(prompt_ids)
    images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
    images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
    images = pipeline.numpy_to_pil(images)
    return images



HF_ACCESS_TOKEN = os.environ["HFAUTH"]

# Load Model
# - Reference: https://github.com/huggingface/diffusers/blob/main/README.md
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    use_auth_token = HF_ACCESS_TOKEN,
    revision="bf16",
    dtype=jnp.bfloat16,
)

prompts = ["apple"] * 1









def generate_image(dense_class_vector=None, int_index=None, noise_seed_vector=None, truncation=0.4):
    seed = int(noise_seed_vector.sum().item()) if noise_seed_vector is not None else None
    noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1, seed=seed)
    noise_vector = torch.from_numpy(noise_vector)
    if int_index is not None:
        class_vector = one_hot_from_int([int_index], batch_size=1)
        class_vector = torch.from_numpy(class_vector)
        dense_class_vector = gan_model.embeddings(class_vector)
    else:
        if isinstance(dense_class_vector, np.ndarray):
            dense_class_vector = torch.tensor(dense_class_vector)
        dense_class_vector = dense_class_vector.view(1, 128)

    input_vector = torch.cat([noise_vector, dense_class_vector], dim=1)

    # Generate an image
    with torch.no_grad():
        output = gan_model.generator(input_vector, truncation)
    output = output.cpu().numpy()
    output = output.transpose((0, 2, 3, 1))
    output = ((output + 1.0) / 2.0) * 256
    output.clip(0, 255, out=output)
    output = np.asarray(np.uint8(output[0]), dtype=np.uint8)
    return output



def text_to_image(text):
    images = sd2_inference(pipeline, [text], params, seed = 42, num_inference_steps = 5 )
    img = images[0]
    return img

examples = ["apple",
            "banana",
            "chocolate"]

if __name__ == '__main__':
    interFace = gr.Interface(fn=text_to_image,
                             inputs=gr.inputs.Textbox(placeholder="Enter the text to Encode to an image", label="Text "
                                                                                                           "query",
                                                      lines=1),
                             outputs=gr.outputs.Image(type="auto", label="Generated Image"),
                             verbose=True,
                             examples=examples,
                             title="Generate Image from Text",
                             description="",
                             theme="huggingface")
    interFace.launch()