|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
|
|
|
|
@st.cache_resource |
|
def load_re_punctuate_model(): |
|
tokenizer = AutoTokenizer.from_pretrained("SJ-Ray/Re-Punctuate") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("SJ-Ray/Re-Punctuate") |
|
return tokenizer, model |
|
|
|
|
|
@st.cache_resource |
|
def load_headline_model(): |
|
tokenizer = AutoTokenizer.from_pretrained("Michau/t5-base-en-generate-headline") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("Michau/t5-base-en-generate-headline") |
|
return tokenizer, model |
|
|
|
|
|
def re_punctuate_text(tokenizer, model, text): |
|
inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True) |
|
outputs = model.generate(inputs["input_ids"], max_length=512, num_beams=4, early_stopping=True) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
def generate_headline_text(tokenizer, model, text, max_length=50): |
|
inputs = tokenizer(f"headline: {text}", return_tensors="pt", truncation=True, padding=True) |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_length=max_length, |
|
num_beams=5, |
|
no_repeat_ngram_size=2, |
|
early_stopping=True |
|
) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
st.title("Model Selection: Re-Punctuate or Generate Headline") |
|
|
|
|
|
model_options = ["Re-Punctuate Text", "Generate Headline"] |
|
selected_model = st.selectbox("Choose a model to use:", model_options) |
|
|
|
|
|
input_text = st.text_area("Enter text:", placeholder="Type your input here...") |
|
|
|
|
|
if st.button("Process Text") and input_text: |
|
with st.spinner("Processing..."): |
|
if selected_model == "Re-Punctuate Text": |
|
tokenizer, model = load_re_punctuate_model() |
|
result = re_punctuate_text(tokenizer, model, input_text) |
|
else: |
|
tokenizer, model = load_headline_model() |
|
result = generate_headline_text(tokenizer, model, input_text) |
|
|
|
|
|
st.subheader(f"Result from {selected_model}:") |
|
st.write(result) |
|
|
|
|
|
st.write("---") |
|
st.write("Powered by Hugging Face Models.") |
|
|