File size: 4,137 Bytes
e3d5df0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43c5517
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import torch
from PIL import Image
import io
import base64
from diffusers import StableDiffusionInpaintPipeline
import gc
from fastapi.responses import JSONResponse
import logging

app = FastAPI()

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global variable for the model
pipe = None

# Add max size limit
MAX_SIZE = 512

def load_model():
    global pipe
    if pipe is None:
        model_id = "Uminosachi/realisticVisionV51_v51VAE-inpainting"
        try:
            device = "cuda" if torch.cuda.is_available() else "cpu"
            pipe = StableDiffusionInpaintPipeline.from_pretrained(
                model_id,
                torch_dtype=torch.float16,
                safety_checker=None
            ).to(device)
            
            if device == "cuda":
                pipe.enable_attention_slicing()
            
            print(f"Model loaded on {device} with optimizations")
            
        except Exception as e:
            print(f"Error loading model: {str(e)}")
            raise
    return pipe

@app.on_event("startup")
async def startup_event():
    try:
        load_model()
    except Exception as e:
        print(f"Startup error: {str(e)}")

def image_to_base64(image: Image.Image) -> str:
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode()

def resize_for_condition_image(input_image: Image.Image, resolution: int):
    input_width, input_height = input_image.size
    aspect_ratio = input_height / input_width
    
    if input_height > input_width:
        # vertical image
        width = resolution
        height = int(resolution * aspect_ratio)
    else:
        # horizontal image
        height = resolution
        width = int(resolution / aspect_ratio)
    
    return input_image.resize((width, height))

@app.post("/inpaint")
async def inpaint(
    image: UploadFile = File(...),
    mask: UploadFile = File(...),
    prompt: str = "add some flowers and a fountain",
    negative_prompt: str = "blurry, low quality, distorted"
):
    try:
        # Add file size check (10MB limit)
        max_size = 10 * 1024 * 1024  # 10MB
        if len(await image.read()) > max_size or len(await mask.read()) > max_size:
            return JSONResponse(
                status_code=400,
                content={"error": "File size too large. Maximum size is 10MB"}
            )
        
        # Reset file positions
        await image.seek(0)
        await mask.seek(0)
        
        # Read and process input image
        image_data = await image.read()
        mask_data = await mask.read()
        
        original_image = Image.open(io.BytesIO(image_data))
        mask_image = Image.open(io.BytesIO(mask_data))

        # Resize images to smaller size
        original_image = resize_for_condition_image(original_image, MAX_SIZE)
        mask_image = resize_for_condition_image(mask_image, MAX_SIZE)
        mask_image = mask_image.convert("L")

        # Reduce steps even more for CPU
        num_steps = 5 if not torch.cuda.is_available() else 20
        
        with torch.cuda.amp.autocast():
            output_image = pipe(
                prompt=prompt,
                negative_prompt=negative_prompt,
                image=original_image,
                mask_image=mask_image,
                num_inference_steps=num_steps,
                guidance_scale=7.0,  # Slightly reduced for speed
            ).images[0]

        # Convert output image to base64
        output_base64 = image_to_base64(output_image)

        # Clean up
        torch.cuda.empty_cache()
        gc.collect()

        return {"status": "success", "image": output_base64}

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    return {"status": "healthy", "cuda_available": torch.cuda.is_available()}