d0r1h commited on
Commit
38e9364
·
1 Parent(s): 7507971

Update Summarizer/Extractive.py

Browse files
Files changed (1) hide show
  1. Summarizer/Extractive.py +46 -12
Summarizer/Extractive.py CHANGED
@@ -1,21 +1,55 @@
1
  import nltk
2
- from sumy.parsers.plaintext import PlaintextParser
3
- from sumy.summarizers.luhn import LuhnSummarizer
4
  from sumy.nlp.tokenizers import Tokenizer
 
 
 
 
 
5
 
6
  nltk.download('punkt')
7
 
8
- def summarize(file, SENTENCES_COUNT):
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- sumarizer = LuhnSummarizer()
11
- with open(file.name) as f:
12
- doc = f.read()
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- sentences_ = []
15
- doc_ = PlaintextParser(doc, Tokenizer("en")).document
16
- for sentence in sumarizer(doc_, SENTENCES_COUNT):
17
- sentences_.append(str(sentence))
18
 
19
- summm_ = " ".join(sentences_)
 
 
 
 
 
 
 
 
 
20
 
21
- return summm_
 
1
  import nltk
2
+ from summarizer import Summarizer
 
3
  from sumy.nlp.tokenizers import Tokenizer
4
+ from sumy.summarizers.lsa import LsaSummarizer
5
+ from sumy.parsers.plaintext import PlaintextParser
6
+ from sumy.summarizers.lex_rank import LexRankSummarizer
7
+ from sumy.summarizers.sum_basic import SumBasicSummarizer
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
9
 
10
  nltk.download('punkt')
11
 
12
+ def extractive(method, file):
13
+ sumarizer = method
14
+ sentences_ = []
15
+ doc_ = PlaintextParser(file, Tokenizer("en")).document
16
+ for sentence in sumarizer(doc_, 5):
17
+ sentences_.append(str(sentence))
18
+ summm_ = " ".join(sentences_)
19
+ return summm_
20
+
21
+ def summarize(file, model):
22
+
23
+ with open(file.name) as f:
24
+ doc = f.read()
25
 
26
+ if model == "Pegasus":
27
+ checkpoint = "google/pegasus-billsum"
28
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
29
+ model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
30
+ inputs = tokenizer(doc,
31
+ max_length=1024,
32
+ truncation=True,
33
+ return_tensors="pt")
34
+
35
+ summary_ids = model.generate(inputs["input_ids"])
36
+ summary = tokenizer.batch_decode(summary_ids,
37
+ skip_special_tokens=True,
38
+ clean_up_tokenization_spaces=False)
39
+ summary = summary[0]
40
 
41
+ elif model == "TextRank":
42
+ summary = extractive(LexRankSummarizer(), doc)
 
 
43
 
44
+ elif model == "SumBasic":
45
+ summary = extractive(SumBasicSummarizer(), doc)
46
+
47
+ elif model == "Lsa":
48
+ summary = extractive(LsaSummarizer(), doc)
49
+
50
+ elif model == "BERT":
51
+ modelbert = Summarizer('distilbert-base-uncased', hidden=[-1,-2], hidden_concat=True)
52
+ result = modelbert(doc)
53
+ summary = ''.join(result)
54
 
55
+ return summary