import gradio as gr import spaces import torch from loadimg import load_img # Assuming loadimg.py exists with load_img function from torchvision import transforms from transformers import AutoModelForImageSegmentation, pipeline from diffusers import FluxFillPipeline from PIL import Image, ImageOps import numpy as np from simple_lama_inpainting import SimpleLama from contextlib import contextmanager import gc # --- Add Translation Imports --- from transformers import MBartForConditionalGeneration, MBart50TokenizerFast # --- Utility Functions --- @contextmanager def float32_high_matmul_precision(): torch.set_float32_matmul_precision("high") try: yield finally: torch.set_float32_matmul_precision("highest") # --- Model Loading --- # Use context manager for precision during model loading if needed with float32_high_matmul_precision(): pipe = FluxFillPipeline.from_pretrained( "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16 ).to("cuda") birefnet = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True ).to("cuda") simple_lama = SimpleLama() # Initialize Lama globally if used often # --- Translation Model and Tokenizer Loading --- translation_model_name = "facebook/mbart-large-50-many-to-many-mmt" try: translation_model = MBartForConditionalGeneration.from_pretrained( translation_model_name ).to("cuda") # Move to GPU translation_tokenizer = MBart50TokenizerFast.from_pretrained(translation_model_name) except Exception as e: print(f"Error loading translation model/tokenizer: {e}") # Consider exiting or disabling the translation tab if loading fails translation_model = None translation_tokenizer = None # --- Image Processing Functions --- transform_image = transforms.Compose( [ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) def prepare_image_and_mask( image, padding_top=0, padding_bottom=0, padding_left=0, padding_right=0, ): image = load_img(image).convert("RGB") background = ImageOps.expand( image, border=(padding_left, padding_top, padding_right, padding_bottom), fill="white", ) mask = Image.new("RGB", image.size, "black") mask = ImageOps.expand( mask, border=(padding_left, padding_top, padding_right, padding_bottom), fill="white", ) return background, mask def outpaint( image, padding_top=0, padding_bottom=0, padding_left=0, padding_right=0, prompt="", num_inference_steps=28, guidance_scale=50, ): background, mask = prepare_image_and_mask( image, padding_top, padding_bottom, padding_left, padding_right ) with ( float32_high_matmul_precision() ): # Apply precision context if needed for inference result = pipe( prompt=prompt, height=background.height, width=background.width, image=background, mask_image=mask, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ).images[0] result = result.convert("RGBA") return result def inpaint( image, mask, prompt="", num_inference_steps=28, guidance_scale=50, ): background = image.convert("RGB") mask = mask.convert("L") with ( float32_high_matmul_precision() ): # Apply precision context if needed for inference result = pipe( prompt=prompt, height=background.height, width=background.width, image=background, mask_image=mask, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ).images[0] result = result.convert("RGBA") return result def rmbg(image=None, url=None): if image is None and url: # Basic check for URL format, improve as needed if not url.startswith(("http://", "https://")): return "Invalid URL provided." image = url # load_img should handle URLs if configured correctly elif image is None: return "Please provide an image or a URL." try: image_pil = load_img(image).convert("RGB") except Exception as e: return f"Error loading image: {e}" image_size = image_pil.size input_images = transform_image(image_pil).unsqueeze(0).to("cuda") with float32_high_matmul_precision(): with torch.no_grad(): preds = birefnet(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) image_pil.putalpha(mask) # Clean up GPU memory if needed del input_images, preds, pred torch.cuda.empty_cache() gc.collect() return image_pil def erase(image=None, mask=None): if image is None or mask is None: return "Please provide both an image and a mask." try: image_pil = load_img(image) mask_pil = load_img(mask).convert("L") result = simple_lama(image_pil, mask_pil) # Clean up gc.collect() return result except Exception as e: return f"Error during erase operation: {e}" # --- Translation Functionality --- # Language Mapping lang_data = { "Arabic": "ar_AR", "Czech": "cs_CZ", "German": "de_DE", "English": "en_XX", "Spanish": "es_XX", "Estonian": "et_EE", "Finnish": "fi_FI", "French": "fr_XX", "Gujarati": "gu_IN", "Hindi": "hi_IN", "Italian": "it_IT", "Japanese": "ja_XX", "Kazakh": "kk_KZ", "Korean": "ko_KR", "Lithuanian": "lt_LT", "Latvian": "lv_LV", "Burmese": "my_MM", "Nepali": "ne_NP", "Dutch": "nl_XX", "Romanian": "ro_RO", "Russian": "ru_RU", "Sinhala": "si_LK", "Turkish": "tr_TR", "Vietnamese": "vi_VN", "Chinese": "zh_CN", "Afrikaans": "af_ZA", "Azerbaijani": "az_AZ", "Bengali": "bn_IN", "Persian": "fa_IR", "Hebrew": "he_IL", "Croatian": "hr_HR", "Indonesian": "id_ID", "Georgian": "ka_GE", "Khmer": "km_KH", "Macedonian": "mk_MK", "Malayalam": "ml_IN", "Mongolian": "mn_MN", "Marathi": "mr_IN", "Polish": "pl_PL", "Pashto": "ps_AF", "Portuguese": "pt_XX", "Swedish": "sv_SE", "Swahili": "sw_KE", "Tamil": "ta_IN", "Telugu": "te_IN", "Thai": "th_TH", "Tagalog": "tl_XX", "Ukrainian": "uk_UA", "Urdu": "ur_PK", "Xhosa": "xh_ZA", "Galician": "gl_ES", "Slovene": "sl_SI", } language_names = sorted(list(lang_data.keys())) def translate_text(text_to_translate, source_language_name, target_language_name): """ Translates text using the loaded mBART model. """ if translation_model is None or translation_tokenizer is None: return "Translation model not loaded. Cannot perform translation." if not text_to_translate: return "Please enter text to translate." if not source_language_name: return "Please select a source language." if not target_language_name: return "Please select a target language." try: source_lang_code = lang_data[source_language_name] target_lang_code = lang_data[target_language_name] translation_tokenizer.src_lang = source_lang_code encoded_text = translation_tokenizer(text_to_translate, return_tensors="pt").to( "cuda" ) # Move input to GPU target_lang_id = translation_tokenizer.lang_code_to_id[target_lang_code] # Generate translation on GPU with torch.no_grad(): # Use no_grad for inference generated_tokens = translation_model.generate( **encoded_text, forced_bos_token_id=target_lang_id, max_length=200 ) translated_text = translation_tokenizer.batch_decode( generated_tokens, skip_special_tokens=True ) # Clean up GPU memory del encoded_text, generated_tokens torch.cuda.empty_cache() gc.collect() return translated_text[0] except KeyError as e: return f"Error: Language code not found for {e}. Check language mappings." except Exception as e: print(f"Translation error: {e}") # Clean up GPU memory on error too torch.cuda.empty_cache() gc.collect() return f"An error occurred during translation: {e}" # --- Main Function Router (for image tasks) --- # Note: Translation uses its own function directly @spaces.GPU(duration=120) # Keep GPU decorator if needed for image tasks def main(*args): api_num = args[0] args = args[1:] gc.collect() # Try to collect garbage before starting task torch.cuda.empty_cache() # Clear cache before starting task result = None try: if api_num == 1: result = rmbg(*args) elif api_num == 2: result = outpaint(*args) elif api_num == 3: result = inpaint(*args) # elif api_num == 4: # Keep commented out as in original # return mask_generation(*args) elif api_num == 5: result = erase(*args) else: result = "Invalid API number." except Exception as e: print(f"Error in main task routing (api_num={api_num}): {e}") result = f"An error occurred: {e}" finally: # Ensure memory cleanup happens even if there's an error gc.collect() torch.cuda.empty_cache() return result # --- Define Gradio Interfaces for Each Tab --- # Image Task Tabs rmbg_tab = gr.Interface( fn=main, inputs=[ gr.Number(1, interactive=False, visible=False), # Hide API number gr.Image(label="Input Image", type="pil", sources=["upload", "clipboard"]), gr.Text(label="Or Image URL (optional)"), ], outputs=gr.Image(label="Output Image", type="pil"), title="Remove Background", description="Upload an image or provide a URL to remove its background.", api_name="rmbg", # examples=[[1, "./assets/sample_rmbg.png", ""]], # Update example path if needed cache_examples=False, ) outpaint_tab = gr.Interface( fn=main, inputs=[ gr.Number(2, interactive=False, visible=False), gr.Image(label="Input Image", type="pil", sources=["upload", "clipboard"]), gr.Number(value=0, label="Padding Top (pixels)"), gr.Number(value=0, label="Padding Bottom (pixels)"), gr.Number(value=0, label="Padding Left (pixels)"), gr.Number(value=0, label="Padding Right (pixels)"), gr.Text( label="Prompt (optional)", info="Describe what to fill the extended area with", ), gr.Slider( minimum=10, maximum=100, step=1, value=28, label="Inference Steps" ), # Use slider for steps gr.Slider( minimum=1, maximum=100, step=1, value=50, label="Guidance Scale" ), # Use slider for guidance ], outputs=gr.Image(label="Outpainted Image", type="pil"), title="Outpainting", description="Extend an image by adding padding and filling the new area using a diffusion model.", api_name="outpainting", # examples=[[2, "./assets/rocket.png", 100, 0, 0, 0, "", 28, 50]], # Update example path cache_examples=False, ) inpaint_tab = gr.Interface( fn=main, inputs=[ gr.Number(3, interactive=False, visible=False), gr.Image(label="Input Image", type="pil", sources=["upload", "clipboard"]), gr.Image( label="Mask Image (White=Inpaint Area)", type="pil", sources=["upload", "clipboard"], ), gr.Text( label="Prompt (optional)", info="Describe what to fill the masked area with" ), gr.Slider(minimum=10, maximum=100, step=1, value=28, label="Inference Steps"), gr.Slider(minimum=1, maximum=100, step=1, value=50, label="Guidance Scale"), ], outputs=gr.Image(label="Inpainted Image", type="pil"), title="Inpainting", description="Fill in the white areas of a mask applied to an image using a diffusion model.", api_name="inpaint", # examples=[[3, "./assets/rocket.png", "./assets/Inpainting_mask.png", "", 28, 50]], # Update example paths cache_examples=False, ) erase_tab = gr.Interface( fn=main, inputs=[ gr.Number(5, interactive=False, visible=False), gr.Image(label="Input Image", type="pil", sources=["upload", "clipboard"]), gr.Image( label="Mask Image (White=Erase Area)", type="pil", sources=["upload", "clipboard"], ), ], outputs=gr.Image(label="Result Image", type="pil"), title="Erase Object (LAMA)", description="Erase objects from an image based on a mask using the LaMa inpainting model.", api_name="erase", # examples=[[5, "./assets/rocket.png", "./assets/Inpainting_mask.png"]], # Update example paths cache_examples=False, ) # --- Define Translation Tab using gr.Blocks --- with gr.Blocks() as translation_tab: gr.Markdown( """ ## Multilingual Translation (mBART-50) Translate text between 50 different languages. Select the source and target languages, enter your text, and click Translate. """ ) with gr.Row(): with gr.Column(scale=1): source_lang_dropdown = gr.Dropdown( label="Source Language", choices=language_names, info="Select the language of your input text.", ) target_lang_dropdown = gr.Dropdown( label="Target Language", choices=language_names, info="Select the language you want to translate to.", ) with gr.Column(scale=2): input_textbox = gr.Textbox( label="Text to Translate", lines=6, # Increased lines placeholder="Enter text here...", ) translate_button = gr.Button( "Translate", variant="primary" ) # Added variant output_textbox = gr.Textbox( label="Translated Text", lines=6, # Increased lines interactive=False, # Make output read-only ) # Connect Components to the translation function directly translate_button.click( fn=translate_text, inputs=[input_textbox, source_lang_dropdown, target_lang_dropdown], outputs=output_textbox, api_name="translate", # Add API name for the translation endpoint ) # Add Translation Examples gr.Examples( examples=[ [ "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है", "Hindi", "French", ], [ "الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا.", "Arabic", "English", ], [ "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie.", "French", "German", ], ["Hello world! How are you today?", "English", "Spanish"], ["Guten Tag!", "German", "Japanese"], ["これはテストです", "Japanese", "English"], ], inputs=[input_textbox, source_lang_dropdown, target_lang_dropdown], outputs=output_textbox, fn=translate_text, cache_examples=False, ) # --- Combine all tabs --- demo = gr.TabbedInterface( [ rmbg_tab, outpaint_tab, inpaint_tab, erase_tab, translation_tab, # Add the translation tab # sam2_tab, # Keep commented out ], [ "Remove Background", # Tab title "Outpainting", # Tab title "Inpainting", # Tab title "Erase (LAMA)", # Tab title "Translate", # Tab title for translation # "sam2", ], title="Image & Text Utilities (GPU)", # Updated title ) demo.launch()