Sanshruth commited on
Commit
1c337e7
·
verified ·
1 Parent(s): 20f285c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -0
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ import numpy as np
5
+ from engine import SegmentAnythingModel, StableDiffusionInpaintingPipeline
6
+ from utils import show_anns, create_image_grid
7
+ import matplotlib.pyplot as plt
8
+ import PIL
9
+ import requests
10
+ import matplotlib
11
+ matplotlib.use('Agg') # Use Agg backend
12
+
13
+ # Check for CUDA availability
14
+ if not torch.cuda.is_available():
15
+ # If CUDA isn't available, create a simple Gradio interface to notify users
16
+ with gr.Blocks() as demo:
17
+ gr.HTML("""
18
+ <style>
19
+ body {
20
+ position: relative;
21
+ height: 100vh;
22
+ width: 100%;
23
+ display: flex;
24
+ justify-content: center;
25
+ align-items: center;
26
+ background: rgba(0, 0, 0, 0.1);
27
+ filter: blur(10px);
28
+ }
29
+ .overlay {
30
+ position: absolute;
31
+ z-index: 10;
32
+ color: white;
33
+ font-size: 20px;
34
+ text-align: center;
35
+ padding: 20px;
36
+ background-color: rgba(0, 0, 0, 0.7);
37
+ border-radius: 10px;
38
+ box-shadow: 0px 0px 20px rgba(0, 0, 0, 0.5);
39
+ }
40
+ .message {
41
+ font-size: 22px;
42
+ margin-top: 20px;
43
+ }
44
+ </style>
45
+ <div class="overlay">
46
+ <h1>CUDA is not available</h1>
47
+ <p>Please clone the repository or run it in Colab:</p>
48
+ <a href="https://github.com/SanshruthR/Stable-Diffusion-Inpainting_with_SAM" style="color: #1e90ff; text-decoration: underline;">GitHub Repository</a>
49
+ <div class="message">
50
+ <p>We are currently unable to run on this machine because CUDA is missing.</p>
51
+ </div>
52
+ </div>
53
+ """)
54
+ demo.launch(share=True, debug=True)
55
+ exit() # Exit the program if CUDA is not available
56
+
57
+ # Download SAM checkpoint
58
+ url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
59
+ response = requests.get(url)
60
+
61
+ with open("sam_vit_h_4b8939.pth", "wb") as file:
62
+ file.write(response.content)
63
+
64
+ # Initialize models
65
+ sam_checkpoint = "sam_vit_h_4b8939.pth"
66
+ model_type = "vit_h"
67
+ device = "cuda" # Default device
68
+ sam_model = SegmentAnythingModel(sam_checkpoint, model_type, device)
69
+
70
+ model_dir = "stabilityai/stable-diffusion-2-inpainting"
71
+ sd_pipeline = StableDiffusionInpaintingPipeline(model_dir)
72
+
73
+ # Global variable to store masks
74
+ current_masks = None
75
+ current_image = None
76
+
77
+ def segment_image(image):
78
+ global current_masks, current_image
79
+ current_image = image
80
+
81
+ # Convert to numpy array
82
+ image_array = np.array(image)
83
+
84
+ # Generate masks
85
+ current_masks = sam_model.generate_masks(image_array)
86
+
87
+ # Create visualization of masks
88
+ fig = plt.figure(figsize=(10, 10))
89
+ ax = fig.add_subplot(1, 1, 1)
90
+
91
+ # Display the original image first
92
+ ax.imshow(sam_model.preprocess_image(image))
93
+
94
+ # Overlay masks
95
+ show_anns(current_masks, ax)
96
+
97
+ ax.axis('off')
98
+ plt.tight_layout()
99
+
100
+ return fig
101
+
102
+ def inpaint_image(mask_index, prompt1, prompt2, prompt3, prompt4):
103
+ global current_masks, current_image
104
+
105
+ if current_masks is None or current_image is None:
106
+ return None
107
+
108
+ # Get selected mask
109
+ segmentation_mask = current_masks[mask_index]['segmentation']
110
+ stable_diffusion_mask = PIL.Image.fromarray((segmentation_mask * 255).astype(np.uint8))
111
+
112
+ # Generate inpainted images
113
+ prompts = [p for p in [prompt1, prompt2, prompt3, prompt4] if p.strip()]
114
+ generator = torch.Generator(device="cuda").manual_seed(42) # Fixed seed for consistency
115
+
116
+ encoded_images = []
117
+ for prompt in prompts:
118
+ img = sd_pipeline.inpaint(
119
+ prompt=prompt,
120
+ image=Image.fromarray(np.array(current_image)),
121
+ mask_image=stable_diffusion_mask,
122
+ guidance_scale=7.5, # Lower guidance scale for more creative results
123
+ num_inference_steps=50, # Good balance between quality and speed
124
+ generator=generator
125
+ )
126
+ encoded_images.append(img)
127
+
128
+ # Create result grid
129
+ result_grid = create_image_grid(Image.fromarray(np.array(current_image)),
130
+ encoded_images,
131
+ prompts,
132
+ 2, 3)
133
+
134
+ return result_grid
135
+
136
+ # Create Gradio interface with two tabs
137
+ with gr.Blocks() as demo:
138
+ gr.Markdown("# Segment Anything + Stable Diffusion Inpainting")
139
+
140
+ with gr.Tab("Step 1: Segment Image"):
141
+ with gr.Row():
142
+ input_image = gr.Image(label="Input Image")
143
+ mask_output = gr.Plot(label="Available Masks")
144
+ segment_btn = gr.Button("Generate Masks")
145
+ segment_btn.click(fn=segment_image, inputs=[input_image], outputs=[mask_output])
146
+
147
+ with gr.Tab("Step 2: Inpaint"):
148
+ with gr.Row():
149
+ with gr.Column():
150
+ mask_index = gr.Slider(minimum=0, maximum=20, step=1,
151
+ label="Mask Index (select based on mask numbers from Step 1)")
152
+ prompt1 = gr.Textbox(label="Prompt 1", placeholder="Enter first inpainting prompt")
153
+ prompt2 = gr.Textbox(label="Prompt 2", placeholder="Enter second inpainting prompt")
154
+ prompt3 = gr.Textbox(label="Prompt 3", placeholder="Enter third inpainting prompt")
155
+ prompt4 = gr.Textbox(label="Prompt 4", placeholder="Enter fourth inpainting prompt")
156
+ inpaint_output = gr.Plot(label="Inpainting Results")
157
+ inpaint_btn = gr.Button("Generate Inpainting")
158
+ inpaint_btn.click(fn=inpaint_image,
159
+ inputs=[mask_index, prompt1, prompt2, prompt3, prompt4],
160
+ outputs=[inpaint_output])
161
+
162
+ if __name__ == "__main__":
163
+ demo.launch(share=True, debug=True)