English2Shami / app.py
guymorlan's picture
Update app.py
84d9d62
raw
history blame
2.71 kB
import streamlit as st
from transformers import pipeline
import pandas as pd
import os
import azure.cognitiveservices.speech as speechsdk
import base64
import torch
dialects = {"Palestinian/Jordanian": "P", "Syrian": "S", "Lebanese": "L", "Egyptian": "E"}
pipeline = pipeline(task="translation", model="guymorlan/English2Dialect")
st.title("English to Levantine Arabic")
num_translations = st.sidebar.selectbox("Number of Translations Per Dialect:", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], index=0)
input_text = st.text_input("Enter English text:")
@st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda parameter: parameter.data.numpy()})
def get_translation(input_text, num_translations):
inputs = [f"{val} {input_text}" for val in dialects.values()]
result = pipeline(inputs, max_length=1024, num_return_sequences=num_translations, num_beams=max(num_translations, 5))
return result
if input_text:
result = get_translation(input_text, num_translations)
#df = pd.DataFrame({"Dialect": [x for x in dialects.keys()],
# "Translation": [x["translation_text"] for x in result]})
for i in range(len(result)):
# play = st.button("Play Audio (Machine Generated)")
st.markdown(f"<div style='font-size:24px'><b>{list(dialects.keys())[i]}:</b></div>", unsafe_allow_html=True)
if i == 0:
if num_translations > 1:
get = result[0][0]["translation_text"]
else:
get = result[0]["translation_text"]
speech_config = speechsdk.SpeechConfig(subscription=os.environ.get('SPEECH_KEY'), region=os.environ.get('SPEECH_REGION'))
audio_config = speechsdk.audio.AudioOutputConfig(filename=f"{input_text}.wav")
speech_config.speech_synthesis_voice_name='ar-SY-AmanyNeural'
speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=speech_config, audio_config=audio_config)
speech_synthesis_result = speech_synthesizer.speak_text_async(get).get()
audio_file = open(f"{input_text}.wav", "rb")
audio_bytes = audio_file.read()
#autoplay_audio(f"{input_text}.wav")
st.audio(audio_bytes, format="audio/mp3", start_time=0)
if num_translations > 1:
for j in range(num_translations):
st.markdown(f"<div style='font-size:24px; text-align:right; direction:rtl;'>{result[i][j]['translation_text']}</div>", unsafe_allow_html=True)
else:
st.markdown(f"<div style='font-size:24px; text-align:right; direction:rtl;'>{result[i]['translation_text']}</div>", unsafe_allow_html=True)
st.markdown("<br>", unsafe_allow_html=True)