ahmetyaylalioglu commited on
Commit
e2a5dbf
·
verified ·
1 Parent(s): 7558e02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -20,9 +20,11 @@ def mask_to_rgb(mask):
20
  bg_transparent[mask == 1] = [0, 255, 0, 127] # Green mask with some transparency
21
  return bg_transparent
22
 
23
- def get_processed_inputs(image, drawing):
24
- """ Process the input image and drawing using SAM model and processor. """
25
- inputs = processor(image, return_tensors="pt").to(device)
 
 
26
  with torch.no_grad():
27
  outputs = model(**inputs)
28
  masks = processor.image_processor.post_process_masks(
@@ -53,10 +55,10 @@ def inpaint(raw_image, input_mask, prompt, negative_prompt=None, seed=74294536,
53
  ).images[0]
54
  return image
55
 
56
- def gradio_interface(image, drawing, positive_prompt, negative_prompt):
57
- """ Gradio interface function to handle image, drawing, and prompts. """
58
  raw_image = Image.fromarray(image).convert("RGB").resize((512, 512))
59
- mask = get_processed_inputs(raw_image, drawing)
60
  processed_image = inpaint(raw_image, mask, positive_prompt, negative_prompt)
61
  return processed_image, mask_to_rgb(mask)
62
 
@@ -64,7 +66,7 @@ iface = gr.Interface(
64
  fn=gradio_interface,
65
  inputs=[
66
  gr.Image(type="numpy", label="Input Image"),
67
- gr.Sketch(label="Draw on the image", shape=(512, 512)),
68
  gr.Textbox(label="Positive Prompt", placeholder="Enter positive prompt here"),
69
  gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt here")
70
  ],
 
20
  bg_transparent[mask == 1] = [0, 255, 0, 127] # Green mask with some transparency
21
  return bg_transparent
22
 
23
+ def get_processed_inputs(image, annotation):
24
+ """ Process the input image and annotated drawing using SAM model and processor. """
25
+ mask = np.zeros(image.size, dtype=np.uint8)
26
+ mask[annotation[:,:,3] > 128] = 1 # Assume drawing is in alpha channel of RGBA
27
+ inputs = processor(images=image, return_tensors="pt").to(device)
28
  with torch.no_grad():
29
  outputs = model(**inputs)
30
  masks = processor.image_processor.post_process_masks(
 
55
  ).images[0]
56
  return image
57
 
58
+ def gradio_interface(image, annotation, positive_prompt, negative_prompt):
59
+ """ Gradio interface function to handle image, annotated drawing, and prompts. """
60
  raw_image = Image.fromarray(image).convert("RGB").resize((512, 512))
61
+ mask = get_processed_inputs(raw_image, annotation)
62
  processed_image = inpaint(raw_image, mask, positive_prompt, negative_prompt)
63
  return processed_image, mask_to_rgb(mask)
64
 
 
66
  fn=gradio_interface,
67
  inputs=[
68
  gr.Image(type="numpy", label="Input Image"),
69
+ gr.Image(tool="editor", label="Draw on the image", output="png", shape=(512, 512)),
70
  gr.Textbox(label="Positive Prompt", placeholder="Enter positive prompt here"),
71
  gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt here")
72
  ],