Spaces:
Runtime error
Runtime error
File size: 2,623 Bytes
29958a2 1da736d 29958a2 0d7e4cb 8565ee2 29958a2 8b6e253 29958a2 fad62be 29958a2 e6cf7d1 29958a2 b1eac09 29958a2 fad62be 859c26b 4184307 29958a2 1da736d 29958a2 1da736d ce08539 a1fff46 0d7e4cb 29958a2 f55a34e 3cab9bb 1da736d cf23f16 1da736d 29958a2 1da736d 29958a2 55bfa61 1da736d 29958a2 c465829 29958a2 858501a bfd5e5b 858501a 52c932d 858501a 29958a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import gradio as gr
import jax.numpy as jnp
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
from diffusers import UniPCMultistepScheduler
import torch
torch.backends.cuda.matmul.allow_tf32 = True
import torchvision
import torchvision.transforms as T
#from torchvision.transforms import v2 as T2
import cv2
import PIL
from PIL import Image
import numpy as np
import torchvision.transforms.functional as F
output_res = (768,768)
conditioning_image_transforms = T.Compose(
[
#T2.ScaleJitter(target_size=output_res, scale_range=(0.5, 3.0))),
T.RandomCrop(size=output_res, pad_if_needed=True, padding_mode="symmetric"),
T.ToTensor(),
T.Normalize([0.5], [0.5]),
]
)
cnet = FlaxControlNetModel.from_pretrained("./models/catcon-controlnet-wd", dtype=jnp.bfloat16, from_flax=True)
pipe = FlaxStableDiffusionControlNetPipeline.from_pretrained(
"./models/wd-1-5-b2",
controlnet=cnet,
dtype=jnp.bfloat16,
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
#pipe.enable_model_cpu_offload()
#pipe.enable_xformers_memory_efficient_attention()
generator = torch.manual_seed(0)
# inference function takes prompt, negative prompt and image
def infer(prompt, negative_prompt, image):
# implement your inference function here
inp = Image.fromarray(image)
cond_input = conditioning_image_transforms(inp)
cond_input = T.ToPILImage()(cond_input)
cond_in = pipe.prepare_image_inputs([cond_input] * 4)
output = pipe(
prompt,
cond_in,
generator=generator,
num_images_per_prompt=4,
num_inference_steps=20,
jit=True
)
return output.images
gr.Interface(
infer,
inputs=[
gr.Textbox(
label="Enter prompt",
max_lines=1,
placeholder="1girl, green hair, sweater, looking at viewer, upper body, beanie, outdoors, watercolor, night, turtleneck",
),
gr.Textbox(
label="Enter negative prompt",
max_lines=1,
placeholder="low quality",
),
gr.Image(),
],
outputs=gr.Gallery().style(grid=[2], height="auto"),
title="Generate controlled outputs with Categorical Conditioning on Waifu Diffusion 1.5 beta 2.",
description="This Space uses image examples as style conditioning.",
examples=[["1girl, green hair, sweater, looking at viewer, upper body, beanie, outdoors, watercolor, night, turtleneck", "low quality", "wikipe_cond_1.png"]],
allow_flagging=False,
).launch(enable_queue=True)
|