File size: 2,623 Bytes
29958a2
1da736d
 
29958a2
 
0d7e4cb
8565ee2
29958a2
8b6e253
29958a2
fad62be
29958a2
e6cf7d1
29958a2
b1eac09
 
29958a2
 
 
 
fad62be
859c26b
4184307
29958a2
 
 
 
1da736d
 
29958a2
 
1da736d
 
ce08539
a1fff46
0d7e4cb
29958a2
 
 
 
 
 
f55a34e
3cab9bb
1da736d
cf23f16
1da736d
 
29958a2
 
 
1da736d
29958a2
55bfa61
1da736d
 
29958a2
 
c465829
29958a2
858501a
 
 
 
 
 
 
 
 
 
 
 
 
bfd5e5b
858501a
 
 
 
52c932d
858501a
 
29958a2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
import jax.numpy as jnp
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
from diffusers import UniPCMultistepScheduler
import torch
torch.backends.cuda.matmul.allow_tf32 = True
import torchvision
import torchvision.transforms as T
#from torchvision.transforms import v2 as T2
import cv2
import PIL
from PIL import Image
import numpy as np

import torchvision.transforms.functional as F

output_res = (768,768)

conditioning_image_transforms = T.Compose(
    [
        #T2.ScaleJitter(target_size=output_res, scale_range=(0.5, 3.0))),
        T.RandomCrop(size=output_res, pad_if_needed=True, padding_mode="symmetric"),
        T.ToTensor(),
        T.Normalize([0.5], [0.5]),
    ]
)

cnet = FlaxControlNetModel.from_pretrained("./models/catcon-controlnet-wd", dtype=jnp.bfloat16, from_flax=True)
pipe = FlaxStableDiffusionControlNetPipeline.from_pretrained(
        "./models/wd-1-5-b2",
        controlnet=cnet,
        dtype=jnp.bfloat16,
        )
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
#pipe.enable_model_cpu_offload()
#pipe.enable_xformers_memory_efficient_attention()

generator = torch.manual_seed(0)

# inference function takes prompt, negative prompt and image
def infer(prompt, negative_prompt, image):
    # implement your inference function here
    inp = Image.fromarray(image)

    cond_input = conditioning_image_transforms(inp)
    cond_input = T.ToPILImage()(cond_input)

    cond_in = pipe.prepare_image_inputs([cond_input] * 4)
    
    output = pipe(
        prompt,
        cond_in,
        generator=generator,
        num_images_per_prompt=4,
        num_inference_steps=20,
        jit=True
            )

    return output.images

gr.Interface(
    infer,
    inputs=[
        gr.Textbox(
            label="Enter prompt",
            max_lines=1,
            placeholder="1girl, green hair, sweater, looking at viewer, upper body, beanie, outdoors, watercolor, night, turtleneck",
        ),
        gr.Textbox(
            label="Enter negative prompt",
            max_lines=1,
            placeholder="low quality",
        ),
        gr.Image(),
    ],
    outputs=gr.Gallery().style(grid=[2], height="auto"),
    title="Generate controlled outputs with Categorical Conditioning on Waifu Diffusion 1.5 beta 2.",
    description="This Space uses image examples as style conditioning.",
    examples=[["1girl, green hair, sweater, looking at viewer, upper body, beanie, outdoors, watercolor, night, turtleneck", "low quality", "wikipe_cond_1.png"]],
    allow_flagging=False,
).launch(enable_queue=True)