File size: 1,906 Bytes
32d9382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
import random
import numpy as np
from nltk.tokenize import sent_tokenize

class T5_Base:
    def __init__(self,path,device,model_max_length):
        self.model=T5ForConditionalGeneration.from_pretrained(path)
        self.tokenizer=T5Tokenizer.from_pretrained(path,model_max_length=model_max_length)
        self.device=torch.device(device)

    def set_seed(seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  

    def preprocess(self,data):
        preprocess_text=data.strip().replace('\n',' ')
        return preprocess_text

    def post_process(self,data):
        final=""
        for sent in sent_tokenize(data):
            sent=sent.capitalize()
            final+=sent+" "+sent
        return final    

    def getSummary(self,data):
        data=self.preprocess(data)
        t5_prepared_Data="summarize: "+data
        tokenized_text=self.tokenizer.encode_plus(t5_prepared_Data,max_length=512,pad_to_max_length=False,truncation=True,return_tensors='pt').to(self.device)
        input_ids,attention_mask=tokenized_text['input_ids'],tokenized_text['attention_mask']
        summary_ids=self.model.generate(input_ids=input_ids,
                                  attention_mask=attention_mask,
                                  early_stopping=True,
                                  num_beams=3,
                                  num_return_sequences=1,
                                  no_repeat_ngram_size=2,
                                  min_length = 75,
                                  max_length=300)

        output=[self.tokenizer.decode(ids,skip_special_tokens=True) for ids in summary_ids]
        summary=output[0]
        summary=self.post_process(summary)
        summary=summary.strip()
        return summary