|
from contextlib import nullcontext |
|
from io import BytesIO |
|
import os |
|
import threading |
|
from typing import Optional, Union |
|
import warnings |
|
|
|
from compel import Compel |
|
from fastapi.responses import StreamingResponse |
|
from loguru import logger |
|
from PIL import Image |
|
import torch |
|
|
|
from leptonai.photon import Photon, FileParam, get_file_content, HTTPException |
|
|
|
|
|
EXAMPLE_IMAGE_BASE64 = "/9j/4AAQSkZJRgABAQAAAQABAAD/2wCEAAkGBxAQEBANDxIQEA8PDw8PDxUPEg8NDxUPFRIWFhURFRYYHSggGBolGxUVITEhJSkrLi4uFx8zODMsNygtLisBCgoKDg0OGBAQFysfHx8tKy4tKy0tKystLS0rKy0tLSstNy4tLy0tLS0tKy0tLSsrLS0rLS0tLS0tLS0rKzctK//AABEIAOEA4QMBEQACEQEDEQH/xAAbAAEAAgMBAQAAAAAAAAAAAAAAAQMCBAYHBf/EAEAQAQACAQIBCAUIBwkBAAAAAAABAgMEETEFBhIhQXGRoVFhgbHBBxMyQ1JyktEVIkJic4LhJFNjk6KywuLwFP/EABoBAQEAAwEBAAAAAAAAAAAAAAABAgMFBAb/xAAtEQEAAgIBAgMIAQUBAAAAAAAAAQIDEQQSUSFBkQUTIjFCUmFxMiMzgaHBFP/aAAwDAQACEQMRAD8A9uBIJBIAAAAAAAAAAAAAAAAAAAAAAAAAMAZQACQAAAAAAAAAAAAAAAAAAAAAAAAAYgmASAAAAAAAAAAAAAAAAAAAAAAAAAACASAAAAAAAAAAAAAAAAAAAAAAAAAAACIBIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAI3BIAAAI3BIAAAAAAAAAAAAAAAAAAOd5b546XSV6V5m28zEdHqrMx2RPb7Gi2esfLxerHxL2+fg4/V/K12YcEd95mfyaLcq3lD2U9n087S+ff5S9bb6PzdO6sT792qeTk7t9eBh7f7Uzz311/rZjuise6Guc2T7m6OHhj6YTHOXWW458vstaPixnJefOfVn/AObFH0x6M45Z1E8cuSe+1p+KdVu7L3NPtj0ZRypm/vLeMnVPc91TtHon9KZvt28U6p7r7qnZH6YzxwvbxlOue6+5p2j0ZRzg1NeGXJ+O8fFfeWjzljPHxT9Mei7Hzy1dP25mPXMW98M45GSPNrtwsM/S+7yTz1yXibZOj0YjebXjoV7otHVu9FOTfzeLLwMcfKdO20Oqrmx0zV+jkrFo7pe+s7jblXr02mOy9WIAAAAAAAAAAAAAADxzVx0t8d60yUi07VyVi8Rt2xvwn1uRaZiZ0+kpWJiNvnX5E00/VTTf7F7x5WmYhh1tnR+WMc3sP7Ns0d847f8AGE6mWpW15CrHDJf21ifieC7lbXkiI+sn8H/Y8DcrY0ER+3P4P6htP/yx9qfCPzDaJwR+95QmoNyqtjr6J9sx+RqF3Km+32Y9s3/M0m5a+TLaOG0d1axPjtuyhjLXyTaZibTNp/emZlshqs9t5qzvotN/Bp7nTx/xhwM/9y37fVZtQAAAAAAAAAAAAAADyLW02y5I9GS8eFpcjJHjL6TDO6R+mFatTcsrHqFZxHqETt6lETt6BFdkVXZRReBWvkgRq3hlEMZa8x1s4arPbeasf2LS/wAGnudPH/GHBz/3Lft9Vm1AAAAAAAAAAAAAAAPKeWqdHU56+jNkn2TaZj3uVljV5fQ8ad46/pr0lpelbWUGSiYkRjaQVyiqrqKbg1sqjVuyhjKiI62cNcvcuQMfQ0umrPGMGLfv6EbunSNVh89knd5n8t9kwAAAAAAAAAAAAAAAcZzv5vTM31uOY22i2Ws9XCIjpVnu26p9c7vJnwb+KHR4fK6dY7f4ch1xxie/jHjweGay68XiWdLwx0yWxYDcETIMJkFdpXRtr3tBo2pms24RM90TK9MsZlVOnt27V75jfw4s4rLGbQ6fmLzew6i98mXpXrh6G0fRpa1t+qe2Yjbh1cXqwY4nxlzuZntTVa+b02IexykgAAAAAAAiASAAAAAACrV4IyY74p4Xpak91omPikxuNLWdTEvGsmG0TtO8THVPfDl2nT6KmpjwZRW/pme/rY9TZ0soi3q8ITa6Zb29Eea7TUotafRH+r802uvywm0+rzNmmFrT6vwx8V2aYTktHbt3bR7jZ0qMl7TxmZ79zadMKpiZ4yyiWMxp6f8AJzp+jpLX7cma0x3RER74l78EfC4vNtvJrs6tueQAAAAAAABiCYBIAAAAAAPMucOl6GqzV7JvN47r/rfFzc0avLu8S3Viq0a1aHsZxRA6AMJoKwnGIptSFVTesKjXuqSq7WUMLPYeamDoaLT19OOL/jmbfF0scarD5/PbqyWl9Zm1AAAAAAAAMATAJBIAAAAAOJ584Ns2PJ9vH0fbWfytHg8XKr4xLq+z7fDMdnOQ8bpwziUVEyCJBhbYFGSYUa95Ua2SVSWOKs2tFY4zMRHfPVDOsbnTVedRMvccGKKUrSOFK1rHdEbOpD52Z3O1ggAAAAAAADCATAJgEgAAAAA57ntp+lp63jjjyRv923V7+i8/Irum3s4NtZNd3DOdLt1TFkZG67XTGZQYWlYFNga2SVGveVhjLf5sYPnNZp6f4tbT3V/Wnyq3Yo3aHk5VunHZ7M6LhAAAAAAAAAMIBIJgEgAAAAA0+WNP85gy4+2cduj96OuvnEMbxusw2YrdN4l5jXrcqYfRVkmrBmgUkFdwUXkhWtkVGteWUMZdL8nmKJ1nTnhixXt7Z2r7rS9XGj4nO59v6eu8vUIyw9rkMotAJAAAAAAABWCYkE7gkEgAAAAiQeW8paf5rPlx8IrktEfd36vLZy8katMPoePbqpWVUW9TU3onZGSJQV2lRr5BWtkWEa2RnDCXU8xabRmyemaUj2bzPvh7ONHzlyudPjEOux5p9L1Q5+mzj1NlRtY9QiNmmUFkWBIAAAAKwSCQSCQAAAAAcFz20/R1EZNurLSJ/mr+rPlFXh5Nfi33dj2ffdNdnway8bpMoBEgrsCi6jVyKNXIyhhLtuaeLo6as/bte8+O0eVYdDDGquLyrbyT+H3sdW6HlbGOqo2cdRi2KQguqC2ASAAADAE7AAkEgAAAAA5vnxpelgrljjiv1/dt1T5xVo5Fd132e3g36cmu7hYlzZd2GcSiomQYWkGvlso1MsqjTy2/ozrG2q9tRt6byLp9sWOkcK0rXwjrdSldQ4OS27TL7OLTMmqZbNNObYr64kGcUBnFQZAAAAAxAgEgkAAAAAAFWow1vW1LxFq2ia2ie2JNbWJ1O4cbr+Zt95nT5KzHZTNvWY/nrE7+2Pa8t+LE/wAZ06eL2hMeF42+Vl5B1dOOC0x6aTTJ5RO/k888XJD1152GfPTUyaPPHHBqP8jNMf7WucN/tlujkYZ+uPVRbT5uzDqJ7sGaZ8qnub/bK+/xffHqxryVq7/R02o/mxXx+d4iGcYMk+TXPMwx9X/V+Hmdyhk+rx4fXmy14emIx9Lw6m2vFt5vPf2hjj5RMvucj/J5Sloy6nLOe9Z3ita/N4Yn09HeZn2z7IevHhpTxeDLy75PDydpg0laxtENm3l2vikIidgSAAAAAACN/wD3UBsCQAAAAAAAAQCJgU2U2bAAAbAmIREgAAAAAAAAAx6/V4gyAAAAAAAAAAAABGwGwGwJAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB/9k=" |
|
|
|
|
|
class JPEGResponse(StreamingResponse): |
|
media_type = "image/jpeg" |
|
|
|
|
|
class ImgPilot(Photon): |
|
requirement_dependency = [ |
|
"torch", |
|
"diffusers", |
|
"invisible-watermark", |
|
"compel", |
|
"Pillow", |
|
] |
|
|
|
|
|
|
|
deployment_template = { |
|
"resource_shape": "gpu.a10", |
|
"env": { |
|
"MODEL": "SimianLuo/LCM_Dreamshaper_v7", |
|
"USE_TORCH_COMPILE": "false", |
|
"WIDTH": "768", |
|
"HEIGHT": "768", |
|
"PRINT_PROMPT": "false", |
|
}, |
|
} |
|
|
|
|
|
|
|
handler_max_concurrency = 1 |
|
|
|
def init(self): |
|
from diffusers import AutoPipelineForImage2Image |
|
|
|
cuda_available = torch.cuda.is_available() |
|
|
|
if cuda_available: |
|
self.device = torch.device("cuda") |
|
else: |
|
self.device = torch.device("cpu") |
|
|
|
self.base = AutoPipelineForImage2Image.from_pretrained( |
|
os.environ["MODEL"], |
|
torch_dtype=torch.float16 if cuda_available else torch.float32, |
|
) |
|
self.base.safety_checker = None |
|
self.base.requires_safety_checker = False |
|
if self.handler_max_concurrency > 1: |
|
self.base_lock = threading.Lock() |
|
else: |
|
self.base_lock = nullcontext() |
|
self.print_prompt = os.environ["PRINT_PROMPT"].lower() in [ |
|
"true", |
|
"t", |
|
"1", |
|
"yes", |
|
"y", |
|
] |
|
logger.info(f"print_prompt: {self.print_prompt}") |
|
if cuda_available: |
|
self.base.to("cuda") |
|
self.use_torch_compile = os.environ["USE_TORCH_COMPILE"].lower() in [ |
|
"true", |
|
"t", |
|
"1", |
|
"yes", |
|
"y", |
|
] |
|
if self.use_torch_compile: |
|
if self.handler_max_concurrency > 1: |
|
warnings.warn( |
|
"torch compile does not support multithreading, so we will" |
|
" disable torch compile since handler_max_concurrency > 1." |
|
) |
|
else: |
|
self.width = int(os.environ["WIDTH"]) |
|
self.height = int(os.environ["HEIGHT"]) |
|
logger.info( |
|
"Compiling model with torch.compile. Note that with torch" |
|
" compile, your first invocation will be slow, but subsequent" |
|
" invocations will be faster." |
|
) |
|
self.base.unet = torch.compile( |
|
self.base.unet, mode="reduce-overhead", fullgraph=True |
|
) |
|
else: |
|
self.use_torch_compile = False |
|
|
|
self.compel_proc = Compel( |
|
tokenizer=self.base.tokenizer, |
|
text_encoder=self.base.text_encoder, |
|
truncate_long_prompts=False, |
|
) |
|
|
|
logger.info(f"Initialized model {os.environ['MODEL']}. cuda: {cuda_available}.") |
|
|
|
@Photon.handler( |
|
"run", |
|
example={ |
|
"prompt": ( |
|
"Portrait of The Terminator, glare pose, detailed, intricate, full of" |
|
" colour, cinematic lighting, trending on artstation, 8k," |
|
" hyperrealistic, focused, extreme details, unreal engine 5, cinematic," |
|
" masterpiece" |
|
), |
|
"seed": 2159232, |
|
"strength": 0.5, |
|
"steps": 4, |
|
"guidance_scale": 8.0, |
|
"width": 512, |
|
"height": 512, |
|
"lcm_steps": 50, |
|
"input_image": EXAMPLE_IMAGE_BASE64, |
|
}, |
|
) |
|
def run( |
|
self, |
|
prompt: str, |
|
seed: int, |
|
strength: float, |
|
steps: int, |
|
guidance_scale: float, |
|
width: int, |
|
height: int, |
|
lcm_steps: int, |
|
input_image: Optional[Union[str, FileParam]], |
|
) -> JPEGResponse: |
|
from diffusers.utils import load_image |
|
import time |
|
|
|
start = time.time() |
|
|
|
if self.print_prompt: |
|
logger.info(f"Prompt: {prompt}") |
|
|
|
|
|
|
|
tokens = self.base.tokenizer(prompt, return_tensors="pt") |
|
if tokens.input_ids.shape[1] > 77: |
|
prompt_embeds = self.compel_proc(prompt) |
|
prompt = None |
|
else: |
|
prompt_embeds = None |
|
|
|
if input_image is not None: |
|
image_file = get_file_content(input_image, return_file=True) |
|
pil_image = Image.open(image_file, formats=["JPEG", "PNG", "GIF", "BMP"]) |
|
if self.use_torch_compile: |
|
|
|
if width != self.width or height != self.height: |
|
raise HTTPException( |
|
status_code=400, |
|
detail=( |
|
f"width and height must be {self.width} and" |
|
f" {self.height} when use_torch_compile is true." |
|
), |
|
) |
|
|
|
if pil_image.height != self.height or pil_image.width != self.width: |
|
pil_image = pil_image.resize( |
|
(self.width, self.height), Image.BILINEAR |
|
) |
|
input_image = load_image(pil_image).convert("RGB") |
|
|
|
with self.base_lock: |
|
generator = torch.manual_seed(seed) |
|
output_image = self.base( |
|
prompt=prompt, |
|
prompt_embeds=prompt_embeds, |
|
generator=generator, |
|
image=input_image, |
|
strength=strength, |
|
num_inference_steps=steps, |
|
guidance_scale=guidance_scale, |
|
width=width, |
|
height=height, |
|
lcm_origin_steps=lcm_steps, |
|
output_type="pil", |
|
) |
|
|
|
nsfw_content_detected = ( |
|
output_image.nsfw_content_detected[0] |
|
if "nsfw_content_detected" in output_image |
|
else False |
|
) |
|
if nsfw_content_detected: |
|
raise HTTPException(status_code=400, detail="nsfw content detected") |
|
else: |
|
img_io = BytesIO() |
|
output_image.images[0].save(img_io, format="JPEG") |
|
img_io.seek(0) |
|
logger.info(f"Produced output in {time.time() - start} seconds.") |
|
return JPEGResponse(img_io) |
|
|
|
|
|
if __name__ == "__main__": |
|
p = ImgPilot() |
|
p.launch() |
|
|