Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, File, UploadFile, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
import torch | |
from PIL import Image | |
import io | |
import base64 | |
from diffusers import StableDiffusionInpaintPipeline | |
import gc | |
from fastapi.responses import JSONResponse | |
import logging | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Global variable for the model | |
pipe = None | |
# Add max size limit | |
MAX_SIZE = 512 | |
def load_model(): | |
global pipe | |
if pipe is None: | |
model_id = "Uminosachi/realisticVisionV51_v51VAE-inpainting" | |
try: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16, | |
safety_checker=None | |
).to(device) | |
if device == "cuda": | |
pipe.enable_attention_slicing() | |
print(f"Model loaded on {device} with optimizations") | |
except Exception as e: | |
print(f"Error loading model: {str(e)}") | |
raise | |
return pipe | |
async def startup_event(): | |
try: | |
load_model() | |
except Exception as e: | |
print(f"Startup error: {str(e)}") | |
def image_to_base64(image: Image.Image) -> str: | |
buffered = io.BytesIO() | |
image.save(buffered, format="PNG") | |
return base64.b64encode(buffered.getvalue()).decode() | |
def resize_for_condition_image(input_image: Image.Image, resolution: int): | |
input_width, input_height = input_image.size | |
aspect_ratio = input_height / input_width | |
if input_height > input_width: | |
# vertical image | |
width = resolution | |
height = int(resolution * aspect_ratio) | |
else: | |
# horizontal image | |
height = resolution | |
width = int(resolution / aspect_ratio) | |
return input_image.resize((width, height)) | |
async def inpaint( | |
image: UploadFile = File(...), | |
mask: UploadFile = File(...), | |
prompt: str = "add some flowers and a fountain", | |
negative_prompt: str = "blurry, low quality, distorted" | |
): | |
try: | |
# Add file size check (10MB limit) | |
max_size = 10 * 1024 * 1024 # 10MB | |
if len(await image.read()) > max_size or len(await mask.read()) > max_size: | |
return JSONResponse( | |
status_code=400, | |
content={"error": "File size too large. Maximum size is 10MB"} | |
) | |
# Reset file positions | |
await image.seek(0) | |
await mask.seek(0) | |
# Read and process input image | |
image_data = await image.read() | |
mask_data = await mask.read() | |
original_image = Image.open(io.BytesIO(image_data)) | |
mask_image = Image.open(io.BytesIO(mask_data)) | |
# Resize images to smaller size | |
original_image = resize_for_condition_image(original_image, MAX_SIZE) | |
mask_image = resize_for_condition_image(mask_image, MAX_SIZE) | |
mask_image = mask_image.convert("L") | |
# Reduce steps even more for CPU | |
num_steps = 5 if not torch.cuda.is_available() else 20 | |
with torch.cuda.amp.autocast(): | |
output_image = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
image=original_image, | |
mask_image=mask_image, | |
num_inference_steps=num_steps, | |
guidance_scale=7.0, # Slightly reduced for speed | |
).images[0] | |
# Convert output image to base64 | |
output_base64 = image_to_base64(output_image) | |
# Clean up | |
torch.cuda.empty_cache() | |
gc.collect() | |
return {"status": "success", "image": output_base64} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
return {"status": "healthy", "cuda_available": torch.cuda.is_available()} |