File size: 6,037 Bytes
1c337e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import gradio as gr
import torch
from PIL import Image
import numpy as np
from engine import SegmentAnythingModel, StableDiffusionInpaintingPipeline
from utils import show_anns, create_image_grid
import matplotlib.pyplot as plt
import PIL
import requests
import matplotlib
matplotlib.use('Agg') # Use Agg backend
# Check for CUDA availability
if not torch.cuda.is_available():
# If CUDA isn't available, create a simple Gradio interface to notify users
with gr.Blocks() as demo:
gr.HTML("""
<style>
body {
position: relative;
height: 100vh;
width: 100%;
display: flex;
justify-content: center;
align-items: center;
background: rgba(0, 0, 0, 0.1);
filter: blur(10px);
}
.overlay {
position: absolute;
z-index: 10;
color: white;
font-size: 20px;
text-align: center;
padding: 20px;
background-color: rgba(0, 0, 0, 0.7);
border-radius: 10px;
box-shadow: 0px 0px 20px rgba(0, 0, 0, 0.5);
}
.message {
font-size: 22px;
margin-top: 20px;
}
</style>
<div class="overlay">
<h1>CUDA is not available</h1>
<p>Please clone the repository or run it in Colab:</p>
<a href="https://github.com/SanshruthR/Stable-Diffusion-Inpainting_with_SAM" style="color: #1e90ff; text-decoration: underline;">GitHub Repository</a>
<div class="message">
<p>We are currently unable to run on this machine because CUDA is missing.</p>
</div>
</div>
""")
demo.launch(share=True, debug=True)
exit() # Exit the program if CUDA is not available
# Download SAM checkpoint
url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
response = requests.get(url)
with open("sam_vit_h_4b8939.pth", "wb") as file:
file.write(response.content)
# Initialize models
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda" # Default device
sam_model = SegmentAnythingModel(sam_checkpoint, model_type, device)
model_dir = "stabilityai/stable-diffusion-2-inpainting"
sd_pipeline = StableDiffusionInpaintingPipeline(model_dir)
# Global variable to store masks
current_masks = None
current_image = None
def segment_image(image):
global current_masks, current_image
current_image = image
# Convert to numpy array
image_array = np.array(image)
# Generate masks
current_masks = sam_model.generate_masks(image_array)
# Create visualization of masks
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(1, 1, 1)
# Display the original image first
ax.imshow(sam_model.preprocess_image(image))
# Overlay masks
show_anns(current_masks, ax)
ax.axis('off')
plt.tight_layout()
return fig
def inpaint_image(mask_index, prompt1, prompt2, prompt3, prompt4):
global current_masks, current_image
if current_masks is None or current_image is None:
return None
# Get selected mask
segmentation_mask = current_masks[mask_index]['segmentation']
stable_diffusion_mask = PIL.Image.fromarray((segmentation_mask * 255).astype(np.uint8))
# Generate inpainted images
prompts = [p for p in [prompt1, prompt2, prompt3, prompt4] if p.strip()]
generator = torch.Generator(device="cuda").manual_seed(42) # Fixed seed for consistency
encoded_images = []
for prompt in prompts:
img = sd_pipeline.inpaint(
prompt=prompt,
image=Image.fromarray(np.array(current_image)),
mask_image=stable_diffusion_mask,
guidance_scale=7.5, # Lower guidance scale for more creative results
num_inference_steps=50, # Good balance between quality and speed
generator=generator
)
encoded_images.append(img)
# Create result grid
result_grid = create_image_grid(Image.fromarray(np.array(current_image)),
encoded_images,
prompts,
2, 3)
return result_grid
# Create Gradio interface with two tabs
with gr.Blocks() as demo:
gr.Markdown("# Segment Anything + Stable Diffusion Inpainting")
with gr.Tab("Step 1: Segment Image"):
with gr.Row():
input_image = gr.Image(label="Input Image")
mask_output = gr.Plot(label="Available Masks")
segment_btn = gr.Button("Generate Masks")
segment_btn.click(fn=segment_image, inputs=[input_image], outputs=[mask_output])
with gr.Tab("Step 2: Inpaint"):
with gr.Row():
with gr.Column():
mask_index = gr.Slider(minimum=0, maximum=20, step=1,
label="Mask Index (select based on mask numbers from Step 1)")
prompt1 = gr.Textbox(label="Prompt 1", placeholder="Enter first inpainting prompt")
prompt2 = gr.Textbox(label="Prompt 2", placeholder="Enter second inpainting prompt")
prompt3 = gr.Textbox(label="Prompt 3", placeholder="Enter third inpainting prompt")
prompt4 = gr.Textbox(label="Prompt 4", placeholder="Enter fourth inpainting prompt")
inpaint_output = gr.Plot(label="Inpainting Results")
inpaint_btn = gr.Button("Generate Inpainting")
inpaint_btn.click(fn=inpaint_image,
inputs=[mask_index, prompt1, prompt2, prompt3, prompt4],
outputs=[inpaint_output])
if __name__ == "__main__":
demo.launch(share=True, debug=True)
|