Spaces:
Running
Running
| import streamlit as st | |
| from .services import TextGeneration | |
| from tokenizers import Tokenizer | |
| from functools import lru_cache | |
| # @st.cache(allow_output_mutation=False, hash_funcs={Tokenizer: str}) | |
| def load_text_generator(): | |
| generator = TextGeneration() | |
| generator.load() | |
| return generator | |
| generator = load_text_generator() | |
| qa_prompt = """ | |
| أجب عن السؤال التالي: | |
| """ | |
| qa_prompt_post = """ الجواب هو """ | |
| qa_prompt_post_year = """ في سنة: """ | |
| def write(): | |
| st.markdown( | |
| """ | |
| <h1 style="text-align:left;">Arabic Language Generation</h1> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # Sidebar | |
| # Taken from https://huggingface.co/spaces/flax-community/spanish-gpt2/blob/main/app.py | |
| st.sidebar.subheader("Configurable parameters") | |
| model_name = st.sidebar.selectbox( | |
| "Model Selector", | |
| options=[ | |
| "AraGPT2-Base", | |
| # "AraGPT2-Medium", | |
| # "Aragpt2-Large", | |
| "AraGPT2-Mega", | |
| ], | |
| index=0, | |
| ) | |
| max_new_tokens = st.sidebar.number_input( | |
| "Maximum length", | |
| min_value=0, | |
| max_value=1024, | |
| value=100, | |
| help="The maximum length of the sequence to be generated.", | |
| ) | |
| temp = st.sidebar.slider( | |
| "Temperature", | |
| value=1.0, | |
| min_value=0.1, | |
| max_value=100.0, | |
| help="The value used to module the next token probabilities.", | |
| ) | |
| top_k = st.sidebar.number_input( | |
| "Top k", | |
| value=10, | |
| help="The number of highest probability vocabulary tokens to keep for top-k-filtering.", | |
| ) | |
| top_p = st.sidebar.number_input( | |
| "Top p", | |
| value=0.95, | |
| help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.", | |
| ) | |
| do_sample = st.sidebar.selectbox( | |
| "Sampling?", | |
| (True, False), | |
| help="Whether or not to use sampling; use greedy decoding otherwise.", | |
| ) | |
| num_beams = st.sidebar.number_input( | |
| "Number of beams", | |
| min_value=1, | |
| max_value=10, | |
| value=3, | |
| help="The number of beams to use for beam search.", | |
| ) | |
| repetition_penalty = st.sidebar.number_input( | |
| "Repetition Penalty", | |
| min_value=0.0, | |
| value=3.0, | |
| step=0.1, | |
| help="The parameter for repetition penalty. 1.0 means no penalty", | |
| ) | |
| no_repeat_ngram_size = st.sidebar.number_input( | |
| "No Repeat N-Gram Size", | |
| min_value=0, | |
| value=3, | |
| help="If set to int > 0, all ngrams of that size can only occur once.", | |
| ) | |
| st.write("#") | |
| col = st.columns(2) | |
| col[0].image("images/AraGPT2.png", width=200) | |
| st.markdown( | |
| """ | |
| <h3 style="text-align:left;">AraGPT2 is GPT2 model trained from scratch on 77GB of Arabic text.</h3> | |
| <h4 style="text-align:left;"> More details in our <a href="https://github.com/aub-mind/arabert/tree/master/aragpt2">repo</a>.</h4> | |
| <p style="text-align:left;"><p> | |
| <p style="text-align:left;">Use the generation paramters on the sidebar to adjust generation quality.</p> | |
| <p style="text-align:right;"><p> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # col[0].write( | |
| # "AraGPT2 is trained from screatch on 77GB of Arabic text. More details in our [repo](https://github.com/aub-mind/arabert/tree/master/aragpt2)." | |
| # ) | |
| # st.write("## Generate Arabic Text") | |
| st.markdown( | |
| """ | |
| <style> | |
| p, div, input, label, textarea{ | |
| text-align: right; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| prompt = st.text_area( | |
| "Prompt", | |
| "يحكى أن مزارعا مخادعا قام ببيع بئر الماء الموجود في أرضه لجاره مقابل مبلغ كبير من المال", | |
| ) | |
| if st.button("Generate"): | |
| with st.spinner("Generating..."): | |
| generated_text = generator.generate( | |
| prompt=prompt, | |
| model_name=model_name, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temp, | |
| top_k=top_k, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=do_sample, | |
| num_beams=num_beams, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| ) | |
| st.write(generated_text) | |
| st.markdown("---") | |
| st.subheader("") | |
| st.markdown( | |
| """ | |
| <p style="text-align:left;"><p> | |
| <h2 style="text-align:left;">Zero-Shot Question Answering</h2> | |
| <p style="text-align:left;">Adjust the maximum length to closely match the expected output length. Setting the Sampling paramter to False is recommended</p> | |
| <p style="text-align:left;"><p> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| question = st.text_input( | |
| "Question", "من كان رئيس ألمانيا النازية في الحرب العالمية الثانية ؟" | |
| ) | |
| is_date = st.checkbox("Help the model: Is the answer a date?") | |
| if st.button("Answer"): | |
| prompt2 = qa_prompt + question + qa_prompt_post | |
| if is_date: | |
| prompt2 += qa_prompt_post_year | |
| else: | |
| prompt2 += " : " | |
| with st.spinner("Thinking..."): | |
| answer = generator.generate( | |
| prompt=prompt2, | |
| model_name=model_name, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temp, | |
| top_k=top_k, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=do_sample, | |
| num_beams=num_beams, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| ) | |
| st.write(answer) | |