File size: 3,372 Bytes
43c5517
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import torch
from PIL import Image
import io
import base64
from diffusers import StableDiffusionInpaintPipeline
import gc
from fastapi.responses import JSONResponse
import logging

app = FastAPI()

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global variable for the model
pipe = None

def load_model():
    global pipe
    if pipe is None:
        # Use the pre-uploaded model from Hugging Face
        model_id = "Uminosachi/realisticVisionV51_v51VAE-inpainting"
        pipe = StableDiffusionInpaintPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.float16,
            safety_checker=None
        ).to("cuda")
        pipe.enable_attention_slicing(slice_size="max")
        pipe.enable_sequential_cpu_offload()
    return pipe

@app.on_event("startup")
async def startup_event():
    if torch.cuda.is_available():
        load_model()

def image_to_base64(image: Image.Image) -> str:
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode()

@app.post("/inpaint")
async def inpaint(

    image: UploadFile = File(...),

    mask: UploadFile = File(...),

    prompt: str = "add some flowers and a fountain",

    negative_prompt: str = "blurry, low quality, distorted"

):
    try:
        # Add file size check (10MB limit)
        max_size = 10 * 1024 * 1024  # 10MB
        if len(await image.read()) > max_size or len(await mask.read()) > max_size:
            return JSONResponse(
                status_code=400,
                content={"error": "File size too large. Maximum size is 10MB"}
            )
        
        # Reset file positions
        await image.seek(0)
        await mask.seek(0)
        
        # Read and process input image
        image_data = await image.read()
        mask_data = await mask.read()
        
        original_image = Image.open(io.BytesIO(image_data))
        mask_image = Image.open(io.BytesIO(mask_data))

        # Resize to multiple of 8
        width, height = (dim - dim % 8 for dim in original_image.size)
        original_image = original_image.resize((width, height))
        mask_image = mask_image.resize((width, height))
        mask_image = mask_image.convert("L")

        # Perform inpainting
        with torch.cuda.amp.autocast():
            output_image = pipe(
                prompt=prompt,
                negative_prompt=negative_prompt,
                image=original_image,
                mask_image=mask_image,
                num_inference_steps=20,
                guidance_scale=7.5,
            ).images[0]

        # Convert output image to base64
        output_base64 = image_to_base64(output_image)

        # Clean up
        torch.cuda.empty_cache()
        gc.collect()

        return {"status": "success", "image": output_base64}

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    return {"status": "healthy", "cuda_available": torch.cuda.is_available()}