abuhanzala commited on
Commit
192180f
·
verified ·
1 Parent(s): 46926cf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -0
app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
+
5
+ # Load model
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+ model_name = "Ateeqq/Text-Rewriter-Paraphraser"
8
+
9
+ @st.cache_resource
10
+ def load_model():
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
13
+ return tokenizer, model
14
+
15
+ tokenizer, model = load_model()
16
+
17
+ # Rewrite function
18
+ def rewrite(text):
19
+ input_ids = tokenizer(f"paraphraser: {text}", return_tensors="pt", truncation=True, max_length=1024).input_ids.to(device)
20
+ output = model.generate(
21
+ input_ids=input_ids,
22
+ num_beams=5,
23
+ no_repeat_ngram_size=3,
24
+ temperature=0.9,
25
+ max_length=1024,
26
+ early_stopping=True,
27
+ eos_token_id=tokenizer.eos_token_id
28
+ )
29
+ return tokenizer.decode(output[0], skip_special_tokens=True)
30
+
31
+ # UI
32
+ st.title("📝 Text Rewriter (Paraphraser)")
33
+ text_input = st.text_area("Enter text to rewrite:", height=300)
34
+ if st.button("Rewrite"):
35
+ with st.spinner("Rewriting..."):
36
+ result = rewrite(text_input)
37
+ st.success("Done!")
38
+ st.markdown("### 🔁 Rewritten Text")
39
+ st.write(result)