Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,117 Bytes
04ef268 7fd88e0 04ef268 4fe7eec 04ef268 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import os
import spaces
import gradio as gr
from src.util.base import *
from src.util.params import *
from PIL import Image, ImageDraw
def visualize_poke(
pokeX, pokeY, pokeHeight, pokeWidth, imageHeight=imageHeight, imageWidth=imageWidth
):
if (
(pokeX - pokeWidth // 2 < 0)
or (pokeX + pokeWidth // 2 > imageWidth // 8)
or (pokeY - pokeHeight // 2 < 0)
or (pokeY + pokeHeight // 2 > imageHeight // 8)
):
gr.Warning("Modification outside image")
shape = [
(pokeX * 8 - pokeWidth * 8 // 2, pokeY * 8 - pokeHeight * 8 // 2),
(pokeX * 8 + pokeWidth * 8 // 2, pokeY * 8 + pokeHeight * 8 // 2),
]
blank = Image.new("RGB", (imageWidth, imageHeight))
if os.path.exists("outputs/original.png"):
oImg = Image.open("outputs/original.png")
pImg = Image.open("outputs/poked.png")
else:
oImg = blank
pImg = blank
oRec = ImageDraw.Draw(oImg)
pRec = ImageDraw.Draw(pImg)
oRec.rectangle(shape, outline="white")
pRec.rectangle(shape, outline="white")
return oImg, pImg
@spaces.GPU()
def display_poke_images(
prompt,
seed,
num_inference_steps,
poke=False,
pokeX=None,
pokeY=None,
pokeHeight=None,
pokeWidth=None,
intermediate=False,
progress=gr.Progress(),
):
text_embeddings = get_text_embeddings(prompt)
latents, modified_latents = generate_modified_latents(
poke, seed, pokeX, pokeY, pokeHeight, pokeWidth
)
progress(0)
images = generate_images(
latents, text_embeddings, num_inference_steps, intermediate=intermediate
)
if not intermediate:
images.save("outputs/original.png")
if poke:
progress(0.5)
modImages = generate_images(
modified_latents,
text_embeddings,
num_inference_steps,
intermediate=intermediate,
)
if not intermediate:
modImages.save("outputs/poked.png")
else:
modImages = None
return images, modImages
__all__ = ["display_poke_images", "visualize_poke"]
|