m451h commited on
Commit
936e569
·
verified ·
1 Parent(s): 1a44ef7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from IPython import get_ipython
2
+ from IPython.display import display
3
+ from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
4
+ from PIL import Image
5
+ import requests
6
+ import matplotlib.pyplot as plt
7
+ import torch.nn as nn
8
+ import torch
9
+ from torchvision import transforms
10
+ from transformers import SamModel, SamProcessor
11
+ from diffusers import AutoPipelineForInpainting
12
+ from diffusers.utils import load_image, make_image_grid
13
+ from google.colab import drive
14
+ drive.mount('/content/drive')
15
+
16
+
17
+ def modify_image(image_url, prompt, mask_id=4):
18
+
19
+ processor = SegformerImageProcessor.from_pretrained("sayeed99/segformer_b3_clothes")
20
+ model = AutoModelForSemanticSegmentation.from_pretrained("sayeed99/segformer_b3_clothes")
21
+
22
+ image = Image.open(image_url)
23
+ inputs = processor(images=image, return_tensors="pt")
24
+
25
+ outputs = model(**inputs)
26
+ logits = outputs.logits.cpu()
27
+
28
+ upsampled_logits = nn.functional.interpolate(
29
+ logits,
30
+ size=image.size[::-1],
31
+ mode="bilinear",
32
+ align_corners=False,
33
+ )
34
+
35
+ pred_seg = upsampled_logits.argmax(dim=1)[0]
36
+
37
+ mask = (pred_seg == mask_id).numpy()
38
+ mask_image = Image.fromarray((mask * 255).astype('uint8'))
39
+
40
+ pipeline = AutoPipelineForInpainting.from_pretrained(
41
+ "redstonehero/ReV_Animated_Inpainting",
42
+ torch_dtype=torch.float16)
43
+
44
+ pipeline.enable_model_cpu_offload()
45
+
46
+ image1 = pipeline(prompt=prompt,
47
+ num_inference_steps=24,
48
+ image=image,
49
+ mask_image=mask_image,
50
+ guidance_scale=3,
51
+ strength=1.0).images[0]
52
+
53
+ return make_image_grid([image1], rows = 1, cols = 1)
54
+
55
+
56
+
57
+ import gradio as gr
58
+
59
+ def gradio_wrapper(image, prompt, choice):
60
+ return modify_image(image, prompt, int(choice))
61
+
62
+ demo = gr.Interface(
63
+ fn=gradio_wrapper,
64
+ inputs=[
65
+ gr.Image(type="filepath"), # Change gr.inputs.Image to gr.Image
66
+ gr.Textbox(label="Prompt"),
67
+ gr.Radio(["4", "5", "6"], label="Mask ID")
68
+ ],
69
+ outputs=gr.Image()
70
+ )
71
+
72
+ demo.launch(inline=False)