tryemoji / photon /main.py
yadongxie's picture
feat: add web
89682f8
from contextlib import nullcontext
from io import BytesIO
import os
import threading
from typing import Optional, Union
import warnings
from compel import Compel
from fastapi.responses import StreamingResponse
from loguru import logger
from PIL import Image
import torch
from leptonai.photon import Photon, FileParam, get_file_content, HTTPException
EXAMPLE_IMAGE_BASE64 = "/9j/4AAQSkZJRgABAQAAAQABAAD/2wCEAAkGBxAQEBANDxIQEA8PDw8PDxUPEg8NDxUPFRIWFhURFRYYHSggGBolGxUVITEhJSkrLi4uFx8zODMsNygtLisBCgoKDg0OGBAQFysfHx8tKy4tKy0tKystLS0rKy0tLSstNy4tLy0tLS0tKy0tLSsrLS0rLS0tLS0tLS0rKzctK//AABEIAOEA4QMBEQACEQEDEQH/xAAbAAEAAgMBAQAAAAAAAAAAAAAAAQMCBAYHBf/EAEAQAQACAQIBCAUIBwkBAAAAAAABAgMEETEFBhIhQXGRoVFhgbHBBxMyQ1JyktEVIkJic4LhJFNjk6KywuLwFP/EABoBAQEAAwEBAAAAAAAAAAAAAAABAgMFBAb/xAAtEQEAAgIBAgMIAQUBAAAAAAAAAQIDEQQSUSFBkQUTIjFCUmFxMiMzgaHBFP/aAAwDAQACEQMRAD8A9uBIJBIAAAAAAAAAAAAAAAAAAAAAAAAAMAZQACQAAAAAAAAAAAAAAAAAAAAAAAAAYgmASAAAAAAAAAAAAAAAAAAAAAAAAAACASAAAAAAAAAAAAAAAAAAAAAAAAAAACIBIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAI3BIAAAI3BIAAAAAAAAAAAAAAAAAAOd5b546XSV6V5m28zEdHqrMx2RPb7Gi2esfLxerHxL2+fg4/V/K12YcEd95mfyaLcq3lD2U9n087S+ff5S9bb6PzdO6sT792qeTk7t9eBh7f7Uzz311/rZjuise6Guc2T7m6OHhj6YTHOXWW458vstaPixnJefOfVn/AObFH0x6M45Z1E8cuSe+1p+KdVu7L3NPtj0ZRypm/vLeMnVPc91TtHon9KZvt28U6p7r7qnZH6YzxwvbxlOue6+5p2j0ZRzg1NeGXJ+O8fFfeWjzljPHxT9Mei7Hzy1dP25mPXMW98M45GSPNrtwsM/S+7yTz1yXibZOj0YjebXjoV7otHVu9FOTfzeLLwMcfKdO20Oqrmx0zV+jkrFo7pe+s7jblXr02mOy9WIAAAAAAAAAAAAAADxzVx0t8d60yUi07VyVi8Rt2xvwn1uRaZiZ0+kpWJiNvnX5E00/VTTf7F7x5WmYhh1tnR+WMc3sP7Ns0d847f8AGE6mWpW15CrHDJf21ifieC7lbXkiI+sn8H/Y8DcrY0ER+3P4P6htP/yx9qfCPzDaJwR+95QmoNyqtjr6J9sx+RqF3Km+32Y9s3/M0m5a+TLaOG0d1axPjtuyhjLXyTaZibTNp/emZlshqs9t5qzvotN/Bp7nTx/xhwM/9y37fVZtQAAAAAAAAAAAAAADyLW02y5I9GS8eFpcjJHjL6TDO6R+mFatTcsrHqFZxHqETt6lETt6BFdkVXZRReBWvkgRq3hlEMZa8x1s4arPbeasf2LS/wAGnudPH/GHBz/3Lft9Vm1AAAAAAAAAAAAAAAPKeWqdHU56+jNkn2TaZj3uVljV5fQ8ad46/pr0lpelbWUGSiYkRjaQVyiqrqKbg1sqjVuyhjKiI62cNcvcuQMfQ0umrPGMGLfv6EbunSNVh89knd5n8t9kwAAAAAAAAAAAAAAAcZzv5vTM31uOY22i2Ws9XCIjpVnu26p9c7vJnwb+KHR4fK6dY7f4ch1xxie/jHjweGay68XiWdLwx0yWxYDcETIMJkFdpXRtr3tBo2pms24RM90TK9MsZlVOnt27V75jfw4s4rLGbQ6fmLzew6i98mXpXrh6G0fRpa1t+qe2Yjbh1cXqwY4nxlzuZntTVa+b02IexykgAAAAAAAiASAAAAAACrV4IyY74p4Xpak91omPikxuNLWdTEvGsmG0TtO8THVPfDl2nT6KmpjwZRW/pme/rY9TZ0soi3q8ITa6Zb29Eea7TUotafRH+r802uvywm0+rzNmmFrT6vwx8V2aYTktHbt3bR7jZ0qMl7TxmZ79zadMKpiZ4yyiWMxp6f8AJzp+jpLX7cma0x3RER74l78EfC4vNtvJrs6tueQAAAAAAABiCYBIAAAAAAPMucOl6GqzV7JvN47r/rfFzc0avLu8S3Viq0a1aHsZxRA6AMJoKwnGIptSFVTesKjXuqSq7WUMLPYeamDoaLT19OOL/jmbfF0scarD5/PbqyWl9Zm1AAAAAAAAMATAJBIAAAAAOJ584Ns2PJ9vH0fbWfytHg8XKr4xLq+z7fDMdnOQ8bpwziUVEyCJBhbYFGSYUa95Ua2SVSWOKs2tFY4zMRHfPVDOsbnTVedRMvccGKKUrSOFK1rHdEbOpD52Z3O1ggAAAAAAADCATAJgEgAAAAA57ntp+lp63jjjyRv923V7+i8/Irum3s4NtZNd3DOdLt1TFkZG67XTGZQYWlYFNga2SVGveVhjLf5sYPnNZp6f4tbT3V/Wnyq3Yo3aHk5VunHZ7M6LhAAAAAAAAAMIBIJgEgAAAAA0+WNP85gy4+2cduj96OuvnEMbxusw2YrdN4l5jXrcqYfRVkmrBmgUkFdwUXkhWtkVGteWUMZdL8nmKJ1nTnhixXt7Z2r7rS9XGj4nO59v6eu8vUIyw9rkMotAJAAAAAAABWCYkE7gkEgAAAAiQeW8paf5rPlx8IrktEfd36vLZy8katMPoePbqpWVUW9TU3onZGSJQV2lRr5BWtkWEa2RnDCXU8xabRmyemaUj2bzPvh7ONHzlyudPjEOux5p9L1Q5+mzj1NlRtY9QiNmmUFkWBIAAAAKwSCQSCQAAAAAcFz20/R1EZNurLSJ/mr+rPlFXh5Nfi33dj2ffdNdnway8bpMoBEgrsCi6jVyKNXIyhhLtuaeLo6as/bte8+O0eVYdDDGquLyrbyT+H3sdW6HlbGOqo2cdRi2KQguqC2ASAAADAE7AAkEgAAAAA5vnxpelgrljjiv1/dt1T5xVo5Fd132e3g36cmu7hYlzZd2GcSiomQYWkGvlso1MsqjTy2/ozrG2q9tRt6byLp9sWOkcK0rXwjrdSldQ4OS27TL7OLTMmqZbNNObYr64kGcUBnFQZAAAAAxAgEgkAAAAAAFWow1vW1LxFq2ia2ie2JNbWJ1O4cbr+Zt95nT5KzHZTNvWY/nrE7+2Pa8t+LE/wAZ06eL2hMeF42+Vl5B1dOOC0x6aTTJ5RO/k888XJD1152GfPTUyaPPHHBqP8jNMf7WucN/tlujkYZ+uPVRbT5uzDqJ7sGaZ8qnub/bK+/xffHqxryVq7/R02o/mxXx+d4iGcYMk+TXPMwx9X/V+Hmdyhk+rx4fXmy14emIx9Lw6m2vFt5vPf2hjj5RMvucj/J5Sloy6nLOe9Z3ita/N4Yn09HeZn2z7IevHhpTxeDLy75PDydpg0laxtENm3l2vikIidgSAAAAAACN/wD3UBsCQAAAAAAAAQCJgU2U2bAAAbAmIREgAAAAAAAAAx6/V4gyAAAAAAAAAAAABGwGwGwJAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB/9k="
class JPEGResponse(StreamingResponse):
media_type = "image/jpeg"
class ImgPilot(Photon):
requirement_dependency = [
"torch",
"diffusers",
"invisible-watermark",
"compel",
"Pillow",
]
# In default, we will use gpu.a10 as the computation resource shape. This should
# be fast enough.
deployment_template = {
"resource_shape": "gpu.a10",
"env": {
"MODEL": "SimianLuo/LCM_Dreamshaper_v7",
"USE_TORCH_COMPILE": "false",
"WIDTH": "768",
"HEIGHT": "768",
"PRINT_PROMPT": "false",
},
}
# A10 should be able to support a maximum concurrency of 8 requests to interleave
# IO and compute. This is not tuned by the way.
handler_max_concurrency = 1
def init(self):
from diffusers import AutoPipelineForImage2Image # type: ignore
cuda_available = torch.cuda.is_available()
if cuda_available:
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.base = AutoPipelineForImage2Image.from_pretrained(
os.environ["MODEL"],
torch_dtype=torch.float16 if cuda_available else torch.float32,
)
self.base.safety_checker = None
self.base.requires_safety_checker = False
if self.handler_max_concurrency > 1:
self.base_lock = threading.Lock()
else:
self.base_lock = nullcontext()
self.print_prompt = os.environ["PRINT_PROMPT"].lower() in [
"true",
"t",
"1",
"yes",
"y",
]
logger.info(f"print_prompt: {self.print_prompt}")
if cuda_available:
self.base.to("cuda")
self.use_torch_compile = os.environ["USE_TORCH_COMPILE"].lower() in [
"true",
"t",
"1",
"yes",
"y",
]
if self.use_torch_compile:
if self.handler_max_concurrency > 1:
warnings.warn(
"torch compile does not support multithreading, so we will"
" disable torch compile since handler_max_concurrency > 1."
)
else:
self.width = int(os.environ["WIDTH"])
self.height = int(os.environ["HEIGHT"])
logger.info(
"Compiling model with torch.compile. Note that with torch"
" compile, your first invocation will be slow, but subsequent"
" invocations will be faster."
)
self.base.unet = torch.compile(
self.base.unet, mode="reduce-overhead", fullgraph=True
)
else:
self.use_torch_compile = False
self.compel_proc = Compel(
tokenizer=self.base.tokenizer,
text_encoder=self.base.text_encoder,
truncate_long_prompts=False,
) # type: ignore
logger.info(f"Initialized model {os.environ['MODEL']}. cuda: {cuda_available}.")
@Photon.handler(
"run",
example={
"prompt": (
"Portrait of The Terminator, glare pose, detailed, intricate, full of"
" colour, cinematic lighting, trending on artstation, 8k,"
" hyperrealistic, focused, extreme details, unreal engine 5, cinematic,"
" masterpiece"
),
"seed": 2159232,
"strength": 0.5,
"steps": 4,
"guidance_scale": 8.0,
"width": 512,
"height": 512,
"lcm_steps": 50,
"input_image": EXAMPLE_IMAGE_BASE64,
},
)
def run(
self,
prompt: str,
seed: int,
strength: float,
steps: int,
guidance_scale: float,
width: int,
height: int,
lcm_steps: int,
input_image: Optional[Union[str, FileParam]],
) -> JPEGResponse:
from diffusers.utils import load_image # type: ignore
import time
start = time.time()
if self.print_prompt:
logger.info(f"Prompt: {prompt}")
# diffusers truncates prompt to 77 tokens, in case prompt is too long, we will
# use compel to process the prompt (but compel is slower)
tokens = self.base.tokenizer(prompt, return_tensors="pt")
if tokens.input_ids.shape[1] > 77:
prompt_embeds = self.compel_proc(prompt)
prompt = None
else:
prompt_embeds = None
if input_image is not None:
image_file = get_file_content(input_image, return_file=True)
pil_image = Image.open(image_file, formats=["JPEG", "PNG", "GIF", "BMP"])
if self.use_torch_compile:
# checks width and height parameter, and return error if width and height are not correct
if width != self.width or height != self.height:
raise HTTPException(
status_code=400,
detail=(
f"width and height must be {self.width} and"
f" {self.height} when use_torch_compile is true."
),
)
# checks input image height and width, and resize if necessary
if pil_image.height != self.height or pil_image.width != self.width:
pil_image = pil_image.resize(
(self.width, self.height), Image.BILINEAR
)
input_image = load_image(pil_image).convert("RGB")
with self.base_lock:
generator = torch.manual_seed(seed)
output_image = self.base(
prompt=prompt,
prompt_embeds=prompt_embeds,
generator=generator,
image=input_image,
strength=strength,
num_inference_steps=steps,
guidance_scale=guidance_scale,
width=width,
height=height,
lcm_origin_steps=lcm_steps,
output_type="pil",
) # type: ignore
nsfw_content_detected = (
output_image.nsfw_content_detected[0]
if "nsfw_content_detected" in output_image
else False
) # type: ignore
if nsfw_content_detected:
raise HTTPException(status_code=400, detail="nsfw content detected")
else:
img_io = BytesIO()
output_image.images[0].save(img_io, format="JPEG") # type: ignore
img_io.seek(0)
logger.info(f"Produced output in {time.time() - start} seconds.")
return JPEGResponse(img_io)
if __name__ == "__main__":
p = ImgPilot()
p.launch()