Spaces:
Running
Running
from fastapi import FastAPI, Request, UploadFile, Form, File | |
from fastapi.responses import StreamingResponse | |
from contextlib import asynccontextmanager | |
from starlette.middleware.cors import CORSMiddleware | |
from PIL import Image | |
from io import BytesIO | |
from diffusers import ( | |
AutoPipelineForText2Image, | |
AutoPipelineForImage2Image, | |
AutoPipelineForInpainting, | |
) | |
from transformers import CLIPFeatureExtractor | |
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | |
async def lifespan(app: FastAPI): | |
feature_extractor = CLIPFeatureExtractor.from_pretrained( | |
"openai/clip-vit-base-patch32" | |
) | |
safety_checker = StableDiffusionSafetyChecker.from_pretrained( | |
"CompVis/stable-diffusion-safety-checker" | |
) | |
text2img = AutoPipelineForText2Image.from_pretrained( | |
"stabilityai/sd-turbo", | |
safety_checker=safety_checker, | |
feature_extractor=feature_extractor, | |
).to("cpu") | |
img2img = AutoPipelineForImage2Image.from_pipe(text2img).to("cpu") | |
inpaint = AutoPipelineForInpainting.from_pipe(img2img).to("cpu") | |
yield {"text2img": text2img, "img2img": img2img, "inpaint": inpaint} | |
del inpaint | |
del img2img | |
del text2img | |
del safety_checker | |
del feature_extractor | |
app = FastAPI(lifespan=lifespan) | |
origins = ["*"] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def root(): | |
return {"Hello": "World"} | |
async def text_to_image( | |
request: Request, | |
prompt: str = Form(...), | |
num_inference_steps: int = Form(1), | |
): | |
results = request.state.text2img( | |
prompt=prompt, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=0.0, | |
) | |
if not results.nsfw_content_detected[0]: | |
image = results.images[0] | |
else: | |
image = Image.new("RGB", (512, 512), "black") | |
bytes = BytesIO() | |
image.save(bytes, "PNG") | |
bytes.seek(0) | |
return StreamingResponse(bytes, media_type="image/png") | |
async def image_to_image( | |
request: Request, | |
prompt: str = Form(...), | |
init_image: UploadFile = File(...), | |
num_inference_steps: int = Form(2), | |
strength: float = Form(1.0), | |
): | |
init_bytes = await init_image.read() | |
init_image = Image.open(BytesIO(init_bytes)) | |
init_width, init_height = init_image.size | |
init_image = init_image.convert("RGB").resize((512, 512)) | |
results = request.state.img2img( | |
prompt, | |
image=init_image, | |
num_inference_steps=num_inference_steps, | |
strength=strength, | |
guidance_scale=0.0, | |
) | |
if not results.nsfw_content_detected[0]: | |
image = results.images[0].resize((init_width, init_height)) | |
else: | |
image = Image.new("RGB", (512, 512), "black") | |
bytes = BytesIO() | |
image.save(bytes, "PNG") | |
bytes.seek(0) | |
return StreamingResponse(bytes, media_type="image/png") | |
async def inpainting( | |
request: Request, | |
prompt: str = Form(...), | |
init_image: UploadFile = File(...), | |
mask_image: UploadFile = File(...), | |
num_inference_steps: int = Form(2), | |
strength: float = Form(1.0), | |
): | |
init_bytes = await init_image.read() | |
init_image = Image.open(BytesIO(init_bytes)) | |
init_width, init_height = init_image.size | |
init_image = init_image.convert("RGB").resize((512, 512)) | |
mask_bytes = await mask_image.read() | |
mask_image = Image.open(BytesIO(mask_bytes)) | |
mask_image = mask_image.convert("RGB").resize((512, 512)) | |
results = request.state.inpaint( | |
prompt, | |
image=init_image, | |
mask_image=mask_image, | |
num_inference_steps=num_inference_steps, | |
strength=strength, | |
guidance_scale=0.0, | |
) | |
if not results.nsfw_content_detected[0]: | |
image = results.images[0].resize((init_width, init_height)) | |
else: | |
image = Image.new("RGB", (512, 512), "black") | |
bytes = BytesIO() | |
image.save(bytes, "PNG") | |
bytes.seek(0) | |
return StreamingResponse(bytes, media_type="image/png") | |