File size: 5,225 Bytes
aa440ce edd9b00 aa440ce 0e85e84 aa440ce edd9b00 87cca9a aa440ce 4c03781 aa440ce ddf9fae aa440ce 07c279b edd9b00 aa440ce edd9b00 aa440ce edd9b00 aa440ce edd9b00 aa440ce 0b80303 edd9b00 aa440ce edd9b00 aa440ce a768034 edd9b00 aa440ce 77d43d7 aa440ce 1d7dfdd b36cee2 edd9b00 07c279b edd9b00 1d7dfdd edd9b00 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
# -*- coding: utf-8 -*-
import argparse
import re
import os
import streamlit as st
import random
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import tokenizers
os.environ["TOKENIZERS_PARALLELISM"] = "false"
random.seed(None)
suggested_text_list = ['ืคืขื ืืืช, ืืคื ื ืฉื ืื ืจืืืช','ืฉืืื, ืงืืจืืื ืื ืืืจืื ืืื ื','ืืืงืจ ืืื ืืืืื','ืืื ืืคืจืชื ืืช ืื ืืืื ืืืงืก ืืฉ']
@st.cache_resource
def load_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
return model, tokenizer
def extend(input_text, max_size=20, top_k=50, top_p=0.95):
if len(input_text) == 0:
input_text = ""
encoded_prompt = tokenizer.encode(
input_text, add_special_tokens=False, return_tensors="pt")
encoded_prompt = encoded_prompt.to(device)
if encoded_prompt.size()[-1] == 0:
input_ids = None
else:
input_ids = encoded_prompt
output_sequences = model.generate(
input_ids=input_ids,
max_length=max_size + len(encoded_prompt[0]),
top_k=top_k,
top_p=top_p,
do_sample=True,
repetition_penalty=2.5,
num_return_sequences=1)
# Remove the batch dimension when returning multiple sequences
if len(output_sequences.shape) > 2:
output_sequences.squeeze_()
generated_sequences = []
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
generated_sequence = generated_sequence.tolist()
# Decode text
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
# Remove all text after the stop token
text = text[: text.find(stop_token) if stop_token else None]
# Remove all text after 3 newlines
text = text[: text.find(new_lines) if new_lines else None]
# Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
total_sequence = (
input_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
)
generated_sequences.append(total_sequence)
parsed_text = total_sequence.replace("<|startoftext|>", "").replace("\r","").replace("\n\n", "\n")
if len(parsed_text) == 0:
parsed_text = "ืฉืืืื"
return parsed_text
if __name__ == "__main__":
st.title("Hebrew GPT Neo (Small)")
pre_model_path = "Norod78/hebrew-gpt_neo-small"
model, tokenizer = load_model(pre_model_path)
stop_token = "<|endoftext|>"
new_lines = "\n\n\n"
np.random.seed(None)
random_seed = np.random.randint(10000,size=1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = 0 if torch.cuda.is_available()==False else torch.cuda.device_count()
torch.manual_seed(random_seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(random_seed)
model.to(device)
text_area = st.text_area("Enter the first few words (or leave blank), tap on \"Generate Text\" below. Tapping again will produce a different result.", 'ืืืืฉ ืืืืจืื ืืขืืื ืืฉื ืืื ืืืืจื ืืฉืืคืชืข ื ืฉืืขื ื ืงืืฉื')
st.sidebar.subheader("Configurable parameters")
max_len = st.sidebar.slider("Max-Length", 0, 192, 96,help="The maximum length of the sequence to be generated.")
top_k = st.sidebar.slider("Top-K", 0, 100, 40, help="The number of highest probability vocabulary tokens to keep for top-k-filtering.")
top_p = st.sidebar.slider("Top-P", 0.0, 1.0, 0.92, help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.")
if st.button("Generate Text"):
with st.spinner(text="Generating results..."):
st.subheader("Result")
print(f"device:{device}, n_gpu:{n_gpu}, random_seed:{random_seed}, maxlen:{max_len}, top_k:{top_k}, top_p:{top_p}")
if len(text_area.strip()) == 0:
text_area = random.choice(suggested_text_list)
result = extend(input_text=text_area,
max_size=int(max_len),
top_k=int(top_k),
top_p=float(top_p))
print("Done length: " + str(len(result)) + " bytes")
#<div class="rtl" dir="rtl" style="text-align:right;">
st.markdown(f"<p dir=\"rtl\" style=\"text-align:right;\"> {result} </p>", unsafe_allow_html=True)
st.write("\n\nResult length: " + str(len(result)) + " bytes")
print(f"\"{result}\"")
st.markdown(
"""Hebrew text generation model (125M parameters) based on EleutherAI's gpt-neo architecture. Originally trained on a TPUv3-8 which was made avilable to me via the [TPU Research Cloud Program](https://sites.research.google/trc/)."""
)
st.markdown("<footer><hr><p style=\"font-size:14px\">Enjoy</p><p style=\"font-size:12px\">Created by <a href=\"https://linktr.ee/Norod78\">Doron Adler</a></p></footer> ", unsafe_allow_html=True) |