|
import gradio as gr |
|
import spaces |
|
import torch |
|
from loadimg import load_img |
|
from torchvision import transforms |
|
from transformers import AutoModelForImageSegmentation, pipeline |
|
from diffusers import FluxFillPipeline |
|
from PIL import Image, ImageOps |
|
import numpy as np |
|
from simple_lama_inpainting import SimpleLama |
|
from contextlib import contextmanager |
|
import gc |
|
|
|
|
|
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast |
|
|
|
|
|
|
|
@contextmanager |
|
def float32_high_matmul_precision(): |
|
torch.set_float32_matmul_precision("high") |
|
try: |
|
yield |
|
finally: |
|
torch.set_float32_matmul_precision("highest") |
|
|
|
|
|
|
|
|
|
with float32_high_matmul_precision(): |
|
pipe = FluxFillPipeline.from_pretrained( |
|
"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16 |
|
).to("cuda") |
|
|
|
birefnet = AutoModelForImageSegmentation.from_pretrained( |
|
"ZhengPeng7/BiRefNet", trust_remote_code=True |
|
).to("cuda") |
|
|
|
simple_lama = SimpleLama() |
|
|
|
|
|
translation_model_name = "facebook/mbart-large-50-many-to-many-mmt" |
|
try: |
|
translation_model = MBartForConditionalGeneration.from_pretrained( |
|
translation_model_name |
|
).to("cuda") |
|
translation_tokenizer = MBart50TokenizerFast.from_pretrained(translation_model_name) |
|
except Exception as e: |
|
print(f"Error loading translation model/tokenizer: {e}") |
|
|
|
translation_model = None |
|
translation_tokenizer = None |
|
|
|
|
|
|
|
transform_image = transforms.Compose( |
|
[ |
|
transforms.Resize((1024, 1024)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
|
|
def prepare_image_and_mask( |
|
image, |
|
padding_top=0, |
|
padding_bottom=0, |
|
padding_left=0, |
|
padding_right=0, |
|
): |
|
image = load_img(image).convert("RGB") |
|
background = ImageOps.expand( |
|
image, |
|
border=(padding_left, padding_top, padding_right, padding_bottom), |
|
fill="white", |
|
) |
|
mask = Image.new("RGB", image.size, "black") |
|
mask = ImageOps.expand( |
|
mask, |
|
border=(padding_left, padding_top, padding_right, padding_bottom), |
|
fill="white", |
|
) |
|
return background, mask |
|
|
|
|
|
def outpaint( |
|
image, |
|
padding_top=0, |
|
padding_bottom=0, |
|
padding_left=0, |
|
padding_right=0, |
|
prompt="", |
|
num_inference_steps=28, |
|
guidance_scale=50, |
|
): |
|
background, mask = prepare_image_and_mask( |
|
image, padding_top, padding_bottom, padding_left, padding_right |
|
) |
|
with ( |
|
float32_high_matmul_precision() |
|
): |
|
result = pipe( |
|
prompt=prompt, |
|
height=background.height, |
|
width=background.width, |
|
image=background, |
|
mask_image=mask, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
).images[0] |
|
result = result.convert("RGBA") |
|
return result |
|
|
|
|
|
def inpaint( |
|
image, |
|
mask, |
|
prompt="", |
|
num_inference_steps=28, |
|
guidance_scale=50, |
|
): |
|
background = image.convert("RGB") |
|
mask = mask.convert("L") |
|
with ( |
|
float32_high_matmul_precision() |
|
): |
|
result = pipe( |
|
prompt=prompt, |
|
height=background.height, |
|
width=background.width, |
|
image=background, |
|
mask_image=mask, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
).images[0] |
|
result = result.convert("RGBA") |
|
return result |
|
|
|
|
|
def rmbg(image=None, url=None): |
|
if image is None and url: |
|
|
|
if not url.startswith(("http://", "https://")): |
|
return "Invalid URL provided." |
|
image = url |
|
elif image is None: |
|
return "Please provide an image or a URL." |
|
|
|
try: |
|
image_pil = load_img(image).convert("RGB") |
|
except Exception as e: |
|
return f"Error loading image: {e}" |
|
|
|
image_size = image_pil.size |
|
input_images = transform_image(image_pil).unsqueeze(0).to("cuda") |
|
with float32_high_matmul_precision(): |
|
with torch.no_grad(): |
|
preds = birefnet(input_images)[-1].sigmoid().cpu() |
|
pred = preds[0].squeeze() |
|
pred_pil = transforms.ToPILImage()(pred) |
|
mask = pred_pil.resize(image_size) |
|
image_pil.putalpha(mask) |
|
|
|
del input_images, preds, pred |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
return image_pil |
|
|
|
|
|
def erase(image=None, mask=None): |
|
if image is None or mask is None: |
|
return "Please provide both an image and a mask." |
|
try: |
|
image_pil = load_img(image) |
|
mask_pil = load_img(mask).convert("L") |
|
result = simple_lama(image_pil, mask_pil) |
|
|
|
gc.collect() |
|
return result |
|
except Exception as e: |
|
return f"Error during erase operation: {e}" |
|
|
|
|
|
|
|
|
|
|
|
lang_data = { |
|
"Arabic": "ar_AR", |
|
"Czech": "cs_CZ", |
|
"German": "de_DE", |
|
"English": "en_XX", |
|
"Spanish": "es_XX", |
|
"Estonian": "et_EE", |
|
"Finnish": "fi_FI", |
|
"French": "fr_XX", |
|
"Gujarati": "gu_IN", |
|
"Hindi": "hi_IN", |
|
"Italian": "it_IT", |
|
"Japanese": "ja_XX", |
|
"Kazakh": "kk_KZ", |
|
"Korean": "ko_KR", |
|
"Lithuanian": "lt_LT", |
|
"Latvian": "lv_LV", |
|
"Burmese": "my_MM", |
|
"Nepali": "ne_NP", |
|
"Dutch": "nl_XX", |
|
"Romanian": "ro_RO", |
|
"Russian": "ru_RU", |
|
"Sinhala": "si_LK", |
|
"Turkish": "tr_TR", |
|
"Vietnamese": "vi_VN", |
|
"Chinese": "zh_CN", |
|
"Afrikaans": "af_ZA", |
|
"Azerbaijani": "az_AZ", |
|
"Bengali": "bn_IN", |
|
"Persian": "fa_IR", |
|
"Hebrew": "he_IL", |
|
"Croatian": "hr_HR", |
|
"Indonesian": "id_ID", |
|
"Georgian": "ka_GE", |
|
"Khmer": "km_KH", |
|
"Macedonian": "mk_MK", |
|
"Malayalam": "ml_IN", |
|
"Mongolian": "mn_MN", |
|
"Marathi": "mr_IN", |
|
"Polish": "pl_PL", |
|
"Pashto": "ps_AF", |
|
"Portuguese": "pt_XX", |
|
"Swedish": "sv_SE", |
|
"Swahili": "sw_KE", |
|
"Tamil": "ta_IN", |
|
"Telugu": "te_IN", |
|
"Thai": "th_TH", |
|
"Tagalog": "tl_XX", |
|
"Ukrainian": "uk_UA", |
|
"Urdu": "ur_PK", |
|
"Xhosa": "xh_ZA", |
|
"Galician": "gl_ES", |
|
"Slovene": "sl_SI", |
|
} |
|
language_names = sorted(list(lang_data.keys())) |
|
|
|
|
|
def translate_text(text_to_translate, source_language_name, target_language_name): |
|
""" |
|
Translates text using the loaded mBART model. |
|
""" |
|
if translation_model is None or translation_tokenizer is None: |
|
return "Translation model not loaded. Cannot perform translation." |
|
if not text_to_translate: |
|
return "Please enter text to translate." |
|
if not source_language_name: |
|
return "Please select a source language." |
|
if not target_language_name: |
|
return "Please select a target language." |
|
|
|
try: |
|
source_lang_code = lang_data[source_language_name] |
|
target_lang_code = lang_data[target_language_name] |
|
|
|
translation_tokenizer.src_lang = source_lang_code |
|
encoded_text = translation_tokenizer(text_to_translate, return_tensors="pt").to( |
|
"cuda" |
|
) |
|
target_lang_id = translation_tokenizer.lang_code_to_id[target_lang_code] |
|
|
|
|
|
with torch.no_grad(): |
|
generated_tokens = translation_model.generate( |
|
**encoded_text, forced_bos_token_id=target_lang_id, max_length=200 |
|
) |
|
|
|
translated_text = translation_tokenizer.batch_decode( |
|
generated_tokens, skip_special_tokens=True |
|
) |
|
|
|
|
|
del encoded_text, generated_tokens |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
return translated_text[0] |
|
|
|
except KeyError as e: |
|
return f"Error: Language code not found for {e}. Check language mappings." |
|
except Exception as e: |
|
print(f"Translation error: {e}") |
|
|
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
return f"An error occurred during translation: {e}" |
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def main(*args): |
|
api_num = args[0] |
|
args = args[1:] |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
result = None |
|
try: |
|
if api_num == 1: |
|
result = rmbg(*args) |
|
elif api_num == 2: |
|
result = outpaint(*args) |
|
elif api_num == 3: |
|
result = inpaint(*args) |
|
|
|
|
|
elif api_num == 5: |
|
result = erase(*args) |
|
else: |
|
result = "Invalid API number." |
|
except Exception as e: |
|
print(f"Error in main task routing (api_num={api_num}): {e}") |
|
result = f"An error occurred: {e}" |
|
finally: |
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
rmbg_tab = gr.Interface( |
|
fn=main, |
|
inputs=[ |
|
gr.Number(1, interactive=False, visible=False), |
|
gr.Image(label="Input Image", type="pil", sources=["upload", "clipboard"]), |
|
gr.Text(label="Or Image URL (optional)"), |
|
], |
|
outputs=gr.Image(label="Output Image", type="pil"), |
|
title="Remove Background", |
|
description="Upload an image or provide a URL to remove its background.", |
|
api_name="rmbg", |
|
|
|
cache_examples=False, |
|
) |
|
|
|
outpaint_tab = gr.Interface( |
|
fn=main, |
|
inputs=[ |
|
gr.Number(2, interactive=False, visible=False), |
|
gr.Image(label="Input Image", type="pil", sources=["upload", "clipboard"]), |
|
gr.Number(value=0, label="Padding Top (pixels)"), |
|
gr.Number(value=0, label="Padding Bottom (pixels)"), |
|
gr.Number(value=0, label="Padding Left (pixels)"), |
|
gr.Number(value=0, label="Padding Right (pixels)"), |
|
gr.Text( |
|
label="Prompt (optional)", |
|
info="Describe what to fill the extended area with", |
|
), |
|
gr.Slider( |
|
minimum=10, maximum=100, step=1, value=28, label="Inference Steps" |
|
), |
|
gr.Slider( |
|
minimum=1, maximum=100, step=1, value=50, label="Guidance Scale" |
|
), |
|
], |
|
outputs=gr.Image(label="Outpainted Image", type="pil"), |
|
title="Outpainting", |
|
description="Extend an image by adding padding and filling the new area using a diffusion model.", |
|
api_name="outpainting", |
|
|
|
cache_examples=False, |
|
) |
|
|
|
inpaint_tab = gr.Interface( |
|
fn=main, |
|
inputs=[ |
|
gr.Number(3, interactive=False, visible=False), |
|
gr.Image(label="Input Image", type="pil", sources=["upload", "clipboard"]), |
|
gr.Image( |
|
label="Mask Image (White=Inpaint Area)", |
|
type="pil", |
|
sources=["upload", "clipboard"], |
|
), |
|
gr.Text( |
|
label="Prompt (optional)", info="Describe what to fill the masked area with" |
|
), |
|
gr.Slider(minimum=10, maximum=100, step=1, value=28, label="Inference Steps"), |
|
gr.Slider(minimum=1, maximum=100, step=1, value=50, label="Guidance Scale"), |
|
], |
|
outputs=gr.Image(label="Inpainted Image", type="pil"), |
|
title="Inpainting", |
|
description="Fill in the white areas of a mask applied to an image using a diffusion model.", |
|
api_name="inpaint", |
|
|
|
cache_examples=False, |
|
) |
|
|
|
erase_tab = gr.Interface( |
|
fn=main, |
|
inputs=[ |
|
gr.Number(5, interactive=False, visible=False), |
|
gr.Image(label="Input Image", type="pil", sources=["upload", "clipboard"]), |
|
gr.Image( |
|
label="Mask Image (White=Erase Area)", |
|
type="pil", |
|
sources=["upload", "clipboard"], |
|
), |
|
], |
|
outputs=gr.Image(label="Result Image", type="pil"), |
|
title="Erase Object (LAMA)", |
|
description="Erase objects from an image based on a mask using the LaMa inpainting model.", |
|
api_name="erase", |
|
|
|
cache_examples=False, |
|
) |
|
|
|
|
|
|
|
with gr.Blocks() as translation_tab: |
|
gr.Markdown( |
|
""" |
|
## Multilingual Translation (mBART-50) |
|
Translate text between 50 different languages. |
|
Select the source and target languages, enter your text, and click Translate. |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
source_lang_dropdown = gr.Dropdown( |
|
label="Source Language", |
|
choices=language_names, |
|
info="Select the language of your input text.", |
|
) |
|
target_lang_dropdown = gr.Dropdown( |
|
label="Target Language", |
|
choices=language_names, |
|
info="Select the language you want to translate to.", |
|
) |
|
with gr.Column(scale=2): |
|
input_textbox = gr.Textbox( |
|
label="Text to Translate", |
|
lines=6, |
|
placeholder="Enter text here...", |
|
) |
|
translate_button = gr.Button( |
|
"Translate", variant="primary" |
|
) |
|
output_textbox = gr.Textbox( |
|
label="Translated Text", |
|
lines=6, |
|
interactive=False, |
|
) |
|
|
|
|
|
translate_button.click( |
|
fn=translate_text, |
|
inputs=[input_textbox, source_lang_dropdown, target_lang_dropdown], |
|
outputs=output_textbox, |
|
api_name="translate", |
|
) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
[ |
|
"संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है", |
|
"Hindi", |
|
"French", |
|
], |
|
[ |
|
"الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا.", |
|
"Arabic", |
|
"English", |
|
], |
|
[ |
|
"Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie.", |
|
"French", |
|
"German", |
|
], |
|
["Hello world! How are you today?", "English", "Spanish"], |
|
["Guten Tag!", "German", "Japanese"], |
|
["これはテストです", "Japanese", "English"], |
|
], |
|
inputs=[input_textbox, source_lang_dropdown, target_lang_dropdown], |
|
outputs=output_textbox, |
|
fn=translate_text, |
|
cache_examples=False, |
|
) |
|
|
|
|
|
demo = gr.TabbedInterface( |
|
[ |
|
rmbg_tab, |
|
outpaint_tab, |
|
inpaint_tab, |
|
erase_tab, |
|
translation_tab, |
|
|
|
], |
|
[ |
|
"Remove Background", |
|
"Outpainting", |
|
"Inpainting", |
|
"Erase (LAMA)", |
|
"Translate", |
|
|
|
], |
|
title="Image & Text Utilities (GPU)", |
|
) |
|
|
|
demo.launch() |
|
|