File size: 6,037 Bytes
1c337e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import gradio as gr
import torch
from PIL import Image
import numpy as np
from engine import SegmentAnythingModel, StableDiffusionInpaintingPipeline
from utils import show_anns, create_image_grid
import matplotlib.pyplot as plt
import PIL
import requests
import matplotlib
matplotlib.use('Agg')  # Use Agg backend

# Check for CUDA availability
if not torch.cuda.is_available():
    # If CUDA isn't available, create a simple Gradio interface to notify users
    with gr.Blocks() as demo:
        gr.HTML("""
            <style>
                body {
                    position: relative;
                    height: 100vh;
                    width: 100%;
                    display: flex;
                    justify-content: center;
                    align-items: center;
                    background: rgba(0, 0, 0, 0.1);
                    filter: blur(10px);
                }
                .overlay {
                    position: absolute;
                    z-index: 10;
                    color: white;
                    font-size: 20px;
                    text-align: center;
                    padding: 20px;
                    background-color: rgba(0, 0, 0, 0.7);
                    border-radius: 10px;
                    box-shadow: 0px 0px 20px rgba(0, 0, 0, 0.5);
                }
                .message {
                    font-size: 22px;
                    margin-top: 20px;
                }
            </style>
            <div class="overlay">
                <h1>CUDA is not available</h1>
                <p>Please clone the repository or run it in Colab:</p>
                <a href="https://github.com/SanshruthR/Stable-Diffusion-Inpainting_with_SAM" style="color: #1e90ff; text-decoration: underline;">GitHub Repository</a>
                <div class="message">
                    <p>We are currently unable to run on this machine because CUDA is missing.</p>
                </div>
            </div>
        """)
        demo.launch(share=True, debug=True)
    exit()  # Exit the program if CUDA is not available

# Download SAM checkpoint
url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
response = requests.get(url)

with open("sam_vit_h_4b8939.pth", "wb") as file:
    file.write(response.content)

# Initialize models
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"  # Default device
sam_model = SegmentAnythingModel(sam_checkpoint, model_type, device)

model_dir = "stabilityai/stable-diffusion-2-inpainting"
sd_pipeline = StableDiffusionInpaintingPipeline(model_dir)

# Global variable to store masks
current_masks = None
current_image = None

def segment_image(image):
    global current_masks, current_image
    current_image = image
    
    # Convert to numpy array
    image_array = np.array(image)
    
    # Generate masks
    current_masks = sam_model.generate_masks(image_array)
    
    # Create visualization of masks
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(1, 1, 1)
    
    # Display the original image first
    ax.imshow(sam_model.preprocess_image(image))
    
    # Overlay masks
    show_anns(current_masks, ax)
    
    ax.axis('off')
    plt.tight_layout()
    
    return fig

def inpaint_image(mask_index, prompt1, prompt2, prompt3, prompt4):
    global current_masks, current_image
    
    if current_masks is None or current_image is None:
        return None
    
    # Get selected mask
    segmentation_mask = current_masks[mask_index]['segmentation']
    stable_diffusion_mask = PIL.Image.fromarray((segmentation_mask * 255).astype(np.uint8))

    # Generate inpainted images
    prompts = [p for p in [prompt1, prompt2, prompt3, prompt4] if p.strip()]
    generator = torch.Generator(device="cuda").manual_seed(42)  # Fixed seed for consistency
    
    encoded_images = []
    for prompt in prompts:
        img = sd_pipeline.inpaint(
            prompt=prompt,
            image=Image.fromarray(np.array(current_image)),
            mask_image=stable_diffusion_mask,
            guidance_scale=7.5,  # Lower guidance scale for more creative results
            num_inference_steps=50,  # Good balance between quality and speed
            generator=generator
        )
        encoded_images.append(img)

    # Create result grid
    result_grid = create_image_grid(Image.fromarray(np.array(current_image)),
                                  encoded_images,
                                  prompts,
                                  2, 3)
    
    return result_grid

# Create Gradio interface with two tabs
with gr.Blocks() as demo:
    gr.Markdown("# Segment Anything + Stable Diffusion Inpainting")
    
    with gr.Tab("Step 1: Segment Image"):
        with gr.Row():
            input_image = gr.Image(label="Input Image")
            mask_output = gr.Plot(label="Available Masks")
        segment_btn = gr.Button("Generate Masks")
        segment_btn.click(fn=segment_image, inputs=[input_image], outputs=[mask_output])
    
    with gr.Tab("Step 2: Inpaint"):
        with gr.Row():
            with gr.Column():
                mask_index = gr.Slider(minimum=0, maximum=20, step=1, 
                                     label="Mask Index (select based on mask numbers from Step 1)")
                prompt1 = gr.Textbox(label="Prompt 1", placeholder="Enter first inpainting prompt")
                prompt2 = gr.Textbox(label="Prompt 2", placeholder="Enter second inpainting prompt")
                prompt3 = gr.Textbox(label="Prompt 3", placeholder="Enter third inpainting prompt")
                prompt4 = gr.Textbox(label="Prompt 4", placeholder="Enter fourth inpainting prompt")
            inpaint_output = gr.Plot(label="Inpainting Results")
        inpaint_btn = gr.Button("Generate Inpainting")
        inpaint_btn.click(fn=inpaint_image, 
                         inputs=[mask_index, prompt1, prompt2, prompt3, prompt4],
                         outputs=[inpaint_output])

if __name__ == "__main__":
    demo.launch(share=True, debug=True)