r3gm's picture
Update app.py
6d92970 verified
raw
history blame
7.74 kB
import gradio as gr
import spaces
import numpy as np
import random
from diffusers import DiffusionPipeline
import torch
import threading
from PIL import Image
MODEL_ID = "cagliostrolab/animagine-xl-3.1"
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
torch.cuda.max_memory_allocated(device=device)
pipe = DiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
use_safetensors=True,
)
else:
pipe = DiffusionPipeline.from_pretrained(MODEL_ID, use_safetensors=True)
pipe = pipe.to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1536
def latents_to_rgb(latents):
weights = (
(60, -60, 25, -70),
(60, -5, 15, -50),
(60, 10, -5, -35)
)
weights_tensor = torch.tensor(weights, dtype=latents.dtype, device=latents.device).T
biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype, device=latents.device)
rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.view(-1, 1, 1)
image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()
image_array = image_array.transpose(1, 2, 0) # Change the order of dimensions
pil_image = Image.fromarray(image_array)
resized_image = pil_image.resize((pil_image.size[0] * 2, pil_image.size[1] * 2), Image.LANCZOS) # Resize 128x128 * ...
return resized_image
class BaseGenerator:
def __init__(self, pipe):
self.pipe = pipe
self.image = None
self.new_image_event = threading.Event()
self.generation_finished = threading.Event()
self.intermediate_image_concurrency(3)
def intermediate_image_concurrency(self, concurrency):
self.concurrency = concurrency
def decode_tensors(self, pipe, step, timestep, callback_kwargs):
latents = callback_kwargs["latents"]
if step % self.concurrency == 0: # every how many steps
print(step)
self.image = latents_to_rgb(latents)
self.new_image_event.set() # Signal that a new image is available
return callback_kwargs
def show_images(self):
while not self.generation_finished.is_set() or self.new_image_event.is_set():
self.new_image_event.wait() # Wait for a new image
self.new_image_event.clear() # Clear the event flag
if self.image:
yield self.image # Yield the new image
def generate_images(self, **kwargs):
if kwargs.get('randomize_seed', False):
kwargs['seed'] = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(kwargs['seed'])
self.image = None
self.image = self.pipe(
height=kwargs['height'],
width=kwargs['width'],
prompt=kwargs['prompt'],
negative_prompt=kwargs['negative_prompt'],
guidance_scale=kwargs['guidance_scale'],
num_inference_steps=kwargs['num_inference_steps'],
generator=generator,
callback_on_step_end=self.decode_tensors,
callback_on_step_end_tensor_inputs=["latents"],
).images[0]
print("finish")
self.new_image_event.set() # Result image
self.generation_finished.set() # Signal that generation is finished
def stream(self, prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
self.generation_finished.clear()
threading.Thread(target=self.generate_images, args=(), kwargs=dict(
prompt=prompt,
negative_prompt=negative_prompt,
seed=seed,
randomize_seed=randomize_seed,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps
)).start()
return self.show_images()
image_generator = BaseGenerator(pipe)
@spaces.GPU
def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, concurrency):
image_generator.intermediate_image_concurrency(concurrency)
stream = image_generator.stream(
prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps
)
yield None
for image in stream:
yield image
css="""
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
if torch.cuda.is_available():
power_device = "GPU"
else:
power_device = "CPU"
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
# Text-to-Image: Display each generation step
Gradio template for displaying preview images during generation steps
Currently running on {power_device}.
""")
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
value="1girl, souryuu asuka langley, neon genesis evangelion, solo, upper body, v, smile, looking at viewer, outdoors, night",
)
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
visible=True,
value="nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
)
with gr.Row():
run_button = gr.Button("Run", scale=0)
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=832,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1216,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=30.0,
step=0.1,
value=7.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=100,
step=1,
value=76,
)
concurrency_gui = gr.Slider(
label="Number of steps to show the next preview image",
minimum=1,
maximum=20,
step=1,
value=3,
)
run_button.click(
fn = infer,
inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, concurrency_gui],
outputs = [result],
show_progress="minimal",
)
demo.queue().launch()