gpu-utils / app.py
not-lain's picture
soft reset
dc95e97
raw
history blame
16.5 kB
import gradio as gr
import spaces
import torch
from loadimg import load_img # Assuming loadimg.py exists with load_img function
from torchvision import transforms
from transformers import AutoModelForImageSegmentation, pipeline
from diffusers import FluxFillPipeline
from PIL import Image, ImageOps
import numpy as np
from simple_lama_inpainting import SimpleLama
from contextlib import contextmanager
import gc
# --- Add Translation Imports ---
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
# --- Utility Functions ---
@contextmanager
def float32_high_matmul_precision():
torch.set_float32_matmul_precision("high")
try:
yield
finally:
torch.set_float32_matmul_precision("highest")
# --- Model Loading ---
# Use context manager for precision during model loading if needed
with float32_high_matmul_precision():
pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
).to("cuda")
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
).to("cuda")
simple_lama = SimpleLama() # Initialize Lama globally if used often
# --- Translation Model and Tokenizer Loading ---
translation_model_name = "facebook/mbart-large-50-many-to-many-mmt"
try:
translation_model = MBartForConditionalGeneration.from_pretrained(
translation_model_name
).to("cuda") # Move to GPU
translation_tokenizer = MBart50TokenizerFast.from_pretrained(translation_model_name)
except Exception as e:
print(f"Error loading translation model/tokenizer: {e}")
# Consider exiting or disabling the translation tab if loading fails
translation_model = None
translation_tokenizer = None
# --- Image Processing Functions ---
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def prepare_image_and_mask(
image,
padding_top=0,
padding_bottom=0,
padding_left=0,
padding_right=0,
):
image = load_img(image).convert("RGB")
background = ImageOps.expand(
image,
border=(padding_left, padding_top, padding_right, padding_bottom),
fill="white",
)
mask = Image.new("RGB", image.size, "black")
mask = ImageOps.expand(
mask,
border=(padding_left, padding_top, padding_right, padding_bottom),
fill="white",
)
return background, mask
def outpaint(
image,
padding_top=0,
padding_bottom=0,
padding_left=0,
padding_right=0,
prompt="",
num_inference_steps=28,
guidance_scale=50,
):
background, mask = prepare_image_and_mask(
image, padding_top, padding_bottom, padding_left, padding_right
)
with (
float32_high_matmul_precision()
): # Apply precision context if needed for inference
result = pipe(
prompt=prompt,
height=background.height,
width=background.width,
image=background,
mask_image=mask,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
result = result.convert("RGBA")
return result
def inpaint(
image,
mask,
prompt="",
num_inference_steps=28,
guidance_scale=50,
):
background = image.convert("RGB")
mask = mask.convert("L")
with (
float32_high_matmul_precision()
): # Apply precision context if needed for inference
result = pipe(
prompt=prompt,
height=background.height,
width=background.width,
image=background,
mask_image=mask,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
result = result.convert("RGBA")
return result
def rmbg(image=None, url=None):
if image is None and url:
# Basic check for URL format, improve as needed
if not url.startswith(("http://", "https://")):
return "Invalid URL provided."
image = url # load_img should handle URLs if configured correctly
elif image is None:
return "Please provide an image or a URL."
try:
image_pil = load_img(image).convert("RGB")
except Exception as e:
return f"Error loading image: {e}"
image_size = image_pil.size
input_images = transform_image(image_pil).unsqueeze(0).to("cuda")
with float32_high_matmul_precision():
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image_pil.putalpha(mask)
# Clean up GPU memory if needed
del input_images, preds, pred
torch.cuda.empty_cache()
gc.collect()
return image_pil
def erase(image=None, mask=None):
if image is None or mask is None:
return "Please provide both an image and a mask."
try:
image_pil = load_img(image)
mask_pil = load_img(mask).convert("L")
result = simple_lama(image_pil, mask_pil)
# Clean up
gc.collect()
return result
except Exception as e:
return f"Error during erase operation: {e}"
# --- Translation Functionality ---
# Language Mapping
lang_data = {
"Arabic": "ar_AR",
"Czech": "cs_CZ",
"German": "de_DE",
"English": "en_XX",
"Spanish": "es_XX",
"Estonian": "et_EE",
"Finnish": "fi_FI",
"French": "fr_XX",
"Gujarati": "gu_IN",
"Hindi": "hi_IN",
"Italian": "it_IT",
"Japanese": "ja_XX",
"Kazakh": "kk_KZ",
"Korean": "ko_KR",
"Lithuanian": "lt_LT",
"Latvian": "lv_LV",
"Burmese": "my_MM",
"Nepali": "ne_NP",
"Dutch": "nl_XX",
"Romanian": "ro_RO",
"Russian": "ru_RU",
"Sinhala": "si_LK",
"Turkish": "tr_TR",
"Vietnamese": "vi_VN",
"Chinese": "zh_CN",
"Afrikaans": "af_ZA",
"Azerbaijani": "az_AZ",
"Bengali": "bn_IN",
"Persian": "fa_IR",
"Hebrew": "he_IL",
"Croatian": "hr_HR",
"Indonesian": "id_ID",
"Georgian": "ka_GE",
"Khmer": "km_KH",
"Macedonian": "mk_MK",
"Malayalam": "ml_IN",
"Mongolian": "mn_MN",
"Marathi": "mr_IN",
"Polish": "pl_PL",
"Pashto": "ps_AF",
"Portuguese": "pt_XX",
"Swedish": "sv_SE",
"Swahili": "sw_KE",
"Tamil": "ta_IN",
"Telugu": "te_IN",
"Thai": "th_TH",
"Tagalog": "tl_XX",
"Ukrainian": "uk_UA",
"Urdu": "ur_PK",
"Xhosa": "xh_ZA",
"Galician": "gl_ES",
"Slovene": "sl_SI",
}
language_names = sorted(list(lang_data.keys()))
def translate_text(text_to_translate, source_language_name, target_language_name):
"""
Translates text using the loaded mBART model.
"""
if translation_model is None or translation_tokenizer is None:
return "Translation model not loaded. Cannot perform translation."
if not text_to_translate:
return "Please enter text to translate."
if not source_language_name:
return "Please select a source language."
if not target_language_name:
return "Please select a target language."
try:
source_lang_code = lang_data[source_language_name]
target_lang_code = lang_data[target_language_name]
translation_tokenizer.src_lang = source_lang_code
encoded_text = translation_tokenizer(text_to_translate, return_tensors="pt").to(
"cuda"
) # Move input to GPU
target_lang_id = translation_tokenizer.lang_code_to_id[target_lang_code]
# Generate translation on GPU
with torch.no_grad(): # Use no_grad for inference
generated_tokens = translation_model.generate(
**encoded_text, forced_bos_token_id=target_lang_id, max_length=200
)
translated_text = translation_tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)
# Clean up GPU memory
del encoded_text, generated_tokens
torch.cuda.empty_cache()
gc.collect()
return translated_text[0]
except KeyError as e:
return f"Error: Language code not found for {e}. Check language mappings."
except Exception as e:
print(f"Translation error: {e}")
# Clean up GPU memory on error too
torch.cuda.empty_cache()
gc.collect()
return f"An error occurred during translation: {e}"
# --- Main Function Router (for image tasks) ---
# Note: Translation uses its own function directly
@spaces.GPU(duration=120) # Keep GPU decorator if needed for image tasks
def main(*args):
api_num = args[0]
args = args[1:]
gc.collect() # Try to collect garbage before starting task
torch.cuda.empty_cache() # Clear cache before starting task
result = None
try:
if api_num == 1:
result = rmbg(*args)
elif api_num == 2:
result = outpaint(*args)
elif api_num == 3:
result = inpaint(*args)
# elif api_num == 4: # Keep commented out as in original
# return mask_generation(*args)
elif api_num == 5:
result = erase(*args)
else:
result = "Invalid API number."
except Exception as e:
print(f"Error in main task routing (api_num={api_num}): {e}")
result = f"An error occurred: {e}"
finally:
# Ensure memory cleanup happens even if there's an error
gc.collect()
torch.cuda.empty_cache()
return result
# --- Define Gradio Interfaces for Each Tab ---
# Image Task Tabs
rmbg_tab = gr.Interface(
fn=main,
inputs=[
gr.Number(1, interactive=False, visible=False), # Hide API number
gr.Image(label="Input Image", type="pil", sources=["upload", "clipboard"]),
gr.Text(label="Or Image URL (optional)"),
],
outputs=gr.Image(label="Output Image", type="pil"),
title="Remove Background",
description="Upload an image or provide a URL to remove its background.",
api_name="rmbg",
# examples=[[1, "./assets/sample_rmbg.png", ""]], # Update example path if needed
cache_examples=False,
)
outpaint_tab = gr.Interface(
fn=main,
inputs=[
gr.Number(2, interactive=False, visible=False),
gr.Image(label="Input Image", type="pil", sources=["upload", "clipboard"]),
gr.Number(value=0, label="Padding Top (pixels)"),
gr.Number(value=0, label="Padding Bottom (pixels)"),
gr.Number(value=0, label="Padding Left (pixels)"),
gr.Number(value=0, label="Padding Right (pixels)"),
gr.Text(
label="Prompt (optional)",
info="Describe what to fill the extended area with",
),
gr.Slider(
minimum=10, maximum=100, step=1, value=28, label="Inference Steps"
), # Use slider for steps
gr.Slider(
minimum=1, maximum=100, step=1, value=50, label="Guidance Scale"
), # Use slider for guidance
],
outputs=gr.Image(label="Outpainted Image", type="pil"),
title="Outpainting",
description="Extend an image by adding padding and filling the new area using a diffusion model.",
api_name="outpainting",
# examples=[[2, "./assets/rocket.png", 100, 0, 0, 0, "", 28, 50]], # Update example path
cache_examples=False,
)
inpaint_tab = gr.Interface(
fn=main,
inputs=[
gr.Number(3, interactive=False, visible=False),
gr.Image(label="Input Image", type="pil", sources=["upload", "clipboard"]),
gr.Image(
label="Mask Image (White=Inpaint Area)",
type="pil",
sources=["upload", "clipboard"],
),
gr.Text(
label="Prompt (optional)", info="Describe what to fill the masked area with"
),
gr.Slider(minimum=10, maximum=100, step=1, value=28, label="Inference Steps"),
gr.Slider(minimum=1, maximum=100, step=1, value=50, label="Guidance Scale"),
],
outputs=gr.Image(label="Inpainted Image", type="pil"),
title="Inpainting",
description="Fill in the white areas of a mask applied to an image using a diffusion model.",
api_name="inpaint",
# examples=[[3, "./assets/rocket.png", "./assets/Inpainting_mask.png", "", 28, 50]], # Update example paths
cache_examples=False,
)
erase_tab = gr.Interface(
fn=main,
inputs=[
gr.Number(5, interactive=False, visible=False),
gr.Image(label="Input Image", type="pil", sources=["upload", "clipboard"]),
gr.Image(
label="Mask Image (White=Erase Area)",
type="pil",
sources=["upload", "clipboard"],
),
],
outputs=gr.Image(label="Result Image", type="pil"),
title="Erase Object (LAMA)",
description="Erase objects from an image based on a mask using the LaMa inpainting model.",
api_name="erase",
# examples=[[5, "./assets/rocket.png", "./assets/Inpainting_mask.png"]], # Update example paths
cache_examples=False,
)
# --- Define Translation Tab using gr.Blocks ---
with gr.Blocks() as translation_tab:
gr.Markdown(
"""
## Multilingual Translation (mBART-50)
Translate text between 50 different languages.
Select the source and target languages, enter your text, and click Translate.
"""
)
with gr.Row():
with gr.Column(scale=1):
source_lang_dropdown = gr.Dropdown(
label="Source Language",
choices=language_names,
info="Select the language of your input text.",
)
target_lang_dropdown = gr.Dropdown(
label="Target Language",
choices=language_names,
info="Select the language you want to translate to.",
)
with gr.Column(scale=2):
input_textbox = gr.Textbox(
label="Text to Translate",
lines=6, # Increased lines
placeholder="Enter text here...",
)
translate_button = gr.Button(
"Translate", variant="primary"
) # Added variant
output_textbox = gr.Textbox(
label="Translated Text",
lines=6, # Increased lines
interactive=False, # Make output read-only
)
# Connect Components to the translation function directly
translate_button.click(
fn=translate_text,
inputs=[input_textbox, source_lang_dropdown, target_lang_dropdown],
outputs=output_textbox,
api_name="translate", # Add API name for the translation endpoint
)
# Add Translation Examples
gr.Examples(
examples=[
[
"संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है",
"Hindi",
"French",
],
[
"الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا.",
"Arabic",
"English",
],
[
"Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie.",
"French",
"German",
],
["Hello world! How are you today?", "English", "Spanish"],
["Guten Tag!", "German", "Japanese"],
["これはテストです", "Japanese", "English"],
],
inputs=[input_textbox, source_lang_dropdown, target_lang_dropdown],
outputs=output_textbox,
fn=translate_text,
cache_examples=False,
)
# --- Combine all tabs ---
demo = gr.TabbedInterface(
[
rmbg_tab,
outpaint_tab,
inpaint_tab,
erase_tab,
translation_tab, # Add the translation tab
# sam2_tab, # Keep commented out
],
[
"Remove Background", # Tab title
"Outpainting", # Tab title
"Inpainting", # Tab title
"Erase (LAMA)", # Tab title
"Translate", # Tab title for translation
# "sam2",
],
title="Image & Text Utilities (GPU)", # Updated title
)
demo.launch()