ahmetyaylalioglu commited on
Commit
c636056
·
verified ·
1 Parent(s): 8007933

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -14
app.py CHANGED
@@ -3,11 +3,11 @@ from PIL import Image
3
  import numpy as np
4
  from transformers import SamModel, SamProcessor
5
  from diffusers import AutoPipelineForInpainting
6
- from diffusers.models.autoencoders.vq_model import VQEncoderOutput, VQModel
7
  import torch
8
 
9
- # Force the model to use CPU
10
- device = "cpu"
 
11
 
12
  # Model and Processor setup
13
  model_name = "facebook/sam-vit-huge"
@@ -15,14 +15,14 @@ model = SamModel.from_pretrained(model_name).to(device)
15
  processor = SamProcessor.from_pretrained(model_name)
16
 
17
  def mask_to_rgb(mask):
 
18
  bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
19
- bg_transparent[mask == 1] = [0, 255, 0, 127]
20
  return bg_transparent
21
 
22
- def get_processed_inputs(image, points_str):
23
- points = [list(map(int, point.split(','))) for point in points_str.split()]
24
- input_points = [points]
25
- inputs = processor(image, input_points=input_points, return_tensors="pt").to(device)
26
  with torch.no_grad():
27
  outputs = model(**inputs)
28
  masks = processor.image_processor.post_process_masks(
@@ -34,12 +34,15 @@ def get_processed_inputs(image, points_str):
34
  return ~best_mask.cpu().numpy()
35
 
36
  def inpaint(raw_image, input_mask, prompt, negative_prompt=None, seed=74294536, cfgs=7):
 
37
  mask_image = Image.fromarray(input_mask)
38
  rand_gen = torch.manual_seed(seed)
39
  pipeline = AutoPipelineForInpainting.from_pretrained(
40
- "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16
41
- )
42
- pipeline.enable_model_cpu_offload()
 
 
43
  image = pipeline(
44
  prompt=prompt,
45
  image=raw_image,
@@ -50,7 +53,9 @@ def inpaint(raw_image, input_mask, prompt, negative_prompt=None, seed=74294536,
50
  ).images[0]
51
  return image
52
 
53
- def gradio_interface(image, points, positive_prompt, negative_prompt):
 
 
54
  raw_image = Image.fromarray(image).convert("RGB").resize((512, 512))
55
  mask = get_processed_inputs(raw_image, points)
56
  processed_image = inpaint(raw_image, mask, positive_prompt, negative_prompt)
@@ -60,7 +65,7 @@ iface = gr.Interface(
60
  fn=gradio_interface,
61
  inputs=[
62
  gr.Image(type="numpy", label="Input Image"),
63
- gr.Textbox(label="Points (format: x1,y1 x2,y2 ...)", placeholder="e.g., 100,100 200,200"),
64
  gr.Textbox(label="Positive Prompt", placeholder="Enter positive prompt here"),
65
  gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt here")
66
  ],
@@ -69,7 +74,7 @@ iface = gr.Interface(
69
  gr.Image(label="Segmentation Mask")
70
  ],
71
  title="Interactive Image Inpainting",
72
- description="Enter points as 'x1,y1 x2,y2 ...' for segmentation, provide prompts, and see the inpainted result."
73
  )
74
 
75
  iface.launch(share=True)
 
3
  import numpy as np
4
  from transformers import SamModel, SamProcessor
5
  from diffusers import AutoPipelineForInpainting
 
6
  import torch
7
 
8
+ # Check if GPU is available, otherwise use CPU
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ print(f"Using device: {device}")
11
 
12
  # Model and Processor setup
13
  model_name = "facebook/sam-vit-huge"
 
15
  processor = SamProcessor.from_pretrained(model_name)
16
 
17
  def mask_to_rgb(mask):
18
+ """ Convert binary mask to RGB with transparency for the background. """
19
  bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
20
+ bg_transparent[mask == 1] = [0, 255, 0, 127] # Green mask with some transparency
21
  return bg_transparent
22
 
23
+ def get_processed_inputs(image, points):
24
+ """ Process the input image and points using SAM model and processor. """
25
+ inputs = processor(image, input_points=points, return_tensors="pt").to(device)
 
26
  with torch.no_grad():
27
  outputs = model(**inputs)
28
  masks = processor.image_processor.post_process_masks(
 
34
  return ~best_mask.cpu().numpy()
35
 
36
  def inpaint(raw_image, input_mask, prompt, negative_prompt=None, seed=74294536, cfgs=7):
37
+ """ Inpaint the masked area in the image using a text prompt and an inpainting pipeline. """
38
  mask_image = Image.fromarray(input_mask)
39
  rand_gen = torch.manual_seed(seed)
40
  pipeline = AutoPipelineForInpainting.from_pretrained(
41
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
42
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
43
+ ).to(device)
44
+ if device == "cpu":
45
+ pipeline.enable_model_cpu_offload()
46
  image = pipeline(
47
  prompt=prompt,
48
  image=raw_image,
 
53
  ).images[0]
54
  return image
55
 
56
+ def gradio_interface(image, points_json, positive_prompt, negative_prompt):
57
+ """ Gradio interface function to handle image, points for segmentation, and prompts. """
58
+ points = [[(point['x'], point['y']) for point in stroke['points']] for stroke in points_json]
59
  raw_image = Image.fromarray(image).convert("RGB").resize((512, 512))
60
  mask = get_processed_inputs(raw_image, points)
61
  processed_image = inpaint(raw_image, mask, positive_prompt, negative_prompt)
 
65
  fn=gradio_interface,
66
  inputs=[
67
  gr.Image(type="numpy", label="Input Image"),
68
+ gr.Image(type="json", label="Click to select points", tool="sketch", brush_radius=1, shape=(512, 512)),
69
  gr.Textbox(label="Positive Prompt", placeholder="Enter positive prompt here"),
70
  gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt here")
71
  ],
 
74
  gr.Image(label="Segmentation Mask")
75
  ],
76
  title="Interactive Image Inpainting",
77
+ description="Click on the image to select points for segmentation, provide prompts, and see the inpainted result."
78
  )
79
 
80
  iface.launch(share=True)