|
import gradio as gr |
|
from huggingface_hub import InferenceClient |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
|
|
|
tokenizer_eng_to_darija = AutoTokenizer.from_pretrained("Saidtaoussi/AraT5_Darija_to_MSA") |
|
model_eng_to_darija = AutoModelForSeq2SeqLM.from_pretrained("Saidtaoussi/AraT5_Darija_to_MSA") |
|
|
|
|
|
tokenizer_darija_to_msa = AutoTokenizer.from_pretrained("lachkarsalim/Helsinki-translation-English_Moroccan-Arabic") |
|
model_darija_to_msa = AutoModelForSeq2SeqLM.from_pretrained("lachkarsalim/Helsinki-translation-English_Moroccan-Arabic") |
|
|
|
def respond( |
|
message, |
|
history: list[tuple[str, str]], |
|
system_message, |
|
max_tokens, |
|
temperature, |
|
top_p, |
|
translation_choice: str, |
|
): |
|
""" |
|
Responds to the input message by selecting the translation model based on the user's choice. |
|
""" |
|
messages = [{"role": "system", "content": system_message}] |
|
|
|
for val in history: |
|
if val[0]: |
|
messages.append({"role": "user", "content": val[0]}) |
|
if val[1]: |
|
messages.append({"role": "assistant", "content": val[1]}) |
|
|
|
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
response = "" |
|
|
|
|
|
if translation_choice == "Moroccan Arabic to MSA": |
|
|
|
inputs = tokenizer_darija_to_msa(message, return_tensors="pt", padding=True) |
|
outputs = model_darija_to_msa.generate(inputs["input_ids"], num_beams=5, max_length=512, early_stopping=True) |
|
response = tokenizer_darija_to_msa.decode(outputs[0], skip_special_tokens=True) |
|
|
|
elif translation_choice == "English to Moroccan Arabic": |
|
|
|
inputs = tokenizer_eng_to_darija(message, return_tensors="pt", padding=True) |
|
outputs = model_eng_to_darija.generate(inputs["input_ids"], num_beams=5, max_length=512, early_stopping=True) |
|
response = tokenizer_eng_to_darija.decode(outputs[0], skip_special_tokens=True) |
|
|
|
return response |
|
|
|
|
|
|
|
demo = gr.ChatInterface( |
|
respond, |
|
additional_inputs=[ |
|
gr.Textbox(value="You are a friendly Chatbot.", label="System message"), |
|
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), |
|
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), |
|
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"), |
|
gr.Dropdown( |
|
label="Choose Translation Direction", |
|
choices=["English to Moroccan Arabic", "Moroccan Arabic to MSA"], |
|
value="English to Moroccan Arabic" |
|
), |
|
], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|
|
|