Krishnan Palanisami commited on
Commit
00896df
·
verified ·
1 Parent(s): 7dc93a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +475 -0
app.py CHANGED
@@ -5,6 +5,481 @@ from haystack.utils import clean_wiki_text, convert_files_to_docs
5
  from haystack.nodes import TfidfRetriever, FARMReader
6
  from haystack.pipelines import ExtractiveQAPipeline
7
  from main import print_qa, QuestionGenerator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def main():
10
  # Set the Streamlit app title
 
5
  from haystack.nodes import TfidfRetriever, FARMReader
6
  from haystack.pipelines import ExtractiveQAPipeline
7
  from main import print_qa, QuestionGenerator
8
+ import en_core_web_sm
9
+ import json
10
+ import numpy as np
11
+ import random
12
+ import re
13
+ import torch
14
+ from transformers import (
15
+ AutoTokenizer,
16
+ AutoModelForSeq2SeqLM,
17
+ AutoModelForSequenceClassification,
18
+ )
19
+ from typing import Any, List, Mapping, Tuple
20
+
21
+
22
+ class QuestionGenerator:
23
+ """A transformer-based NLP system for generating reading comprehension-style questions from
24
+ texts. It can generate full sentence questions, multiple choice questions, or a mix of the
25
+ two styles.
26
+
27
+ To filter out low quality questions, questions are assigned a score and ranked once they have
28
+ been generated. Only the top k questions will be returned. This behaviour can be turned off
29
+ by setting use_evaluator=False.
30
+ """
31
+
32
+ def __init__(self) -> None:
33
+
34
+ QG_PRETRAINED = "iarfmoose/t5-base-question-generator"
35
+ self.ANSWER_TOKEN = "<answer>"
36
+ self.CONTEXT_TOKEN = "<context>"
37
+ self.SEQ_LENGTH = 512
38
+
39
+ self.device = torch.device(
40
+ "cuda" if torch.cuda.is_available() else "cpu")
41
+
42
+ self.qg_tokenizer = AutoTokenizer.from_pretrained(
43
+ QG_PRETRAINED, use_fast=False)
44
+ self.qg_model = AutoModelForSeq2SeqLM.from_pretrained(QG_PRETRAINED)
45
+ self.qg_model.to(self.device)
46
+ self.qg_model.eval()
47
+
48
+ self.qa_evaluator = QAEvaluator()
49
+
50
+ def generate(
51
+ self,
52
+ article: str,
53
+ use_evaluator: bool = True,
54
+ num_questions: bool = None,
55
+ answer_style: str = "all"
56
+ ) -> List:
57
+ """Takes an article and generates a set of question and answer pairs. If use_evaluator
58
+ is True then QA pairs will be ranked and filtered based on their quality. answer_style
59
+ should selected from ["all", "sentences", "multiple_choice"].
60
+ """
61
+
62
+ print("Generating questions...\n")
63
+
64
+ qg_inputs, qg_answers = self.generate_qg_inputs(article, answer_style)
65
+ generated_questions = self.generate_questions_from_inputs(qg_inputs)
66
+
67
+ message = "{} questions doesn't match {} answers".format(
68
+ len(generated_questions), len(qg_answers)
69
+ )
70
+ assert len(generated_questions) == len(qg_answers), message
71
+
72
+ if use_evaluator:
73
+ print("Evaluating QA pairs...\n")
74
+ encoded_qa_pairs = self.qa_evaluator.encode_qa_pairs(
75
+ generated_questions, qg_answers
76
+ )
77
+ scores = self.qa_evaluator.get_scores(encoded_qa_pairs)
78
+
79
+ if num_questions:
80
+ qa_list = self._get_ranked_qa_pairs(
81
+ generated_questions, qg_answers, scores, num_questions
82
+ )
83
+ else:
84
+ qa_list = self._get_ranked_qa_pairs(
85
+ generated_questions, qg_answers, scores
86
+ )
87
+
88
+ else:
89
+ print("Skipping evaluation step.\n")
90
+ qa_list = self._get_all_qa_pairs(generated_questions, qg_answers)
91
+
92
+ return qa_list
93
+
94
+ def generate_qg_inputs(self, text: str, answer_style: str) -> Tuple[List[str], List[str]]:
95
+ """Given a text, returns a list of model inputs and a list of corresponding answers.
96
+ Model inputs take the form "answer_token <answer text> context_token <context text>" where
97
+ the answer is a string extracted from the text, and the context is the wider text surrounding
98
+ the context.
99
+ """
100
+
101
+ VALID_ANSWER_STYLES = ["all", "sentences", "multiple_choice"]
102
+
103
+ if answer_style not in VALID_ANSWER_STYLES:
104
+ raise ValueError(
105
+ "Invalid answer style {}. Please choose from {}".format(
106
+ answer_style, VALID_ANSWER_STYLES
107
+ )
108
+ )
109
+
110
+ inputs = []
111
+ answers = []
112
+
113
+ if answer_style == "sentences" or answer_style == "all":
114
+ segments = self._split_into_segments(text)
115
+
116
+ for segment in segments:
117
+ sentences = self._split_text(segment)
118
+ prepped_inputs, prepped_answers = self._prepare_qg_inputs(
119
+ sentences, segment
120
+ )
121
+ inputs.extend(prepped_inputs)
122
+ answers.extend(prepped_answers)
123
+
124
+ if answer_style == "multiple_choice" or answer_style == "all":
125
+ sentences = self._split_text(text)
126
+ prepped_inputs, prepped_answers = self._prepare_qg_inputs_MC(
127
+ sentences
128
+ )
129
+ inputs.extend(prepped_inputs)
130
+ answers.extend(prepped_answers)
131
+
132
+ return inputs, answers
133
+
134
+ def generate_questions_from_inputs(self, qg_inputs: List) -> List[str]:
135
+ """Given a list of concatenated answers and contexts, with the form:
136
+ "answer_token <answer text> context_token <context text>", generates a list of
137
+ questions.
138
+ """
139
+ generated_questions = []
140
+
141
+ for qg_input in qg_inputs:
142
+ question = self._generate_question(qg_input)
143
+ generated_questions.append(question)
144
+
145
+ return generated_questions
146
+
147
+ def _split_text(self, text: str) -> List[str]:
148
+ """Splits the text into sentences, and attempts to split or truncate long sentences."""
149
+ MAX_SENTENCE_LEN = 128
150
+ sentences = re.findall(".*?[.!\?]", text)
151
+ cut_sentences = []
152
+
153
+ for sentence in sentences:
154
+ if len(sentence) > MAX_SENTENCE_LEN:
155
+ cut_sentences.extend(re.split("[,;:)]", sentence))
156
+
157
+ # remove useless post-quote sentence fragments
158
+ cut_sentences = [s for s in sentences if len(s.split(" ")) > 5]
159
+ sentences = sentences + cut_sentences
160
+
161
+ return list(set([s.strip(" ") for s in sentences]))
162
+
163
+ def _split_into_segments(self, text: str) -> List[str]:
164
+ """Splits a long text into segments short enough to be input into the transformer network.
165
+ Segments are used as context for question generation.
166
+ """
167
+ MAX_TOKENS = 490
168
+ paragraphs = text.split("\n")
169
+ tokenized_paragraphs = [
170
+ self.qg_tokenizer(p)["input_ids"] for p in paragraphs if len(p) > 0
171
+ ]
172
+ segments = []
173
+
174
+ while len(tokenized_paragraphs) > 0:
175
+ segment = []
176
+
177
+ while len(segment) < MAX_TOKENS and len(tokenized_paragraphs) > 0:
178
+ paragraph = tokenized_paragraphs.pop(0)
179
+ segment.extend(paragraph)
180
+ segments.append(segment)
181
+
182
+ return [self.qg_tokenizer.decode(s, skip_special_tokens=True) for s in segments]
183
+
184
+ def _prepare_qg_inputs(
185
+ self,
186
+ sentences: List[str],
187
+ text: str
188
+ ) -> Tuple[List[str], List[str]]:
189
+ """Uses sentences as answers and the text as context. Returns a tuple of (model inputs, answers).
190
+ Model inputs are "answer_token <answer text> context_token <context text>"
191
+ """
192
+ inputs = []
193
+ answers = []
194
+
195
+ for sentence in sentences:
196
+ qg_input = f"{self.ANSWER_TOKEN} {sentence} {self.CONTEXT_TOKEN} {text}"
197
+ inputs.append(qg_input)
198
+ answers.append(sentence)
199
+
200
+ return inputs, answers
201
+
202
+ def _prepare_qg_inputs_MC(self, sentences: List[str]) -> Tuple[List[str], List[str]]:
203
+ """Performs NER on the text, and uses extracted entities are candidate answers for multiple-choice
204
+ questions. Sentences are used as context, and entities as answers. Returns a tuple of (model inputs, answers).
205
+ Model inputs are "answer_token <answer text> context_token <context text>"
206
+ """
207
+ spacy_nlp = en_core_web_sm.load()
208
+ docs = list(spacy_nlp.pipe(sentences, disable=["parser"]))
209
+ inputs_from_text = []
210
+ answers_from_text = []
211
+
212
+ for doc, sentence in zip(docs, sentences):
213
+ entities = doc.ents
214
+ if entities:
215
+
216
+ for entity in entities:
217
+ qg_input = f"{self.ANSWER_TOKEN} {entity} {self.CONTEXT_TOKEN} {sentence}"
218
+ answers = self._get_MC_answers(entity, docs)
219
+ inputs_from_text.append(qg_input)
220
+ answers_from_text.append(answers)
221
+
222
+ return inputs_from_text, answers_from_text
223
+
224
+ def _get_MC_answers(self, correct_answer: Any, docs: Any) -> List[Mapping[str, Any]]:
225
+ """Finds a set of alternative answers for a multiple-choice question. Will attempt to find
226
+ alternatives of the same entity type as correct_answer if possible.
227
+ """
228
+ entities = []
229
+
230
+ for doc in docs:
231
+ entities.extend([{"text": e.text, "label_": e.label_} for e in doc.ents])
232
+
233
+ # Remove duplicate elements and convert to a list
234
+ entities_json = [json.dumps(kv) for kv in entities]
235
+ pool = sorted(set(entities_json)) # Convert pool to a sorted list
236
+ num_choices = min(4, len(pool)) - 1 # Number of choices to make
237
+
238
+ # Add the correct answer
239
+ final_choices = []
240
+ correct_label = correct_answer.label_
241
+ final_choices.append({"answer": correct_answer.text, "correct": True})
242
+
243
+ # Remove the correct answer from the pool
244
+ pool = [e for e in pool if e != json.dumps({"text": correct_answer.text, "label_": correct_answer.label_})]
245
+
246
+ # Find answers with the same NER label
247
+ matches = [e for e in pool if correct_label in e]
248
+
249
+ # If not enough matches, add other random answers
250
+ if len(matches) < num_choices:
251
+ choices = matches
252
+ remaining_choices = random.sample(sorted(pool), num_choices - len(choices))
253
+ choices.extend(remaining_choices)
254
+ else:
255
+ choices = random.sample(sorted(matches), num_choices)
256
+
257
+ choices = [json.loads(s) for s in choices]
258
+
259
+ for choice in choices:
260
+ final_choices.append({"answer": choice["text"], "correct": False})
261
+
262
+ random.shuffle(final_choices)
263
+ return final_choices
264
+
265
+
266
+
267
+ # def _get_MC_answers(self, correct_answer: Any, docs: Any) -> List[Mapping[str, Any]]:
268
+ # """Finds a set of alternative answers for a multiple-choice question. Will attempt to find
269
+ # alternatives of the same entity type as correct_answer if possible.
270
+ # """
271
+ # entities = []
272
+
273
+ # for doc in docs:
274
+ # entities.extend([{"text": e.text, "label_": e.label_}
275
+ # for e in doc.ents])
276
+
277
+ # # remove duplicate elements
278
+ # entities_json = [json.dumps(kv) for kv in entities]
279
+ # pool = set(entities_json)
280
+ # num_choices = (
281
+ # min(4, len(pool)) - 1
282
+ # ) # -1 because we already have the correct answer
283
+
284
+ # # add the correct answer
285
+ # final_choices = []
286
+ # correct_label = correct_answer.label_
287
+ # final_choices.append({"answer": correct_answer.text, "correct": True})
288
+ # pool.remove(
289
+ # json.dumps({"text": correct_answer.text,
290
+ # "label_": correct_answer.label_})
291
+ # )
292
+
293
+ # # find answers with the same NER label
294
+ # matches = [e for e in pool if correct_label in e]
295
+
296
+ # # if we don't have enough then add some other random answers
297
+ # if len(matches) < num_choices:
298
+ # choices = matches
299
+ # pool = pool.difference(set(choices))
300
+ # choices.extend(random.sample(pool, num_choices - len(choices)))
301
+ # else:
302
+ # choices = random.sample(matches, num_choices)
303
+
304
+ # choices = [json.loads(s) for s in choices]
305
+
306
+ # for choice in choices:
307
+ # final_choices.append({"answer": choice["text"], "correct": False})
308
+
309
+ # random.shuffle(final_choices)
310
+ # return final_choices
311
+
312
+ @torch.no_grad()
313
+ def _generate_question(self, qg_input: str) -> str:
314
+ """Takes qg_input which is the concatenated answer and context, and uses it to generate
315
+ a question sentence. The generated question is decoded and then returned.
316
+ """
317
+ encoded_input = self._encode_qg_input(qg_input)
318
+ output = self.qg_model.generate(input_ids=encoded_input["input_ids"])
319
+ question = self.qg_tokenizer.decode(
320
+ output[0],
321
+ skip_special_tokens=True
322
+ )
323
+ return question
324
+
325
+ def _encode_qg_input(self, qg_input: str) -> torch.tensor:
326
+ """Tokenizes a string and returns a tensor of input ids corresponding to indices of tokens in
327
+ the vocab.
328
+ """
329
+ return self.qg_tokenizer(
330
+ qg_input,
331
+ padding='max_length',
332
+ max_length=self.SEQ_LENGTH,
333
+ truncation=True,
334
+ return_tensors="pt",
335
+ ).to(self.device)
336
+
337
+ def _get_ranked_qa_pairs(
338
+ self, generated_questions: List[str], qg_answers: List[str], scores, num_questions: int = 10
339
+ ) -> List[Mapping[str, str]]:
340
+ """Ranks generated questions according to scores, and returns the top num_questions examples.
341
+ """
342
+ if num_questions > len(scores):
343
+ num_questions = len(scores)
344
+ print((
345
+ f"\nWas only able to generate {num_questions} questions.",
346
+ "For more questions, please input a longer text.")
347
+ )
348
+
349
+ qa_list = []
350
+
351
+ for i in range(num_questions):
352
+ index = scores[i]
353
+ qa = {
354
+ "question": generated_questions[index].split("?")[0] + "?",
355
+ "answer": qg_answers[index]
356
+ }
357
+ qa_list.append(qa)
358
+
359
+ return qa_list
360
+
361
+ def _get_all_qa_pairs(self, generated_questions: List[str], qg_answers: List[str]):
362
+ """Formats question and answer pairs without ranking or filtering."""
363
+ qa_list = []
364
+
365
+ for question, answer in zip(generated_questions, qg_answers):
366
+ qa = {
367
+ "question": question.split("?")[0] + "?",
368
+ "answer": answer
369
+ }
370
+ qa_list.append(qa)
371
+
372
+ return qa_list
373
+
374
+
375
+ class QAEvaluator:
376
+ """Wrapper for a transformer model which evaluates the quality of question-answer pairs.
377
+ Given a QA pair, the model will generate a score. Scores can be used to rank and filter
378
+ QA pairs.
379
+ """
380
+
381
+ def __init__(self) -> None:
382
+
383
+ QAE_PRETRAINED = "iarfmoose/bert-base-cased-qa-evaluator"
384
+ self.SEQ_LENGTH = 512
385
+
386
+ self.device = torch.device(
387
+ "cuda" if torch.cuda.is_available() else "cpu")
388
+
389
+ self.qae_tokenizer = AutoTokenizer.from_pretrained(QAE_PRETRAINED)
390
+ self.qae_model = AutoModelForSequenceClassification.from_pretrained(
391
+ QAE_PRETRAINED
392
+ )
393
+ self.qae_model.to(self.device)
394
+ self.qae_model.eval()
395
+
396
+ def encode_qa_pairs(self, questions: List[str], answers: List[str]) -> List[torch.tensor]:
397
+ """Takes a list of questions and a list of answers and encodes them as a list of tensors."""
398
+ encoded_pairs = []
399
+
400
+ for question, answer in zip(questions, answers):
401
+ encoded_qa = self._encode_qa(question, answer)
402
+ encoded_pairs.append(encoded_qa.to(self.device))
403
+
404
+ return encoded_pairs
405
+
406
+ def get_scores(self, encoded_qa_pairs: List[torch.tensor]) -> List[float]:
407
+ """Generates scores for a list of encoded QA pairs."""
408
+ scores = {}
409
+
410
+ for i in range(len(encoded_qa_pairs)):
411
+ scores[i] = self._evaluate_qa(encoded_qa_pairs[i])
412
+
413
+ return [
414
+ k for k, v in sorted(scores.items(), key=lambda item: item[1], reverse=True)
415
+ ]
416
+
417
+ def _encode_qa(self, question: str, answer: str) -> torch.tensor:
418
+ """Concatenates a question and answer, and then tokenizes them. Returns a tensor of
419
+ input ids corresponding to indices in the vocab.
420
+ """
421
+ if type(answer) is list:
422
+ for a in answer:
423
+ if a["correct"]:
424
+ correct_answer = a["answer"]
425
+ else:
426
+ correct_answer = answer
427
+
428
+ return self.qae_tokenizer(
429
+ text=question,
430
+ text_pair=correct_answer,
431
+ padding="max_length",
432
+ max_length=self.SEQ_LENGTH,
433
+ truncation=True,
434
+ return_tensors="pt",
435
+ )
436
+
437
+ @torch.no_grad()
438
+ def _evaluate_qa(self, encoded_qa_pair: torch.tensor) -> float:
439
+ """Takes an encoded QA pair and returns a score."""
440
+ output = self.qae_model(**encoded_qa_pair)
441
+ return output[0][0][1]
442
+
443
+
444
+ def print_qa(qa_list: List[Mapping[str, str]], show_answers: bool = True) -> None:
445
+ """Formats and prints a list of generated questions and answers."""
446
+
447
+ for i in range(len(qa_list)):
448
+ # wider space for 2 digit q nums
449
+ space = " " * int(np.where(i < 9, 3, 4))
450
+
451
+ print(f"{i + 1}) Q: {qa_list[i]['question']}")
452
+
453
+ answer = qa_list[i]["answer"]
454
+
455
+ # print a list of multiple choice answers
456
+ if type(answer) is list:
457
+
458
+ if show_answers:
459
+ print(
460
+ f"{space}A: 1. {answer[0]['answer']} "
461
+ f"{np.where(answer[0]['correct'], '(correct)', '')}"
462
+ )
463
+ for j in range(1, len(answer)):
464
+ print(
465
+ f"{space + ' '}{j + 1}. {answer[j]['answer']} "
466
+ f"{np.where(answer[j]['correct']==True,'(correct)', '')}"
467
+ )
468
+
469
+ else:
470
+ print(f"{space}A: 1. {answer[0]['answer']}")
471
+ for j in range(1, len(answer)):
472
+ print(f"{space + ' '}{j + 1}. {answer[j]['answer']}")
473
+
474
+ print("")
475
+
476
+ # print full sentence answers
477
+ else:
478
+ if show_answers:
479
+ print(f"{space}A: {answer}\n")
480
+
481
+
482
+
483
 
484
  def main():
485
  # Set the Streamlit app title