Soft_Computing_Project / summarize.py
Akshat1000's picture
Update summarize.py
809119f verified
raw
history blame contribute delete
849 Bytes
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
tokenizer = T5Tokenizer.from_pretrained("t5-base")
model = T5ForConditionalGeneration.from_pretrained("t5-base")
def summarize_text(text, max_chunk_length=512):
text = text.replace("\n", " ")
chunks = [text[i:i+max_chunk_length] for i in range(0, len(text), max_chunk_length)]
summarized_chunks = []
for chunk in chunks:
input_text = "summarize: " + chunk
inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
summary_ids = model.generate(inputs, max_length=150, min_length=40, num_beams=4, length_penalty=2.0, early_stopping=True)
output = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
summarized_chunks.append(output)
return " ".join(summarized_chunks)