File size: 11,096 Bytes
89682f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
from contextlib import nullcontext
from io import BytesIO
import os
import threading
from typing import Optional, Union
import warnings
from compel import Compel
from fastapi.responses import StreamingResponse
from loguru import logger
from PIL import Image
import torch
from leptonai.photon import Photon, FileParam, get_file_content, HTTPException
EXAMPLE_IMAGE_BASE64 = "/9j/4AAQSkZJRgABAQAAAQABAAD/2wCEAAkGBxAQEBANDxIQEA8PDw8PDxUPEg8NDxUPFRIWFhURFRYYHSggGBolGxUVITEhJSkrLi4uFx8zODMsNygtLisBCgoKDg0OGBAQFysfHx8tKy4tKy0tKystLS0rKy0tLSstNy4tLy0tLS0tKy0tLSsrLS0rLS0tLS0tLS0rKzctK//AABEIAOEA4QMBEQACEQEDEQH/xAAbAAEAAgMBAQAAAAAAAAAAAAAAAQMCBAYHBf/EAEAQAQACAQIBCAUIBwkBAAAAAAABAgMEETEFBhIhQXGRoVFhgbHBBxMyQ1JyktEVIkJic4LhJFNjk6KywuLwFP/EABoBAQEAAwEBAAAAAAAAAAAAAAABAgMFBAb/xAAtEQEAAgIBAgMIAQUBAAAAAAAAAQIDEQQSUSFBkQUTIjFCUmFxMiMzgaHBFP/aAAwDAQACEQMRAD8A9uBIJBIAAAAAAAAAAAAAAAAAAAAAAAAAMAZQACQAAAAAAAAAAAAAAAAAAAAAAAAAYgmASAAAAAAAAAAAAAAAAAAAAAAAAAACASAAAAAAAAAAAAAAAAAAAAAAAAAAACIBIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAI3BIAAAI3BIAAAAAAAAAAAAAAAAAAOd5b546XSV6V5m28zEdHqrMx2RPb7Gi2esfLxerHxL2+fg4/V/K12YcEd95mfyaLcq3lD2U9n087S+ff5S9bb6PzdO6sT792qeTk7t9eBh7f7Uzz311/rZjuise6Guc2T7m6OHhj6YTHOXWW458vstaPixnJefOfVn/AObFH0x6M45Z1E8cuSe+1p+KdVu7L3NPtj0ZRypm/vLeMnVPc91TtHon9KZvt28U6p7r7qnZH6YzxwvbxlOue6+5p2j0ZRzg1NeGXJ+O8fFfeWjzljPHxT9Mei7Hzy1dP25mPXMW98M45GSPNrtwsM/S+7yTz1yXibZOj0YjebXjoV7otHVu9FOTfzeLLwMcfKdO20Oqrmx0zV+jkrFo7pe+s7jblXr02mOy9WIAAAAAAAAAAAAAADxzVx0t8d60yUi07VyVi8Rt2xvwn1uRaZiZ0+kpWJiNvnX5E00/VTTf7F7x5WmYhh1tnR+WMc3sP7Ns0d847f8AGE6mWpW15CrHDJf21ifieC7lbXkiI+sn8H/Y8DcrY0ER+3P4P6htP/yx9qfCPzDaJwR+95QmoNyqtjr6J9sx+RqF3Km+32Y9s3/M0m5a+TLaOG0d1axPjtuyhjLXyTaZibTNp/emZlshqs9t5qzvotN/Bp7nTx/xhwM/9y37fVZtQAAAAAAAAAAAAAADyLW02y5I9GS8eFpcjJHjL6TDO6R+mFatTcsrHqFZxHqETt6lETt6BFdkVXZRReBWvkgRq3hlEMZa8x1s4arPbeasf2LS/wAGnudPH/GHBz/3Lft9Vm1AAAAAAAAAAAAAAAPKeWqdHU56+jNkn2TaZj3uVljV5fQ8ad46/pr0lpelbWUGSiYkRjaQVyiqrqKbg1sqjVuyhjKiI62cNcvcuQMfQ0umrPGMGLfv6EbunSNVh89knd5n8t9kwAAAAAAAAAAAAAAAcZzv5vTM31uOY22i2Ws9XCIjpVnu26p9c7vJnwb+KHR4fK6dY7f4ch1xxie/jHjweGay68XiWdLwx0yWxYDcETIMJkFdpXRtr3tBo2pms24RM90TK9MsZlVOnt27V75jfw4s4rLGbQ6fmLzew6i98mXpXrh6G0fRpa1t+qe2Yjbh1cXqwY4nxlzuZntTVa+b02IexykgAAAAAAAiASAAAAAACrV4IyY74p4Xpak91omPikxuNLWdTEvGsmG0TtO8THVPfDl2nT6KmpjwZRW/pme/rY9TZ0soi3q8ITa6Zb29Eea7TUotafRH+r802uvywm0+rzNmmFrT6vwx8V2aYTktHbt3bR7jZ0qMl7TxmZ79zadMKpiZ4yyiWMxp6f8AJzp+jpLX7cma0x3RER74l78EfC4vNtvJrs6tueQAAAAAAABiCYBIAAAAAAPMucOl6GqzV7JvN47r/rfFzc0avLu8S3Viq0a1aHsZxRA6AMJoKwnGIptSFVTesKjXuqSq7WUMLPYeamDoaLT19OOL/jmbfF0scarD5/PbqyWl9Zm1AAAAAAAAMATAJBIAAAAAOJ584Ns2PJ9vH0fbWfytHg8XKr4xLq+z7fDMdnOQ8bpwziUVEyCJBhbYFGSYUa95Ua2SVSWOKs2tFY4zMRHfPVDOsbnTVedRMvccGKKUrSOFK1rHdEbOpD52Z3O1ggAAAAAAADCATAJgEgAAAAA57ntp+lp63jjjyRv923V7+i8/Irum3s4NtZNd3DOdLt1TFkZG67XTGZQYWlYFNga2SVGveVhjLf5sYPnNZp6f4tbT3V/Wnyq3Yo3aHk5VunHZ7M6LhAAAAAAAAAMIBIJgEgAAAAA0+WNP85gy4+2cduj96OuvnEMbxusw2YrdN4l5jXrcqYfRVkmrBmgUkFdwUXkhWtkVGteWUMZdL8nmKJ1nTnhixXt7Z2r7rS9XGj4nO59v6eu8vUIyw9rkMotAJAAAAAAABWCYkE7gkEgAAAAiQeW8paf5rPlx8IrktEfd36vLZy8katMPoePbqpWVUW9TU3onZGSJQV2lRr5BWtkWEa2RnDCXU8xabRmyemaUj2bzPvh7ONHzlyudPjEOux5p9L1Q5+mzj1NlRtY9QiNmmUFkWBIAAAAKwSCQSCQAAAAAcFz20/R1EZNurLSJ/mr+rPlFXh5Nfi33dj2ffdNdnway8bpMoBEgrsCi6jVyKNXIyhhLtuaeLo6as/bte8+O0eVYdDDGquLyrbyT+H3sdW6HlbGOqo2cdRi2KQguqC2ASAAADAE7AAkEgAAAAA5vnxpelgrljjiv1/dt1T5xVo5Fd132e3g36cmu7hYlzZd2GcSiomQYWkGvlso1MsqjTy2/ozrG2q9tRt6byLp9sWOkcK0rXwjrdSldQ4OS27TL7OLTMmqZbNNObYr64kGcUBnFQZAAAAAxAgEgkAAAAAAFWow1vW1LxFq2ia2ie2JNbWJ1O4cbr+Zt95nT5KzHZTNvWY/nrE7+2Pa8t+LE/wAZ06eL2hMeF42+Vl5B1dOOC0x6aTTJ5RO/k888XJD1152GfPTUyaPPHHBqP8jNMf7WucN/tlujkYZ+uPVRbT5uzDqJ7sGaZ8qnub/bK+/xffHqxryVq7/R02o/mxXx+d4iGcYMk+TXPMwx9X/V+Hmdyhk+rx4fXmy14emIx9Lw6m2vFt5vPf2hjj5RMvucj/J5Sloy6nLOe9Z3ita/N4Yn09HeZn2z7IevHhpTxeDLy75PDydpg0laxtENm3l2vikIidgSAAAAAACN/wD3UBsCQAAAAAAAAQCJgU2U2bAAAbAmIREgAAAAAAAAAx6/V4gyAAAAAAAAAAAABGwGwGwJAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB/9k="
class JPEGResponse(StreamingResponse):
media_type = "image/jpeg"
class ImgPilot(Photon):
requirement_dependency = [
"torch",
"diffusers",
"invisible-watermark",
"compel",
"Pillow",
]
# In default, we will use gpu.a10 as the computation resource shape. This should
# be fast enough.
deployment_template = {
"resource_shape": "gpu.a10",
"env": {
"MODEL": "SimianLuo/LCM_Dreamshaper_v7",
"USE_TORCH_COMPILE": "false",
"WIDTH": "768",
"HEIGHT": "768",
"PRINT_PROMPT": "false",
},
}
# A10 should be able to support a maximum concurrency of 8 requests to interleave
# IO and compute. This is not tuned by the way.
handler_max_concurrency = 1
def init(self):
from diffusers import AutoPipelineForImage2Image # type: ignore
cuda_available = torch.cuda.is_available()
if cuda_available:
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.base = AutoPipelineForImage2Image.from_pretrained(
os.environ["MODEL"],
torch_dtype=torch.float16 if cuda_available else torch.float32,
)
self.base.safety_checker = None
self.base.requires_safety_checker = False
if self.handler_max_concurrency > 1:
self.base_lock = threading.Lock()
else:
self.base_lock = nullcontext()
self.print_prompt = os.environ["PRINT_PROMPT"].lower() in [
"true",
"t",
"1",
"yes",
"y",
]
logger.info(f"print_prompt: {self.print_prompt}")
if cuda_available:
self.base.to("cuda")
self.use_torch_compile = os.environ["USE_TORCH_COMPILE"].lower() in [
"true",
"t",
"1",
"yes",
"y",
]
if self.use_torch_compile:
if self.handler_max_concurrency > 1:
warnings.warn(
"torch compile does not support multithreading, so we will"
" disable torch compile since handler_max_concurrency > 1."
)
else:
self.width = int(os.environ["WIDTH"])
self.height = int(os.environ["HEIGHT"])
logger.info(
"Compiling model with torch.compile. Note that with torch"
" compile, your first invocation will be slow, but subsequent"
" invocations will be faster."
)
self.base.unet = torch.compile(
self.base.unet, mode="reduce-overhead", fullgraph=True
)
else:
self.use_torch_compile = False
self.compel_proc = Compel(
tokenizer=self.base.tokenizer,
text_encoder=self.base.text_encoder,
truncate_long_prompts=False,
) # type: ignore
logger.info(f"Initialized model {os.environ['MODEL']}. cuda: {cuda_available}.")
@Photon.handler(
"run",
example={
"prompt": (
"Portrait of The Terminator, glare pose, detailed, intricate, full of"
" colour, cinematic lighting, trending on artstation, 8k,"
" hyperrealistic, focused, extreme details, unreal engine 5, cinematic,"
" masterpiece"
),
"seed": 2159232,
"strength": 0.5,
"steps": 4,
"guidance_scale": 8.0,
"width": 512,
"height": 512,
"lcm_steps": 50,
"input_image": EXAMPLE_IMAGE_BASE64,
},
)
def run(
self,
prompt: str,
seed: int,
strength: float,
steps: int,
guidance_scale: float,
width: int,
height: int,
lcm_steps: int,
input_image: Optional[Union[str, FileParam]],
) -> JPEGResponse:
from diffusers.utils import load_image # type: ignore
import time
start = time.time()
if self.print_prompt:
logger.info(f"Prompt: {prompt}")
# diffusers truncates prompt to 77 tokens, in case prompt is too long, we will
# use compel to process the prompt (but compel is slower)
tokens = self.base.tokenizer(prompt, return_tensors="pt")
if tokens.input_ids.shape[1] > 77:
prompt_embeds = self.compel_proc(prompt)
prompt = None
else:
prompt_embeds = None
if input_image is not None:
image_file = get_file_content(input_image, return_file=True)
pil_image = Image.open(image_file, formats=["JPEG", "PNG", "GIF", "BMP"])
if self.use_torch_compile:
# checks width and height parameter, and return error if width and height are not correct
if width != self.width or height != self.height:
raise HTTPException(
status_code=400,
detail=(
f"width and height must be {self.width} and"
f" {self.height} when use_torch_compile is true."
),
)
# checks input image height and width, and resize if necessary
if pil_image.height != self.height or pil_image.width != self.width:
pil_image = pil_image.resize(
(self.width, self.height), Image.BILINEAR
)
input_image = load_image(pil_image).convert("RGB")
with self.base_lock:
generator = torch.manual_seed(seed)
output_image = self.base(
prompt=prompt,
prompt_embeds=prompt_embeds,
generator=generator,
image=input_image,
strength=strength,
num_inference_steps=steps,
guidance_scale=guidance_scale,
width=width,
height=height,
lcm_origin_steps=lcm_steps,
output_type="pil",
) # type: ignore
nsfw_content_detected = (
output_image.nsfw_content_detected[0]
if "nsfw_content_detected" in output_image
else False
) # type: ignore
if nsfw_content_detected:
raise HTTPException(status_code=400, detail="nsfw content detected")
else:
img_io = BytesIO()
output_image.images[0].save(img_io, format="JPEG") # type: ignore
img_io.seek(0)
logger.info(f"Produced output in {time.time() - start} seconds.")
return JPEGResponse(img_io)
if __name__ == "__main__":
p = ImgPilot()
p.launch()
|