Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request, Form, File, UploadFile | |
| from fastapi.responses import StreamingResponse | |
| from contextlib import asynccontextmanager | |
| from starlette.middleware.cors import CORSMiddleware | |
| import torch | |
| from PIL import Image | |
| from io import BytesIO | |
| from diffusers import ( | |
| AutoPipelineForText2Image, | |
| AutoPipelineForImage2Image, | |
| AutoPipelineForInpainting, | |
| ) | |
| async def lifespan(app: FastAPI): | |
| text2img = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo").to( | |
| "cpu" | |
| ) | |
| img2img = AutoPipelineForImage2Image.from_pipe(text2img).to("cpu") | |
| inpaint = AutoPipelineForInpainting.from_pipe(text2img).to("cpu") | |
| yield {"text2img": text2img, "img2img": img2img, "inpaint": inpaint} | |
| del text2img | |
| del img2img | |
| del inpaint | |
| app = FastAPI(lifespan=lifespan) | |
| origins = ["*"] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def root(): | |
| return {"Hello": "World"} | |
| async def text_to_image(request: Request, prompt: str = Form(...)): | |
| image = request.state.text2img( | |
| prompt=prompt, num_inference_steps=1, guidance_scale=0.0 | |
| ).images[0] | |
| bytes = BytesIO() | |
| image.save(bytes, "PNG") | |
| bytes.seek(0) | |
| return StreamingResponse(bytes, media_type="image/png") | |
| async def image_to_image( | |
| request: Request, prompt: str = Form(...), init_image: UploadFile = File(...) | |
| ): | |
| bytes = await init_image.read() | |
| init_image = Image.open(BytesIO(bytes)) | |
| init_image = init_image.convert("RGB").resize((512, 512)) | |
| image = request.state.img2img.pipe( | |
| prompt, | |
| image=init_image, | |
| num_inference_steps=2, | |
| strength=0.5, | |
| guidance_scale=0.0, | |
| ).images[0] | |
| bytes = BytesIO() | |
| image.save(bytes, "PNG") | |
| bytes.seek(0) | |
| return StreamingResponse(bytes, media_type="image/png") | |
| async def inpainting( | |
| request: Request, | |
| prompt: str = Form(...), | |
| init_image: UploadFile = File(...), | |
| mask_image: UploadFile = File(...), | |
| ): | |
| bytes = await init_image.read() | |
| init_image = Image.open(BytesIO(bytes)) | |
| init_image = init_image.convert("RGB").resize((512, 512)) | |
| bytes = await mask_image.read() | |
| mask_image = Image.open(BytesIO(bytes)) | |
| mask_image = mask_image.convert("RGB").resize((512, 512)) | |
| image = request.state.inpaint.pipe( | |
| prompt, | |
| image=init_image, | |
| mask_image=mask_image, | |
| num_inference_steps=3, | |
| strength=0.5, | |
| guidance_scale=0.0, | |
| ).images[0] | |
| bytes = BytesIO() | |
| image.save(bytes, "PNG") | |
| bytes.seek(0) | |
| return StreamingResponse(bytes, media_type="image/png") | |