|
import gradio as gr |
|
import os |
|
import torch |
|
import gradio as gr |
|
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda:0") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_1.2B") |
|
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_1.2B").to(device) |
|
model.eval() |
|
|
|
l1="Afrikaans" |
|
class Language: |
|
def __init__(self, name, code): |
|
self.name = name |
|
self.code = code |
|
|
|
|
|
|
|
|
|
|
|
|
|
lang_id = [ |
|
"":"", |
|
Language("Afrikaans", "af"), |
|
Language("Albanian", "sq"), |
|
Language("Amharic", "am"), |
|
Language("Arabic", "ar"), |
|
Language("Armenian", "hy"), |
|
Language("Asturian", "ast"), |
|
Language("Azerbaijani", "az"), |
|
Language("Bashkir", "ba"), |
|
Language("Belarusian", "be"), |
|
Language("Bulgarian", "bg"), |
|
Language("Bengali", "bn"), |
|
Language("Breton", "br"), |
|
Language("Bosnian", "bs"), |
|
Language("Burmese", "my"), |
|
Language("Catalan", "ca"), |
|
Language("Cebuano", "ceb"), |
|
Language("Chinese","zh"), |
|
Language("Croatian","hr"), |
|
Language("Czech","cs"), |
|
Language("Danish","da"), |
|
Language("Dutch","nl"), |
|
Language("English","en"), |
|
Language("Estonian","et"), |
|
Language("Fulah","ff"), |
|
Language("Finnish","fi"), |
|
Language("French","fr"), |
|
Language("Western Frisian","fy"), |
|
Language("Gaelic","gd"), |
|
Language("Galician","gl"), |
|
Language("Georgian","ka"), |
|
Language("German","de"), |
|
Language("Greek","el"), |
|
Language("Gujarati","gu"), |
|
Language("Hausa","ha"), |
|
Language("Hebrew","he"), |
|
Language("Hindi","hi"), |
|
Language("Haitian","ht"), |
|
Language("Hungarian","hu"), |
|
Language("Irish","ga"), |
|
Language("Indonesian","id"), |
|
Language("Igbo","ig"), |
|
Language("",""), |
|
Language("",""), |
|
Language("",""), |
|
Language("",""), |
|
Language("",""), |
|
Language("",""), |
|
Language("",""), |
|
Language("",""), |
|
Language("",""), |
|
Language("",""), |
|
Language("",""), |
|
Language("",""), |
|
Language("",""), |
|
"Iloko": "ilo", |
|
"Icelandic": "is", |
|
"Italian": "it", |
|
"Japanese": "ja", |
|
"Javanese": "jv", |
|
"Kazakh": "kk", |
|
"Central Khmer": "km", |
|
"Kannada": "kn", |
|
"Korean": "ko", |
|
"Luxembourgish": "lb", |
|
"Ganda": "lg", |
|
"Lingala": "ln", |
|
"Lao": "lo", |
|
"Lithuanian": "lt", |
|
"Latvian": "lv", |
|
"Malagasy": "mg", |
|
"Macedonian": "mk", |
|
"Malayalam": "ml", |
|
"Mongolian": "mn", |
|
"Marathi": "mr", |
|
"Malay": "ms", |
|
"Nepali": "ne", |
|
"Norwegian": "no", |
|
"Northern Sotho": "ns", |
|
"Occitan": "oc", |
|
"Oriya": "or", |
|
"Panjabi": "pa", |
|
"Persian": "fa", |
|
"Polish": "pl", |
|
"Pushto": "ps", |
|
"Portuguese": "pt", |
|
"Romanian": "ro", |
|
"Russian": "ru", |
|
"Sindhi": "sd", |
|
"Sinhala": "si", |
|
"Slovak": "sk", |
|
"Slovenian": "sl", |
|
"Spanish": "es", |
|
"Somali": "so", |
|
"Serbian": "sr", |
|
"Serbian (cyrillic)": "sr", |
|
"Serbian (latin)": "sr", |
|
"Swati": "ss", |
|
"Sundanese": "su", |
|
"Swedish": "sv", |
|
"Swahili": "sw", |
|
"Tamil": "ta", |
|
"Thai": "th", |
|
"Tagalog": "tl", |
|
"Tswana": "tn", |
|
"Turkish": "tr", |
|
"Ukrainian": "uk", |
|
"Urdu": "ur", |
|
"Uzbek": "uz", |
|
"Vietnamese": "vi", |
|
"Welsh": "cy", |
|
"Wolof": "wo", |
|
"Xhosa": "xh", |
|
"Yiddish": "yi", |
|
"Yoruba": "yo", |
|
"Zulu": "zu", |
|
] |
|
|
|
def trans_page(input,input1,trg): |
|
src_lang = lang_id["English"] |
|
trg_lang = lang_id[trg] |
|
if trg_lang != src_lang: |
|
|
|
tokenizer.src_lang = src_lang |
|
with torch.no_grad(): |
|
encoded_input = tokenizer(input, return_tensors="pt").to(device) |
|
generated_tokens = model.generate(**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang)) |
|
translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
else: |
|
translated_text=input |
|
pass |
|
if trg_lang != src_lang: |
|
|
|
tokenizer.src_lang = src_lang |
|
with torch.no_grad(): |
|
|
|
encoded_input = tokenizer(lang_id, return_tensors="pt").to(device) |
|
generated_tokens = model.generate(**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang)) |
|
translated_text1 = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
else: |
|
translated_text1=input1 |
|
pass |
|
return translated_text,gr.Dropdown.update(choices=list(translated_text1.keys())) |
|
|
|
|
|
def trans_to(input,src,trg): |
|
src_lang = lang_id[src] |
|
trg_lang = lang_id[trg] |
|
if trg_lang != src_lang: |
|
tokenizer.src_lang = src_lang |
|
with torch.no_grad(): |
|
encoded_input = tokenizer(input, return_tensors="pt").to(device) |
|
generated_tokens = model.generate(**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang)) |
|
translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
else: |
|
translated_text=input |
|
pass |
|
return translated_text |
|
|
|
md1 = "Translate - 100 Languages" |
|
|
|
|
|
|
|
with gr.Blocks() as transbot: |
|
|
|
with gr.Row(): |
|
gr.Column() |
|
with gr.Column(): |
|
with gr.Row(): |
|
t_space = gr.Dropdown(label="Translate Space", choices=list(lang_id.keys()),value="English") |
|
t_submit = gr.Button("Translate Space") |
|
gr.Column() |
|
|
|
with gr.Row(): |
|
gr.Column() |
|
with gr.Column(): |
|
md = gr.Markdown("""<h1><center>Translate - 100 Languages</center></h1><h4><center>Translation may not be accurate</center></h4>""") |
|
with gr.Row(): |
|
lang_from = gr.Dropdown(label="From:", choices=list(lang_id.keys()),value="English") |
|
lang_to = gr.Dropdown(label="To:", choices=list(lang_id.keys()),value="Chinese") |
|
submit = gr.Button("Go") |
|
with gr.Row(): |
|
with gr.Column(): |
|
message = gr.Textbox(label="Prompt",placeholder="Enter Prompt",lines=4) |
|
translated = gr.Textbox(label="Translated",lines=4,interactive=False) |
|
gr.Column() |
|
t_submit.click(trans_page,[md,lang_from,t_space],[md,lang_from]) |
|
|
|
submit.click(trans_to, inputs=[message,lang_from,lang_to], outputs=[translated]) |
|
transbot.queue(concurrency_count=20) |
|
transbot.launch() |