File size: 1,862 Bytes
936e569
 
 
 
 
 
 
 
 
 
2dd5975
936e569
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e32234
936e569
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2dd5975
936e569
6e32234
936e569
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
from PIL import Image
import requests
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
from torchvision import transforms
from transformers import SamModel, SamProcessor
from diffusers import AutoPipelineForInpainting
from diffusers.utils import load_image, make_image_grid



def modify_image(image_url, prompt, mask_id=4):

  processor = SegformerImageProcessor.from_pretrained("sayeed99/segformer_b3_clothes")
  model = AutoModelForSemanticSegmentation.from_pretrained("sayeed99/segformer_b3_clothes")

  image = Image.open(image_url)
  inputs = processor(images=image, return_tensors="pt")

  outputs = model(**inputs)
  logits = outputs.logits.cpu()

  upsampled_logits = nn.functional.interpolate(
      logits,
      size=image.size[::-1],
      mode="bilinear",
      align_corners=False,
  )

  pred_seg = upsampled_logits.argmax(dim=1)[0]

  mask = (pred_seg == mask_id).numpy()
  mask_image = Image.fromarray((mask * 255).astype('uint8'))

  pipeline = AutoPipelineForInpainting.from_pretrained(
      "redstonehero/ReV_Animated_Inpainting",
      torch_dtype=torch.float16)

#  pipeline.enable_model_cpu_offload()

  image1 = pipeline(prompt=prompt,
                  num_inference_steps=24,
                  image=image,
                  mask_image=mask_image,
                  guidance_scale=3,
                  strength=1.0).images[0]

  return make_image_grid([image1], rows = 1, cols = 1)



import gradio as gr

def gradio_wrapper(image, prompt, choice):
  return modify_image(image, prompt, int(choice))

demo = gr.Interface(
    fn=gradio_wrapper,
    inputs=[
        gr.Image(type="filepath"), 
        gr.Textbox(label="Prompt"),
        gr.Radio(["4", "6"], label="Mask ID")
    ],
    outputs=gr.Image()
)

demo.launch(inline=False)