Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import torch | |
| from PIL import Image | |
| import io | |
| import base64 | |
| from diffusers import StableDiffusionInpaintPipeline | |
| import gc | |
| from fastapi.responses import JSONResponse | |
| import logging | |
| app = FastAPI() | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variable for the model | |
| pipe = None | |
| # Add max size limit | |
| MAX_SIZE = 512 | |
| def load_model(): | |
| global pipe | |
| if pipe is None: | |
| model_id = "Uminosachi/realisticVisionV51_v51VAE-inpainting" | |
| try: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, | |
| safety_checker=None | |
| ).to(device) | |
| if device == "cuda": | |
| pipe.enable_attention_slicing() | |
| print(f"Model loaded on {device} with optimizations") | |
| except Exception as e: | |
| print(f"Error loading model: {str(e)}") | |
| raise | |
| return pipe | |
| async def startup_event(): | |
| try: | |
| load_model() | |
| except Exception as e: | |
| print(f"Startup error: {str(e)}") | |
| def image_to_base64(image: Image.Image) -> str: | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="PNG") | |
| return base64.b64encode(buffered.getvalue()).decode() | |
| def resize_for_condition_image(input_image: Image.Image, resolution: int): | |
| input_width, input_height = input_image.size | |
| aspect_ratio = input_height / input_width | |
| if input_height > input_width: | |
| # vertical image | |
| width = resolution | |
| height = int(resolution * aspect_ratio) | |
| else: | |
| # horizontal image | |
| height = resolution | |
| width = int(resolution / aspect_ratio) | |
| return input_image.resize((width, height)) | |
| async def inpaint( | |
| image: UploadFile = File(...), | |
| mask: UploadFile = File(...), | |
| prompt: str = "add some flowers and a fountain", | |
| negative_prompt: str = "blurry, low quality, distorted" | |
| ): | |
| try: | |
| # Add file size check (10MB limit) | |
| max_size = 10 * 1024 * 1024 # 10MB | |
| if len(await image.read()) > max_size or len(await mask.read()) > max_size: | |
| return JSONResponse( | |
| status_code=400, | |
| content={"error": "File size too large. Maximum size is 10MB"} | |
| ) | |
| # Reset file positions | |
| await image.seek(0) | |
| await mask.seek(0) | |
| # Read and process input image | |
| image_data = await image.read() | |
| mask_data = await mask.read() | |
| original_image = Image.open(io.BytesIO(image_data)) | |
| mask_image = Image.open(io.BytesIO(mask_data)) | |
| # Resize images to smaller size | |
| original_image = resize_for_condition_image(original_image, MAX_SIZE) | |
| mask_image = resize_for_condition_image(mask_image, MAX_SIZE) | |
| mask_image = mask_image.convert("L") | |
| # Reduce steps even more for CPU | |
| num_steps = 5 if not torch.cuda.is_available() else 20 | |
| with torch.cuda.amp.autocast(): | |
| output_image = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| image=original_image, | |
| mask_image=mask_image, | |
| num_inference_steps=num_steps, | |
| guidance_scale=7.0, # Slightly reduced for speed | |
| ).images[0] | |
| # Convert output image to base64 | |
| output_base64 = image_to_base64(output_image) | |
| # Clean up | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return {"status": "success", "image": output_base64} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def health_check(): | |
| return {"status": "healthy", "cuda_available": torch.cuda.is_available()} |