Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, File, UploadFile, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
import torch | |
from PIL import Image | |
import io | |
import base64 | |
from diffusers import StableDiffusionInpaintPipeline | |
import gc | |
from fastapi.responses import JSONResponse | |
import logging | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Global variable for the model | |
pipe = None | |
def load_model(): | |
global pipe | |
if pipe is None: | |
model_id = "Uminosachi/realisticVisionV51_v51VAE-inpainting" | |
try: | |
# Try CUDA first | |
if torch.cuda.is_available(): | |
device = "cuda" | |
dtype = torch.float16 | |
else: | |
# Fallback to CPU | |
device = "cpu" | |
dtype = torch.float32 | |
pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
model_id, | |
torch_dtype=dtype, | |
safety_checker=None | |
).to(device) | |
if device == "cuda": | |
pipe.enable_attention_slicing(slice_size="max") | |
pipe.enable_sequential_cpu_offload() | |
print(f"Model loaded on {device}") | |
except Exception as e: | |
print(f"Error loading model: {str(e)}") | |
raise | |
return pipe | |
async def startup_event(): | |
try: | |
load_model() | |
except Exception as e: | |
print(f"Startup error: {str(e)}") | |
def image_to_base64(image: Image.Image) -> str: | |
buffered = io.BytesIO() | |
image.save(buffered, format="PNG") | |
return base64.b64encode(buffered.getvalue()).decode() | |
async def inpaint( | |
image: UploadFile = File(...), | |
mask: UploadFile = File(...), | |
prompt: str = "add some flowers and a fountain", | |
negative_prompt: str = "blurry, low quality, distorted" | |
): | |
try: | |
# Add file size check (10MB limit) | |
max_size = 10 * 1024 * 1024 # 10MB | |
if len(await image.read()) > max_size or len(await mask.read()) > max_size: | |
return JSONResponse( | |
status_code=400, | |
content={"error": "File size too large. Maximum size is 10MB"} | |
) | |
# Reset file positions | |
await image.seek(0) | |
await mask.seek(0) | |
# Read and process input image | |
image_data = await image.read() | |
mask_data = await mask.read() | |
original_image = Image.open(io.BytesIO(image_data)) | |
mask_image = Image.open(io.BytesIO(mask_data)) | |
# Resize to multiple of 8 | |
width, height = (dim - dim % 8 for dim in original_image.size) | |
original_image = original_image.resize((width, height)) | |
mask_image = mask_image.resize((width, height)) | |
mask_image = mask_image.convert("L") | |
# Perform inpainting | |
with torch.cuda.amp.autocast(): | |
output_image = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
image=original_image, | |
mask_image=mask_image, | |
num_inference_steps=20, | |
guidance_scale=7.5, | |
).images[0] | |
# Convert output image to base64 | |
output_base64 = image_to_base64(output_image) | |
# Clean up | |
torch.cuda.empty_cache() | |
gc.collect() | |
return {"status": "success", "image": output_base64} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
return {"status": "healthy", "cuda_available": torch.cuda.is_available()} |