Krishnan Palanisami commited on
Commit
313428e
·
verified ·
1 Parent(s): 4a2ef03

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -575
app.py DELETED
@@ -1,575 +0,0 @@
1
- import streamlit as st
2
- import wikipedia
3
- from haystack.document_stores import InMemoryDocumentStore
4
- from haystack.utils import clean_wiki_text, convert_files_to_docs
5
- from haystack.nodes import TfidfRetriever, FARMReader
6
- from haystack.pipelines import ExtractiveQAPipeline
7
- from main import print_qa, QuestionGenerator
8
- import en_core_web_sm
9
- import json
10
- import numpy as np
11
- import random
12
- import re
13
- import torch
14
- from transformers import (
15
- AutoTokenizer,
16
- AutoModelForSeq2SeqLM,
17
- AutoModelForSequenceClassification,
18
- )
19
- from typing import Any, List, Mapping, Tuple
20
-
21
-
22
- class QuestionGenerator:
23
- """A transformer-based NLP system for generating reading comprehension-style questions from
24
- texts. It can generate full sentence questions, multiple choice questions, or a mix of the
25
- two styles.
26
-
27
- To filter out low quality questions, questions are assigned a score and ranked once they have
28
- been generated. Only the top k questions will be returned. This behaviour can be turned off
29
- by setting use_evaluator=False.
30
- """
31
-
32
- def __init__(self) -> None:
33
-
34
- QG_PRETRAINED = "iarfmoose/t5-base-question-generator"
35
- self.ANSWER_TOKEN = "<answer>"
36
- self.CONTEXT_TOKEN = "<context>"
37
- self.SEQ_LENGTH = 512
38
-
39
- self.device = torch.device(
40
- "cuda" if torch.cuda.is_available() else "cpu")
41
-
42
- self.qg_tokenizer = AutoTokenizer.from_pretrained(
43
- QG_PRETRAINED, use_fast=False)
44
- self.qg_model = AutoModelForSeq2SeqLM.from_pretrained(QG_PRETRAINED)
45
- self.qg_model.to(self.device)
46
- self.qg_model.eval()
47
-
48
- self.qa_evaluator = QAEvaluator()
49
-
50
- def generate(
51
- self,
52
- article: str,
53
- use_evaluator: bool = True,
54
- num_questions: bool = None,
55
- answer_style: str = "all"
56
- ) -> List:
57
- """Takes an article and generates a set of question and answer pairs. If use_evaluator
58
- is True then QA pairs will be ranked and filtered based on their quality. answer_style
59
- should selected from ["all", "sentences", "multiple_choice"].
60
- """
61
-
62
- print("Generating questions...\n")
63
-
64
- qg_inputs, qg_answers = self.generate_qg_inputs(article, answer_style)
65
- generated_questions = self.generate_questions_from_inputs(qg_inputs)
66
-
67
- message = "{} questions doesn't match {} answers".format(
68
- len(generated_questions), len(qg_answers)
69
- )
70
- assert len(generated_questions) == len(qg_answers), message
71
-
72
- if use_evaluator:
73
- print("Evaluating QA pairs...\n")
74
- encoded_qa_pairs = self.qa_evaluator.encode_qa_pairs(
75
- generated_questions, qg_answers
76
- )
77
- scores = self.qa_evaluator.get_scores(encoded_qa_pairs)
78
-
79
- if num_questions:
80
- qa_list = self._get_ranked_qa_pairs(
81
- generated_questions, qg_answers, scores, num_questions
82
- )
83
- else:
84
- qa_list = self._get_ranked_qa_pairs(
85
- generated_questions, qg_answers, scores
86
- )
87
-
88
- else:
89
- print("Skipping evaluation step.\n")
90
- qa_list = self._get_all_qa_pairs(generated_questions, qg_answers)
91
-
92
- return qa_list
93
-
94
- def generate_qg_inputs(self, text: str, answer_style: str) -> Tuple[List[str], List[str]]:
95
- """Given a text, returns a list of model inputs and a list of corresponding answers.
96
- Model inputs take the form "answer_token <answer text> context_token <context text>" where
97
- the answer is a string extracted from the text, and the context is the wider text surrounding
98
- the context.
99
- """
100
-
101
- VALID_ANSWER_STYLES = ["all", "sentences", "multiple_choice"]
102
-
103
- if answer_style not in VALID_ANSWER_STYLES:
104
- raise ValueError(
105
- "Invalid answer style {}. Please choose from {}".format(
106
- answer_style, VALID_ANSWER_STYLES
107
- )
108
- )
109
-
110
- inputs = []
111
- answers = []
112
-
113
- if answer_style == "sentences" or answer_style == "all":
114
- segments = self._split_into_segments(text)
115
-
116
- for segment in segments:
117
- sentences = self._split_text(segment)
118
- prepped_inputs, prepped_answers = self._prepare_qg_inputs(
119
- sentences, segment
120
- )
121
- inputs.extend(prepped_inputs)
122
- answers.extend(prepped_answers)
123
-
124
- if answer_style == "multiple_choice" or answer_style == "all":
125
- sentences = self._split_text(text)
126
- prepped_inputs, prepped_answers = self._prepare_qg_inputs_MC(
127
- sentences
128
- )
129
- inputs.extend(prepped_inputs)
130
- answers.extend(prepped_answers)
131
-
132
- return inputs, answers
133
-
134
- def generate_questions_from_inputs(self, qg_inputs: List) -> List[str]:
135
- """Given a list of concatenated answers and contexts, with the form:
136
- "answer_token <answer text> context_token <context text>", generates a list of
137
- questions.
138
- """
139
- generated_questions = []
140
-
141
- for qg_input in qg_inputs:
142
- question = self._generate_question(qg_input)
143
- generated_questions.append(question)
144
-
145
- return generated_questions
146
-
147
- def _split_text(self, text: str) -> List[str]:
148
- """Splits the text into sentences, and attempts to split or truncate long sentences."""
149
- MAX_SENTENCE_LEN = 128
150
- sentences = re.findall(".*?[.!\?]", text)
151
- cut_sentences = []
152
-
153
- for sentence in sentences:
154
- if len(sentence) > MAX_SENTENCE_LEN:
155
- cut_sentences.extend(re.split("[,;:)]", sentence))
156
-
157
- # remove useless post-quote sentence fragments
158
- cut_sentences = [s for s in sentences if len(s.split(" ")) > 5]
159
- sentences = sentences + cut_sentences
160
-
161
- return list(set([s.strip(" ") for s in sentences]))
162
-
163
- def _split_into_segments(self, text: str) -> List[str]:
164
- """Splits a long text into segments short enough to be input into the transformer network.
165
- Segments are used as context for question generation.
166
- """
167
- MAX_TOKENS = 490
168
- paragraphs = text.split("\n")
169
- tokenized_paragraphs = [
170
- self.qg_tokenizer(p)["input_ids"] for p in paragraphs if len(p) > 0
171
- ]
172
- segments = []
173
-
174
- while len(tokenized_paragraphs) > 0:
175
- segment = []
176
-
177
- while len(segment) < MAX_TOKENS and len(tokenized_paragraphs) > 0:
178
- paragraph = tokenized_paragraphs.pop(0)
179
- segment.extend(paragraph)
180
- segments.append(segment)
181
-
182
- return [self.qg_tokenizer.decode(s, skip_special_tokens=True) for s in segments]
183
-
184
- def _prepare_qg_inputs(
185
- self,
186
- sentences: List[str],
187
- text: str
188
- ) -> Tuple[List[str], List[str]]:
189
- """Uses sentences as answers and the text as context. Returns a tuple of (model inputs, answers).
190
- Model inputs are "answer_token <answer text> context_token <context text>"
191
- """
192
- inputs = []
193
- answers = []
194
-
195
- for sentence in sentences:
196
- qg_input = f"{self.ANSWER_TOKEN} {sentence} {self.CONTEXT_TOKEN} {text}"
197
- inputs.append(qg_input)
198
- answers.append(sentence)
199
-
200
- return inputs, answers
201
-
202
- def _prepare_qg_inputs_MC(self, sentences: List[str]) -> Tuple[List[str], List[str]]:
203
- """Performs NER on the text, and uses extracted entities are candidate answers for multiple-choice
204
- questions. Sentences are used as context, and entities as answers. Returns a tuple of (model inputs, answers).
205
- Model inputs are "answer_token <answer text> context_token <context text>"
206
- """
207
- spacy_nlp = en_core_web_sm.load()
208
- docs = list(spacy_nlp.pipe(sentences, disable=["parser"]))
209
- inputs_from_text = []
210
- answers_from_text = []
211
-
212
- for doc, sentence in zip(docs, sentences):
213
- entities = doc.ents
214
- if entities:
215
-
216
- for entity in entities:
217
- qg_input = f"{self.ANSWER_TOKEN} {entity} {self.CONTEXT_TOKEN} {sentence}"
218
- answers = self._get_MC_answers(entity, docs)
219
- inputs_from_text.append(qg_input)
220
- answers_from_text.append(answers)
221
-
222
- return inputs_from_text, answers_from_text
223
-
224
- def _get_MC_answers(self, correct_answer: Any, docs: Any) -> List[Mapping[str, Any]]:
225
- """Finds a set of alternative answers for a multiple-choice question. Will attempt to find
226
- alternatives of the same entity type as correct_answer if possible.
227
- """
228
- entities = []
229
-
230
- for doc in docs:
231
- entities.extend([{"text": e.text, "label_": e.label_} for e in doc.ents])
232
-
233
- # Remove duplicate elements and convert to a list
234
- entities_json = [json.dumps(kv) for kv in entities]
235
- pool = sorted(set(entities_json)) # Convert pool to a sorted list
236
- num_choices = min(4, len(pool)) - 1 # Number of choices to make
237
-
238
- # Add the correct answer
239
- final_choices = []
240
- correct_label = correct_answer.label_
241
- final_choices.append({"answer": correct_answer.text, "correct": True})
242
-
243
- # Remove the correct answer from the pool
244
- pool = [e for e in pool if e != json.dumps({"text": correct_answer.text, "label_": correct_answer.label_})]
245
-
246
- # Find answers with the same NER label
247
- matches = [e for e in pool if correct_label in e]
248
-
249
- # If not enough matches, add other random answers
250
- if len(matches) < num_choices:
251
- choices = matches
252
- remaining_choices = random.sample(sorted(pool), num_choices - len(choices))
253
- choices.extend(remaining_choices)
254
- else:
255
- choices = random.sample(sorted(matches), num_choices)
256
-
257
- choices = [json.loads(s) for s in choices]
258
-
259
- for choice in choices:
260
- final_choices.append({"answer": choice["text"], "correct": False})
261
-
262
- random.shuffle(final_choices)
263
- return final_choices
264
-
265
-
266
-
267
- # def _get_MC_answers(self, correct_answer: Any, docs: Any) -> List[Mapping[str, Any]]:
268
- # """Finds a set of alternative answers for a multiple-choice question. Will attempt to find
269
- # alternatives of the same entity type as correct_answer if possible.
270
- # """
271
- # entities = []
272
-
273
- # for doc in docs:
274
- # entities.extend([{"text": e.text, "label_": e.label_}
275
- # for e in doc.ents])
276
-
277
- # # remove duplicate elements
278
- # entities_json = [json.dumps(kv) for kv in entities]
279
- # pool = set(entities_json)
280
- # num_choices = (
281
- # min(4, len(pool)) - 1
282
- # ) # -1 because we already have the correct answer
283
-
284
- # # add the correct answer
285
- # final_choices = []
286
- # correct_label = correct_answer.label_
287
- # final_choices.append({"answer": correct_answer.text, "correct": True})
288
- # pool.remove(
289
- # json.dumps({"text": correct_answer.text,
290
- # "label_": correct_answer.label_})
291
- # )
292
-
293
- # # find answers with the same NER label
294
- # matches = [e for e in pool if correct_label in e]
295
-
296
- # # if we don't have enough then add some other random answers
297
- # if len(matches) < num_choices:
298
- # choices = matches
299
- # pool = pool.difference(set(choices))
300
- # choices.extend(random.sample(pool, num_choices - len(choices)))
301
- # else:
302
- # choices = random.sample(matches, num_choices)
303
-
304
- # choices = [json.loads(s) for s in choices]
305
-
306
- # for choice in choices:
307
- # final_choices.append({"answer": choice["text"], "correct": False})
308
-
309
- # random.shuffle(final_choices)
310
- # return final_choices
311
-
312
- @torch.no_grad()
313
- def _generate_question(self, qg_input: str) -> str:
314
- """Takes qg_input which is the concatenated answer and context, and uses it to generate
315
- a question sentence. The generated question is decoded and then returned.
316
- """
317
- encoded_input = self._encode_qg_input(qg_input)
318
- output = self.qg_model.generate(input_ids=encoded_input["input_ids"])
319
- question = self.qg_tokenizer.decode(
320
- output[0],
321
- skip_special_tokens=True
322
- )
323
- return question
324
-
325
- def _encode_qg_input(self, qg_input: str) -> torch.tensor:
326
- """Tokenizes a string and returns a tensor of input ids corresponding to indices of tokens in
327
- the vocab.
328
- """
329
- return self.qg_tokenizer(
330
- qg_input,
331
- padding='max_length',
332
- max_length=self.SEQ_LENGTH,
333
- truncation=True,
334
- return_tensors="pt",
335
- ).to(self.device)
336
-
337
- def _get_ranked_qa_pairs(
338
- self, generated_questions: List[str], qg_answers: List[str], scores, num_questions: int = 10
339
- ) -> List[Mapping[str, str]]:
340
- """Ranks generated questions according to scores, and returns the top num_questions examples.
341
- """
342
- if num_questions > len(scores):
343
- num_questions = len(scores)
344
- print((
345
- f"\nWas only able to generate {num_questions} questions.",
346
- "For more questions, please input a longer text.")
347
- )
348
-
349
- qa_list = []
350
-
351
- for i in range(num_questions):
352
- index = scores[i]
353
- qa = {
354
- "question": generated_questions[index].split("?")[0] + "?",
355
- "answer": qg_answers[index]
356
- }
357
- qa_list.append(qa)
358
-
359
- return qa_list
360
-
361
- def _get_all_qa_pairs(self, generated_questions: List[str], qg_answers: List[str]):
362
- """Formats question and answer pairs without ranking or filtering."""
363
- qa_list = []
364
-
365
- for question, answer in zip(generated_questions, qg_answers):
366
- qa = {
367
- "question": question.split("?")[0] + "?",
368
- "answer": answer
369
- }
370
- qa_list.append(qa)
371
-
372
- return qa_list
373
-
374
-
375
- class QAEvaluator:
376
- """Wrapper for a transformer model which evaluates the quality of question-answer pairs.
377
- Given a QA pair, the model will generate a score. Scores can be used to rank and filter
378
- QA pairs.
379
- """
380
-
381
- def __init__(self) -> None:
382
-
383
- QAE_PRETRAINED = "iarfmoose/bert-base-cased-qa-evaluator"
384
- self.SEQ_LENGTH = 512
385
-
386
- self.device = torch.device(
387
- "cuda" if torch.cuda.is_available() else "cpu")
388
-
389
- self.qae_tokenizer = AutoTokenizer.from_pretrained(QAE_PRETRAINED)
390
- self.qae_model = AutoModelForSequenceClassification.from_pretrained(
391
- QAE_PRETRAINED
392
- )
393
- self.qae_model.to(self.device)
394
- self.qae_model.eval()
395
-
396
- def encode_qa_pairs(self, questions: List[str], answers: List[str]) -> List[torch.tensor]:
397
- """Takes a list of questions and a list of answers and encodes them as a list of tensors."""
398
- encoded_pairs = []
399
-
400
- for question, answer in zip(questions, answers):
401
- encoded_qa = self._encode_qa(question, answer)
402
- encoded_pairs.append(encoded_qa.to(self.device))
403
-
404
- return encoded_pairs
405
-
406
- def get_scores(self, encoded_qa_pairs: List[torch.tensor]) -> List[float]:
407
- """Generates scores for a list of encoded QA pairs."""
408
- scores = {}
409
-
410
- for i in range(len(encoded_qa_pairs)):
411
- scores[i] = self._evaluate_qa(encoded_qa_pairs[i])
412
-
413
- return [
414
- k for k, v in sorted(scores.items(), key=lambda item: item[1], reverse=True)
415
- ]
416
-
417
- def _encode_qa(self, question: str, answer: str) -> torch.tensor:
418
- """Concatenates a question and answer, and then tokenizes them. Returns a tensor of
419
- input ids corresponding to indices in the vocab.
420
- """
421
- if type(answer) is list:
422
- for a in answer:
423
- if a["correct"]:
424
- correct_answer = a["answer"]
425
- else:
426
- correct_answer = answer
427
-
428
- return self.qae_tokenizer(
429
- text=question,
430
- text_pair=correct_answer,
431
- padding="max_length",
432
- max_length=self.SEQ_LENGTH,
433
- truncation=True,
434
- return_tensors="pt",
435
- )
436
-
437
- @torch.no_grad()
438
- def _evaluate_qa(self, encoded_qa_pair: torch.tensor) -> float:
439
- """Takes an encoded QA pair and returns a score."""
440
- output = self.qae_model(**encoded_qa_pair)
441
- return output[0][0][1]
442
-
443
-
444
- def print_qa(qa_list: List[Mapping[str, str]], show_answers: bool = True) -> None:
445
- """Formats and prints a list of generated questions and answers."""
446
-
447
- for i in range(len(qa_list)):
448
- # wider space for 2 digit q nums
449
- space = " " * int(np.where(i < 9, 3, 4))
450
-
451
- print(f"{i + 1}) Q: {qa_list[i]['question']}")
452
-
453
- answer = qa_list[i]["answer"]
454
-
455
- # print a list of multiple choice answers
456
- if type(answer) is list:
457
-
458
- if show_answers:
459
- print(
460
- f"{space}A: 1. {answer[0]['answer']} "
461
- f"{np.where(answer[0]['correct'], '(correct)', '')}"
462
- )
463
- for j in range(1, len(answer)):
464
- print(
465
- f"{space + ' '}{j + 1}. {answer[j]['answer']} "
466
- f"{np.where(answer[j]['correct']==True,'(correct)', '')}"
467
- )
468
-
469
- else:
470
- print(f"{space}A: 1. {answer[0]['answer']}")
471
- for j in range(1, len(answer)):
472
- print(f"{space + ' '}{j + 1}. {answer[j]['answer']}")
473
-
474
- print("")
475
-
476
- # print full sentence answers
477
- else:
478
- if show_answers:
479
- print(f"{space}A: {answer}\n")
480
-
481
-
482
-
483
-
484
- def main():
485
- # Set the Streamlit app title
486
- st.title("Question Generation using Haystack and Streamlit")
487
-
488
- # Select the input type
489
- inputs = ["Input Paragraph", "Wikipedia Examples"]
490
- input_type = st.selectbox("Select an input type:", inputs)
491
-
492
- # Initialize wiki_text as an empty string
493
- wiki_text = ""
494
-
495
- # Handle different input types
496
- if input_type == "Input Paragraph":
497
- # Allow user to input text paragraph
498
- wiki_text = st.text_area("Input paragraph:", height=200)
499
-
500
- elif input_type == "Wikipedia Examples":
501
- # Define topics for selection
502
- topics = ["Deep Learning", "Machine Learning"]
503
- selected_topic = st.selectbox("Select a topic:", topics)
504
-
505
- # Retrieve Wikipedia content based on the selected topic
506
- if selected_topic:
507
- wiki = wikipedia.page(selected_topic)
508
- wiki_text = wiki.content
509
-
510
- # Display the retrieved Wikipedia content (optional)
511
- st.text_area("Retrieved Wikipedia content:", wiki_text, height=200)
512
-
513
- # Preprocess the input text
514
- wiki_text = clean_wiki_text(wiki_text)
515
-
516
- # Allow user to specify the number of questions to generate
517
- num_questions = st.slider("Number of questions to generate:", min_value=1, max_value=20, value=5)
518
-
519
- # Allow user to specify the model to use
520
- model_options = ["deepset/roberta-base-squad2", "deepset/roberta-base-squad2-distilled", "bert-large-uncased-whole-word-masking-squad2", "deepset/flan-t5-xl-squad2"]
521
- model_name = st.selectbox("Select model:", model_options)
522
-
523
- # Button to generate questions
524
- if st.button("Generate Questions"):
525
- document_store = InMemoryDocumentStore()
526
-
527
- # Convert the preprocessed text into a document
528
- document = {"content": wiki_text}
529
- document_store.write_documents([document])
530
-
531
- # Initialize a TfidfRetriever
532
- retriever = TfidfRetriever(document_store=document_store)
533
-
534
- # Initialize a FARMReader with the selected model
535
- reader = FARMReader(model_name_or_path=model_name, use_gpu=False)
536
-
537
- # Initialize the question generation pipeline
538
- pipe = ExtractiveQAPipeline(reader, retriever)
539
-
540
- # Initialize the QuestionGenerator
541
- qg = QuestionGenerator()
542
-
543
- # Generate multiple-choice questions
544
- qa_list = qg.generate(
545
- wiki_text,
546
- num_questions=num_questions,
547
- answer_style='multiple_choice'
548
- )
549
-
550
- # Display the generated questions and answers
551
- st.header("Generated Questions and Answers:")
552
- for idx, qa in enumerate(qa_list):
553
- # Display the question
554
- st.write(f"Question {idx + 1}: {qa['question']}")
555
-
556
- # Display the answer options
557
- if 'answer' in qa:
558
- for i, option in enumerate(qa['answer']):
559
- correct_marker = "(correct)" if option["correct"] else ""
560
- st.write(f"Option {i + 1}: {option['answer']} {correct_marker}")
561
-
562
- # Add a separator after each question-answer pair
563
- st.write("-" * 40)
564
-
565
-
566
-
567
-
568
-
569
-
570
-
571
- # Run the Streamlit app
572
- if __name__ == "__main__":
573
- main()
574
-
575
-