gpu-utils / app.py
not-lain's picture
add translation tab
c094f91
raw
history blame
13.4 kB
import gradio as gr
import spaces
import torch
from loadimg import load_img
from torchvision import transforms
from transformers import (
AutoModelForImageSegmentation,
pipeline,
MBartForConditionalGeneration,
MBart50TokenizerFast,
)
from diffusers import FluxFillPipeline
from PIL import Image, ImageOps
# from sam2.sam2_image_predictor import SAM2ImagePredictor
import numpy as np
from simple_lama_inpainting import SimpleLama
from contextlib import contextmanager
# import whisperx
import gc
@contextmanager
def float32_high_matmul_precision():
torch.set_float32_matmul_precision("high")
try:
yield
finally:
torch.set_float32_matmul_precision("highest")
pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
).to("cuda")
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def prepare_image_and_mask(
image,
padding_top=0,
padding_bottom=0,
padding_left=0,
padding_right=0,
):
image = load_img(image).convert("RGB")
# expand image (left,top,right,bottom)
background = ImageOps.expand(
image,
border=(padding_left, padding_top, padding_right, padding_bottom),
fill="white",
)
mask = Image.new("RGB", image.size, "black")
mask = ImageOps.expand(
mask,
border=(padding_left, padding_top, padding_right, padding_bottom),
fill="white",
)
return background, mask
def outpaint(
image,
padding_top=0,
padding_bottom=0,
padding_left=0,
padding_right=0,
prompt="",
num_inference_steps=28,
guidance_scale=50,
):
background, mask = prepare_image_and_mask(
image, padding_top, padding_bottom, padding_left, padding_right
)
result = pipe(
prompt=prompt,
height=background.height,
width=background.width,
image=background,
mask_image=mask,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
result = result.convert("RGBA")
return result
def inpaint(
image,
mask,
prompt="",
num_inference_steps=28,
guidance_scale=50,
):
background = image.convert("RGB")
mask = mask.convert("L")
result = pipe(
prompt=prompt,
height=background.height,
width=background.width,
image=background,
mask_image=mask,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
result = result.convert("RGBA")
return result
def rmbg(image=None, url=None):
if image is None:
image = url
image = load_img(image).convert("RGB")
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to("cuda")
with float32_high_matmul_precision():
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
# def mask_generation(image=None, d=None):
# # use bfloat16 for the entire notebook
# # torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
# # # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
# # if torch.cuda.get_device_properties(0).major >= 8:
# # torch.backends.cuda.matmul.allow_tf32 = True
# # torch.backends.cudnn.allow_tf32 = True
# d = eval(d) # convert this to dictionary
# with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
# predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
# predictor.set_image(image)
# input_point = np.array(d["input_points"])
# input_label = np.array(d["input_labels"])
# masks, scores, logits = predictor.predict(
# point_coords=input_point,
# point_labels=input_label,
# multimask_output=True,
# )
# sorted_ind = np.argsort(scores)[::-1]
# masks = masks[sorted_ind]
# scores = scores[sorted_ind]
# logits = logits[sorted_ind]
# out = []
# for i in range(len(masks)):
# m = Image.fromarray(masks[i] * 255).convert("L")
# comp = Image.composite(image, m, m)
# out.append((comp, f"image {i}"))
# return out
def erase(image=None, mask=None):
simple_lama = SimpleLama()
image = load_img(image)
mask = load_img(mask).convert("L")
return simple_lama(image, mask)
# def transcribe(audio):
# if audio is None:
# raise gr.Error("No audio file submitted!")
# device = "cuda" if torch.cuda.is_available() else "cpu"
# compute_type = "float16"
# batch_size = 8 # reduced batch size to be conservative with memory
# try:
# # 1. Load model and transcribe
# model = whisperx.load_model("large-v2", device, compute_type=compute_type)
# audio_input = whisperx.load_audio(audio)
# result = model.transcribe(audio_input, batch_size=batch_size)
# # Clear GPU memory
# del model
# gc.collect()
# torch.cuda.empty_cache()
# # 2. Align whisper output
# model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
# result = whisperx.align(result["segments"], model_a, metadata, audio_input, device, return_char_alignments=False)
# # Clear GPU memory
# del model_a
# gc.collect()
# torch.cuda.empty_cache()
# # 3. Assign speaker labels
# diarize_model = whisperx.DiarizationPipeline(device=device)
# diarize_segments = diarize_model(audio_input)
# # Combine transcription with speaker diarization
# result = whisperx.assign_word_speakers(diarize_segments, result)
# # Format output with speaker labels and timestamps
# formatted_text = []
# for segment in result["segments"]:
# if not isinstance(segment, dict):
# continue
# speaker = f"[Speaker {segment.get('speaker', 'Unknown')}]"
# start_time = f"{float(segment.get('start', 0)):.2f}"
# end_time = f"{float(segment.get('end', 0)):.2f}"
# text = segment.get('text', '').strip()
# formatted_text.append(f"[{start_time}s - {end_time}s] {speaker}: {text}")
# return "\n".join(formatted_text)
# except Exception as e:
# raise gr.Error(f"Transcription failed: {str(e)}")
# finally:
# # Ensure GPU memory is cleared even if an error occurs
# gc.collect()
# torch.cuda.empty_cache()
def translate_text(text, source_lang, target_lang):
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
# Set source language
tokenizer.src_lang = source_lang
# Encode the input text
encoded_text = tokenizer(text, return_tensors="pt")
# Generate translation
generated_tokens = model.generate(
**encoded_text,
forced_bos_token_id=tokenizer.lang_code_to_id[target_lang]
)
# Decode the generated tokens
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
# Clear GPU memory
del model
gc.collect()
torch.cuda.empty_cache()
return translation
@spaces.GPU(duration=120)
def main(*args):
api_num = args[0]
args = args[1:]
if api_num == 1:
return rmbg(*args)
elif api_num == 2:
return outpaint(*args)
elif api_num == 3:
return inpaint(*args)
# elif api_num == 4:
# return mask_generation(*args)
elif api_num == 5:
return erase(*args)
# elif api_num == 6:
# return transcribe(*args)
elif api_num == 7:
return translate_text(*args)
rmbg_tab = gr.Interface(
fn=main,
inputs=[
gr.Number(1, interactive=False),
"image",
gr.Text("", label="url"),
],
outputs=["image"],
api_name="rmbg",
examples=[[1, "./assets/Inpainting mask.png", ""]],
cache_examples=False,
description="pass an image or a url of an image",
)
outpaint_tab = gr.Interface(
fn=main,
inputs=[
gr.Number(2, interactive=False),
gr.Image(label="image", type="pil"),
gr.Number(label="padding top"),
gr.Number(label="padding bottom"),
gr.Number(label="padding left"),
gr.Number(label="padding right"),
gr.Text(label="prompt"),
gr.Number(value=50, label="num_inference_steps"),
gr.Number(value=28, label="guidance_scale"),
],
outputs=["image"],
api_name="outpainting",
examples=[[2, "./assets/rocket.png", 100, 0, 0, 0, "", 50, 28]],
cache_examples=False,
)
inpaint_tab = gr.Interface(
fn=main,
inputs=[
gr.Number(3, interactive=False),
gr.Image(label="image", type="pil"),
gr.Image(label="mask", type="pil"),
gr.Text(label="prompt"),
gr.Number(value=50, label="num_inference_steps"),
gr.Number(value=28, label="guidance_scale"),
],
outputs=["image"],
api_name="inpaint",
examples=[[3, "./assets/rocket.png", "./assets/Inpainting mask.png"]],
cache_examples=False,
description="it is recommended that you use https://github.com/la-voliere/react-mask-editor when creating an image mask in JS and then inverse it before sending it to this space",
)
# sam2_tab = gr.Interface(
# main,
# inputs=[
# gr.Number(4, interactive=False),
# gr.Image(type="pil"),
# gr.Text(),
# ],
# outputs=gr.Gallery(),
# examples=[
# [
# 4,
# "./assets/truck.jpg",
# '{"input_points": [[500, 375], [1125, 625]], "input_labels": [1, 0]}',
# ]
# ],
# api_name="sam2",
# cache_examples=False,
# )
erase_tab = gr.Interface(
main,
inputs=[
gr.Number(5, interactive=False),
gr.Image(type="pil"),
gr.Image(type="pil"),
],
outputs=gr.Image(),
examples=[
[
5,
"./assets/rocket.png",
"./assets/Inpainting mask.png",
]
],
api_name="erase",
cache_examples=False,
)
transcribe_tab = gr.Interface(
fn=main,
inputs=[
gr.Number(value=6, interactive=False), # API number
gr.Audio(type="filepath", label="Audio File"),
],
outputs=gr.Textbox(label="Transcription"),
title="Audio Transcription",
description="Upload an audio file to extract text using WhisperX with speaker diarization",
api_name="transcribe",
examples=[],
)
translate_tab = gr.Interface(
fn=main,
inputs=[
gr.Number(value=7, interactive=False),
gr.Textbox(label="Text to translate"),
gr.Dropdown(
choices=[
"ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX",
"gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV",
"my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN",
"zh_CN", "af_ZA", "az_AZ", "bn_IN", "fa_IR", "he_IL", "hr_HR", "id_ID",
"ka_GE", "km_KH", "mk_MK", "ml_IN", "mn_MN", "mr_IN", "pl_PL", "ps_AF",
"pt_XX", "sv_SE", "sw_KE", "ta_IN", "te_IN", "th_TH", "tl_XX", "uk_UA",
"ur_PK", "xh_ZA", "gl_ES", "sl_SI"
],
label="Source Language",
value="en_XX"
),
gr.Dropdown(
choices=[
"ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX",
"gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV",
"my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN",
"zh_CN", "af_ZA", "az_AZ", "bn_IN", "fa_IR", "he_IL", "hr_HR", "id_ID",
"ka_GE", "km_KH", "mk_MK", "ml_IN", "mn_MN", "mr_IN", "pl_PL", "ps_AF",
"pt_XX", "sv_SE", "sw_KE", "ta_IN", "te_IN", "th_TH", "tl_XX", "uk_UA",
"ur_PK", "xh_ZA", "gl_ES", "sl_SI"
],
label="Target Language",
value="fr_XX"
)
],
outputs=gr.Textbox(label="Translated Text"),
title="Text Translation",
description="Translate text between multiple languages using mBART-50",
api_name="translate",
examples=[
[7, "Hello, how are you?", "en_XX", "fr_XX"],
[7, "Bonjour, comment allez-vous?", "fr_XX", "en_XX"]
]
)
demo = gr.TabbedInterface(
[
rmbg_tab,
outpaint_tab,
inpaint_tab,
erase_tab,
transcribe_tab,
translate_tab
],
[
"remove background",
"outpainting",
"inpainting",
"erase",
"transcribe",
"translate"
],
title="Utilities that require GPU",
)
demo.launch()