|
import gradio as gr |
|
import spaces |
|
import torch |
|
from loadimg import load_img |
|
from torchvision import transforms |
|
from transformers import ( |
|
AutoModelForImageSegmentation, |
|
pipeline, |
|
MBartForConditionalGeneration, |
|
MBart50TokenizerFast, |
|
) |
|
from diffusers import FluxFillPipeline |
|
from PIL import Image, ImageOps |
|
|
|
|
|
import numpy as np |
|
from simple_lama_inpainting import SimpleLama |
|
from contextlib import contextmanager |
|
|
|
|
|
import gc |
|
|
|
|
|
@contextmanager |
|
def float32_high_matmul_precision(): |
|
torch.set_float32_matmul_precision("high") |
|
try: |
|
yield |
|
finally: |
|
torch.set_float32_matmul_precision("highest") |
|
|
|
|
|
pipe = FluxFillPipeline.from_pretrained( |
|
"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16 |
|
).to("cuda") |
|
|
|
birefnet = AutoModelForImageSegmentation.from_pretrained( |
|
"ZhengPeng7/BiRefNet", trust_remote_code=True |
|
) |
|
birefnet.to("cuda") |
|
|
|
transform_image = transforms.Compose( |
|
[ |
|
transforms.Resize((1024, 1024)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
|
|
def prepare_image_and_mask( |
|
image, |
|
padding_top=0, |
|
padding_bottom=0, |
|
padding_left=0, |
|
padding_right=0, |
|
): |
|
image = load_img(image).convert("RGB") |
|
|
|
background = ImageOps.expand( |
|
image, |
|
border=(padding_left, padding_top, padding_right, padding_bottom), |
|
fill="white", |
|
) |
|
mask = Image.new("RGB", image.size, "black") |
|
mask = ImageOps.expand( |
|
mask, |
|
border=(padding_left, padding_top, padding_right, padding_bottom), |
|
fill="white", |
|
) |
|
return background, mask |
|
|
|
|
|
def outpaint( |
|
image, |
|
padding_top=0, |
|
padding_bottom=0, |
|
padding_left=0, |
|
padding_right=0, |
|
prompt="", |
|
num_inference_steps=28, |
|
guidance_scale=50, |
|
): |
|
background, mask = prepare_image_and_mask( |
|
image, padding_top, padding_bottom, padding_left, padding_right |
|
) |
|
|
|
result = pipe( |
|
prompt=prompt, |
|
height=background.height, |
|
width=background.width, |
|
image=background, |
|
mask_image=mask, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
).images[0] |
|
|
|
result = result.convert("RGBA") |
|
|
|
return result |
|
|
|
|
|
def inpaint( |
|
image, |
|
mask, |
|
prompt="", |
|
num_inference_steps=28, |
|
guidance_scale=50, |
|
): |
|
background = image.convert("RGB") |
|
mask = mask.convert("L") |
|
|
|
result = pipe( |
|
prompt=prompt, |
|
height=background.height, |
|
width=background.width, |
|
image=background, |
|
mask_image=mask, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
).images[0] |
|
|
|
result = result.convert("RGBA") |
|
|
|
return result |
|
|
|
|
|
def rmbg(image=None, url=None): |
|
if image is None: |
|
image = url |
|
image = load_img(image).convert("RGB") |
|
image_size = image.size |
|
input_images = transform_image(image).unsqueeze(0).to("cuda") |
|
with float32_high_matmul_precision(): |
|
|
|
with torch.no_grad(): |
|
preds = birefnet(input_images)[-1].sigmoid().cpu() |
|
pred = preds[0].squeeze() |
|
pred_pil = transforms.ToPILImage()(pred) |
|
mask = pred_pil.resize(image_size) |
|
image.putalpha(mask) |
|
return image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def erase(image=None, mask=None): |
|
simple_lama = SimpleLama() |
|
image = load_img(image) |
|
mask = load_img(mask).convert("L") |
|
return simple_lama(image, mask) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def translate_text(text, source_lang, target_lang): |
|
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") |
|
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") |
|
|
|
|
|
tokenizer.src_lang = source_lang |
|
|
|
|
|
encoded_text = tokenizer(text, return_tensors="pt") |
|
|
|
|
|
generated_tokens = model.generate( |
|
**encoded_text, |
|
forced_bos_token_id=tokenizer.lang_code_to_id[target_lang] |
|
) |
|
|
|
|
|
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
|
|
|
|
del model |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
return translation |
|
|
|
@spaces.GPU(duration=120) |
|
def main(*args): |
|
api_num = args[0] |
|
args = args[1:] |
|
if api_num == 1: |
|
return rmbg(*args) |
|
elif api_num == 2: |
|
return outpaint(*args) |
|
elif api_num == 3: |
|
return inpaint(*args) |
|
|
|
|
|
elif api_num == 5: |
|
return erase(*args) |
|
|
|
|
|
elif api_num == 7: |
|
return translate_text(*args) |
|
|
|
|
|
rmbg_tab = gr.Interface( |
|
fn=main, |
|
inputs=[ |
|
gr.Number(1, interactive=False), |
|
"image", |
|
gr.Text("", label="url"), |
|
], |
|
outputs=["image"], |
|
api_name="rmbg", |
|
examples=[[1, "./assets/Inpainting mask.png", ""]], |
|
cache_examples=False, |
|
description="pass an image or a url of an image", |
|
) |
|
|
|
outpaint_tab = gr.Interface( |
|
fn=main, |
|
inputs=[ |
|
gr.Number(2, interactive=False), |
|
gr.Image(label="image", type="pil"), |
|
gr.Number(label="padding top"), |
|
gr.Number(label="padding bottom"), |
|
gr.Number(label="padding left"), |
|
gr.Number(label="padding right"), |
|
gr.Text(label="prompt"), |
|
gr.Number(value=50, label="num_inference_steps"), |
|
gr.Number(value=28, label="guidance_scale"), |
|
], |
|
outputs=["image"], |
|
api_name="outpainting", |
|
examples=[[2, "./assets/rocket.png", 100, 0, 0, 0, "", 50, 28]], |
|
cache_examples=False, |
|
) |
|
|
|
|
|
inpaint_tab = gr.Interface( |
|
fn=main, |
|
inputs=[ |
|
gr.Number(3, interactive=False), |
|
gr.Image(label="image", type="pil"), |
|
gr.Image(label="mask", type="pil"), |
|
gr.Text(label="prompt"), |
|
gr.Number(value=50, label="num_inference_steps"), |
|
gr.Number(value=28, label="guidance_scale"), |
|
], |
|
outputs=["image"], |
|
api_name="inpaint", |
|
examples=[[3, "./assets/rocket.png", "./assets/Inpainting mask.png"]], |
|
cache_examples=False, |
|
description="it is recommended that you use https://github.com/la-voliere/react-mask-editor when creating an image mask in JS and then inverse it before sending it to this space", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
erase_tab = gr.Interface( |
|
main, |
|
inputs=[ |
|
gr.Number(5, interactive=False), |
|
gr.Image(type="pil"), |
|
gr.Image(type="pil"), |
|
], |
|
outputs=gr.Image(), |
|
examples=[ |
|
[ |
|
5, |
|
"./assets/rocket.png", |
|
"./assets/Inpainting mask.png", |
|
] |
|
], |
|
api_name="erase", |
|
cache_examples=False, |
|
) |
|
|
|
transcribe_tab = gr.Interface( |
|
fn=main, |
|
inputs=[ |
|
gr.Number(value=6, interactive=False), |
|
gr.Audio(type="filepath", label="Audio File"), |
|
], |
|
outputs=gr.Textbox(label="Transcription"), |
|
title="Audio Transcription", |
|
description="Upload an audio file to extract text using WhisperX with speaker diarization", |
|
api_name="transcribe", |
|
examples=[], |
|
) |
|
|
|
translate_tab = gr.Interface( |
|
fn=main, |
|
inputs=[ |
|
gr.Number(value=7, interactive=False), |
|
gr.Textbox(label="Text to translate"), |
|
gr.Dropdown( |
|
choices=[ |
|
"ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX", |
|
"gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV", |
|
"my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN", |
|
"zh_CN", "af_ZA", "az_AZ", "bn_IN", "fa_IR", "he_IL", "hr_HR", "id_ID", |
|
"ka_GE", "km_KH", "mk_MK", "ml_IN", "mn_MN", "mr_IN", "pl_PL", "ps_AF", |
|
"pt_XX", "sv_SE", "sw_KE", "ta_IN", "te_IN", "th_TH", "tl_XX", "uk_UA", |
|
"ur_PK", "xh_ZA", "gl_ES", "sl_SI" |
|
], |
|
label="Source Language", |
|
value="en_XX" |
|
), |
|
gr.Dropdown( |
|
choices=[ |
|
"ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX", |
|
"gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV", |
|
"my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN", |
|
"zh_CN", "af_ZA", "az_AZ", "bn_IN", "fa_IR", "he_IL", "hr_HR", "id_ID", |
|
"ka_GE", "km_KH", "mk_MK", "ml_IN", "mn_MN", "mr_IN", "pl_PL", "ps_AF", |
|
"pt_XX", "sv_SE", "sw_KE", "ta_IN", "te_IN", "th_TH", "tl_XX", "uk_UA", |
|
"ur_PK", "xh_ZA", "gl_ES", "sl_SI" |
|
], |
|
label="Target Language", |
|
value="fr_XX" |
|
) |
|
], |
|
outputs=gr.Textbox(label="Translated Text"), |
|
title="Text Translation", |
|
description="Translate text between multiple languages using mBART-50", |
|
api_name="translate", |
|
examples=[ |
|
[7, "Hello, how are you?", "en_XX", "fr_XX"], |
|
[7, "Bonjour, comment allez-vous?", "fr_XX", "en_XX"] |
|
] |
|
) |
|
|
|
demo = gr.TabbedInterface( |
|
[ |
|
rmbg_tab, |
|
outpaint_tab, |
|
inpaint_tab, |
|
erase_tab, |
|
transcribe_tab, |
|
translate_tab |
|
], |
|
[ |
|
"remove background", |
|
"outpainting", |
|
"inpainting", |
|
"erase", |
|
"transcribe", |
|
"translate" |
|
], |
|
title="Utilities that require GPU", |
|
) |
|
|
|
demo.launch() |
|
|