|
|
|
|
|
import argparse |
|
import re |
|
import os |
|
|
|
import streamlit as st |
|
import random |
|
import numpy as np |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_model(model_name): |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
return model, tokenizer |
|
|
|
|
|
def extend(input_text, max_size=20, top_k=50, top_p=0.95): |
|
if len(input_text) == 0: |
|
input_text = "ืฉื ืืืฆืืจื: " |
|
|
|
encoded_prompt = tokenizer.encode( |
|
input_text, add_special_tokens=False, return_tensors="pt") |
|
|
|
encoded_prompt = encoded_prompt.to(device) |
|
|
|
if encoded_prompt.size()[-1] == 0: |
|
input_ids = None |
|
else: |
|
input_ids = encoded_prompt |
|
|
|
output_sequences = model.generate( |
|
input_ids=input_ids, |
|
max_length=max_size + len(encoded_prompt[0]), |
|
top_k=top_k, |
|
top_p=top_p, |
|
do_sample=True, |
|
num_return_sequences=1) |
|
|
|
|
|
if len(output_sequences.shape) > 2: |
|
output_sequences.squeeze_() |
|
|
|
generated_sequences = [] |
|
|
|
for generated_sequence_idx, generated_sequence in enumerate(output_sequences): |
|
generated_sequence = generated_sequence.tolist() |
|
|
|
|
|
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) |
|
|
|
|
|
text = text[: text.find(stop_token) if stop_token else None] |
|
|
|
|
|
text = text[: text.find(new_lines) if new_lines else None] |
|
|
|
|
|
total_sequence = ( |
|
input_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :] |
|
) |
|
|
|
generated_sequences.append(total_sequence) |
|
|
|
parsed_text = total_sequence.replace("<|startoftext|>", "").replace("\r","").replace("\n\n", "\n") |
|
if len(parsed_text) == 0: |
|
parsed_text = "ืฉืืืื" |
|
return parsed_text |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
st.title("Hebrew Poetry - GPT Neo (Small)") |
|
|
|
model, tokenizer = load_model("Norod78/hebrew-gpt_neo-small") |
|
|
|
|
|
stop_token = "<|endoftext|>" |
|
new_lines = "\n\n\n" |
|
|
|
np.random.seed(None) |
|
random_seed = np.random.randint(10000,size=1) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
n_gpu = 0 if torch.cuda.is_available()==False else torch.cuda.device_count() |
|
|
|
torch.manual_seed(random_seed) |
|
if n_gpu > 0: |
|
torch.cuda.manual_seed_all(random_seed) |
|
|
|
model.to(device) |
|
|
|
st.sidebar.subheader("Configurable parameters") |
|
|
|
max_len = st.sidebar.slider("Max-Length", 0, 512, 256,help="The maximum length of the sequence to be generated.") |
|
top_k = st.sidebar.slider("Top-K", 0, 100, 50, help="The number of highest probability vocabulary tokens to keep for top-k-filtering.") |
|
top_p = st.sidebar.slider("Top-P", 0.0, 1.0, 0.95, help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.") |
|
|
|
st.markdown( |
|
"""Hebrew poetry text generation model based on EleutherAI's gpt-neo. Each was trained on a TPUv3-8 which was made avilable to me via the [TPU Research Cloud Program](https://sites.research.google/trc/). """ |
|
) |
|
|
|
prompt = "ืืืืฉ ืืืืจืื ืืขืืื ืืฉื ืืื ืืืืจื ืืฉืืคืชืข ื ืฉืืขื ื ืงืืฉื" |
|
text = st.text_area("Enter text", prompt) |
|
|
|
if st.button("Run"): |
|
with st.spinner(text="Generating results..."): |
|
st.subheader("Result") |
|
print(f"device:{device}, n_gpu:{n_gpu}, random_seed:{random_seed}, maxlen:{max_len}, top_k:{top_k}, top_p:{top_p}") |
|
result = extend(input_text=text, |
|
max_size=int(max_len), |
|
top_k=int(top_k), |
|
top_p=float(top_p)) |
|
|
|
print("Done length: " + str(len(result)) + " bytes") |
|
|
|
st.markdown(f"<p dir=\"rtl\" style=\"text-align:right;\"> {result} </p>", unsafe_allow_html=True) |
|
|
|
|