Spaces:
Sleeping
Sleeping
File size: 4,137 Bytes
e3d5df0 43c5517 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import torch
from PIL import Image
import io
import base64
from diffusers import StableDiffusionInpaintPipeline
import gc
from fastapi.responses import JSONResponse
import logging
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variable for the model
pipe = None
# Add max size limit
MAX_SIZE = 512
def load_model():
global pipe
if pipe is None:
model_id = "Uminosachi/realisticVisionV51_v51VAE-inpainting"
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
safety_checker=None
).to(device)
if device == "cuda":
pipe.enable_attention_slicing()
print(f"Model loaded on {device} with optimizations")
except Exception as e:
print(f"Error loading model: {str(e)}")
raise
return pipe
@app.on_event("startup")
async def startup_event():
try:
load_model()
except Exception as e:
print(f"Startup error: {str(e)}")
def image_to_base64(image: Image.Image) -> str:
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
def resize_for_condition_image(input_image: Image.Image, resolution: int):
input_width, input_height = input_image.size
aspect_ratio = input_height / input_width
if input_height > input_width:
# vertical image
width = resolution
height = int(resolution * aspect_ratio)
else:
# horizontal image
height = resolution
width = int(resolution / aspect_ratio)
return input_image.resize((width, height))
@app.post("/inpaint")
async def inpaint(
image: UploadFile = File(...),
mask: UploadFile = File(...),
prompt: str = "add some flowers and a fountain",
negative_prompt: str = "blurry, low quality, distorted"
):
try:
# Add file size check (10MB limit)
max_size = 10 * 1024 * 1024 # 10MB
if len(await image.read()) > max_size or len(await mask.read()) > max_size:
return JSONResponse(
status_code=400,
content={"error": "File size too large. Maximum size is 10MB"}
)
# Reset file positions
await image.seek(0)
await mask.seek(0)
# Read and process input image
image_data = await image.read()
mask_data = await mask.read()
original_image = Image.open(io.BytesIO(image_data))
mask_image = Image.open(io.BytesIO(mask_data))
# Resize images to smaller size
original_image = resize_for_condition_image(original_image, MAX_SIZE)
mask_image = resize_for_condition_image(mask_image, MAX_SIZE)
mask_image = mask_image.convert("L")
# Reduce steps even more for CPU
num_steps = 5 if not torch.cuda.is_available() else 20
with torch.cuda.amp.autocast():
output_image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=original_image,
mask_image=mask_image,
num_inference_steps=num_steps,
guidance_scale=7.0, # Slightly reduced for speed
).images[0]
# Convert output image to base64
output_base64 = image_to_base64(output_image)
# Clean up
torch.cuda.empty_cache()
gc.collect()
return {"status": "success", "image": output_base64}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy", "cuda_available": torch.cuda.is_available()} |