Spaces:
Paused
Paused
| import os | |
| import spaces | |
| import gradio as gr | |
| from src.util.base import * | |
| from src.util.params import * | |
| from PIL import Image, ImageDraw | |
| def visualize_poke( | |
| pokeX, pokeY, pokeHeight, pokeWidth, imageHeight=imageHeight, imageWidth=imageWidth | |
| ): | |
| if ( | |
| (pokeX - pokeWidth // 2 < 0) | |
| or (pokeX + pokeWidth // 2 > imageWidth // 8) | |
| or (pokeY - pokeHeight // 2 < 0) | |
| or (pokeY + pokeHeight // 2 > imageHeight // 8) | |
| ): | |
| gr.Warning("Modification outside image") | |
| shape = [ | |
| (pokeX * 8 - pokeWidth * 8 // 2, pokeY * 8 - pokeHeight * 8 // 2), | |
| (pokeX * 8 + pokeWidth * 8 // 2, pokeY * 8 + pokeHeight * 8 // 2), | |
| ] | |
| blank = Image.new("RGB", (imageWidth, imageHeight)) | |
| if os.path.exists("outputs/original.png"): | |
| oImg = Image.open("outputs/original.png") | |
| pImg = Image.open("outputs/poked.png") | |
| else: | |
| oImg = blank | |
| pImg = blank | |
| oRec = ImageDraw.Draw(oImg) | |
| pRec = ImageDraw.Draw(pImg) | |
| oRec.rectangle(shape, outline="white") | |
| pRec.rectangle(shape, outline="white") | |
| return oImg, pImg | |
| def display_poke_images( | |
| prompt, | |
| seed, | |
| num_inference_steps, | |
| poke=False, | |
| pokeX=None, | |
| pokeY=None, | |
| pokeHeight=None, | |
| pokeWidth=None, | |
| intermediate=False, | |
| progress=gr.Progress(), | |
| ): | |
| text_embeddings = get_text_embeddings(prompt) | |
| latents, modified_latents = generate_modified_latents( | |
| poke, seed, pokeX, pokeY, pokeHeight, pokeWidth | |
| ) | |
| progress(0) | |
| images = generate_images( | |
| latents, text_embeddings, num_inference_steps, intermediate=intermediate | |
| ) | |
| if not intermediate: | |
| images.save("outputs/original.png") | |
| if poke: | |
| progress(0.5) | |
| modImages = generate_images( | |
| modified_latents, | |
| text_embeddings, | |
| num_inference_steps, | |
| intermediate=intermediate, | |
| ) | |
| if not intermediate: | |
| modImages.save("outputs/poked.png") | |
| else: | |
| modImages = None | |
| return images, modImages | |
| __all__ = ["display_poke_images", "visualize_poke"] | |