File size: 2,031 Bytes
936e569 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from IPython import get_ipython
from IPython.display import display
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
from PIL import Image
import requests
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
from torchvision import transforms
from transformers import SamModel, SamProcessor
from diffusers import AutoPipelineForInpainting
from diffusers.utils import load_image, make_image_grid
from google.colab import drive
drive.mount('/content/drive')
def modify_image(image_url, prompt, mask_id=4):
processor = SegformerImageProcessor.from_pretrained("sayeed99/segformer_b3_clothes")
model = AutoModelForSemanticSegmentation.from_pretrained("sayeed99/segformer_b3_clothes")
image = Image.open(image_url)
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = nn.functional.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
pred_seg = upsampled_logits.argmax(dim=1)[0]
mask = (pred_seg == mask_id).numpy()
mask_image = Image.fromarray((mask * 255).astype('uint8'))
pipeline = AutoPipelineForInpainting.from_pretrained(
"redstonehero/ReV_Animated_Inpainting",
torch_dtype=torch.float16)
pipeline.enable_model_cpu_offload()
image1 = pipeline(prompt=prompt,
num_inference_steps=24,
image=image,
mask_image=mask_image,
guidance_scale=3,
strength=1.0).images[0]
return make_image_grid([image1], rows = 1, cols = 1)
import gradio as gr
def gradio_wrapper(image, prompt, choice):
return modify_image(image, prompt, int(choice))
demo = gr.Interface(
fn=gradio_wrapper,
inputs=[
gr.Image(type="filepath"), # Change gr.inputs.Image to gr.Image
gr.Textbox(label="Prompt"),
gr.Radio(["4", "5", "6"], label="Mask ID")
],
outputs=gr.Image()
)
demo.launch(inline=False) |