Anshoo Mehra commited on
Commit
db60c4c
·
1 Parent(s): cb8c28c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -5
app.py CHANGED
@@ -6,6 +6,7 @@ from transformers import (
6
  AutoTokenizer
7
  )
8
 
 
9
  M1 = "anshoomehra/question-generation-auto-t5-v1-base-s-q"
10
  M2 = "anshoomehra/question-generation-auto-t5-v1-base-s-q-c"
11
 
@@ -14,6 +15,9 @@ M5 = "anshoomehra/question-generation-auto-hints-t5-v1-base-s-q-c"
14
 
15
  device = ['cuda' if torch.cuda.is_available() else 'cpu'][0]
16
 
 
 
 
17
  _m1 = AutoModelForSeq2SeqLM.from_pretrained(M1).to(device)
18
  _tk1 = AutoTokenizer.from_pretrained(M1, cache_dir="./cache")
19
 
@@ -48,6 +52,7 @@ def _formatQs(questions):
48
 
49
  def _generate(mode, context, hint=None, minLength=50, maxLength=500, lengthPenalty=2.0, earlyStopping=True, numReturnSequences=1, numBeams=2, noRepeatNGramSize=0, doSample=False, topK=0, topP=0, temperature=0):
50
 
 
51
  predictionM1 = None
52
  predictionM2 = None
53
  predictionM4 = None
@@ -55,7 +60,23 @@ def _generate(mode, context, hint=None, minLength=50, maxLength=500, lengthPenal
55
 
56
  if mode == 'Auto':
57
  _inputText = "question_context: " + context
58
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  _encoding = _tk1.encode(_inputText, return_tensors='pt', truncation=True, padding='max_length').to(device) # max_length=1024
60
  _outputEncoded = _m1.generate(_encoding,
61
  min_length=minLength,
@@ -119,7 +140,6 @@ def _generate(mode, context, hint=None, minLength=50, maxLength=500, lengthPenal
119
  temperature=temperature
120
  )
121
  predictionM5 = [_tk5.decode(id, clean_up_tokenization_spaces=False, skip_special_tokens=True) for id in _outputEncoded]
122
-
123
  elif mode == 'Hints':
124
  _inputText = "question_hint: " + hint + "</s>question_context: " + context
125
 
@@ -155,12 +175,13 @@ def _generate(mode, context, hint=None, minLength=50, maxLength=500, lengthPenal
155
  )
156
  predictionM5 = [_tk5.decode(id, clean_up_tokenization_spaces=False, skip_special_tokens=True) for id in _outputEncoded]
157
 
 
158
  predictionM1 = _formatQs(predictionM1)
159
  predictionM2 = _formatQs(predictionM2)
160
  predictionM4 = _formatQs(predictionM4)
161
  predictionM5 = _formatQs(predictionM5)
162
-
163
- return predictionM1, predictionM2, predictionM4, predictionM5
164
 
165
  with gr.Blocks() as demo:
166
 
@@ -199,12 +220,13 @@ with gr.Blocks() as demo:
199
  with gr.Row(variant='compact'):
200
  _predictionM2 = gr.Textbox(label="Predicted Questions - question-generation-auto-t5-v1-base-s-q-c [No Hints]")
201
  _predictionM1 = gr.Textbox(label="Predicted Questions - question-generation-auto-t5-v1-base-s-q [No Hints]")
 
202
 
203
  with gr.Row():
204
  gen_btn = gr.Button("Generate Questions")
205
  gen_btn.click(fn=_generate,
206
  inputs=[mode, context, hint, minLength, maxLength, lengthPenalty, earlyStopping, numReturnSequences, numBeams, noRepeatNGramSize, doSample, topK, topP, temperature],
207
- outputs=[_predictionM1, _predictionM2, _predictionM4, _predictionM5]
208
  )
209
 
210
  demo.launch(show_error=True)
 
6
  AutoTokenizer
7
  )
8
 
9
+ M0 = "anshoomehra/question-generation-auto-t5-v1-base-s"
10
  M1 = "anshoomehra/question-generation-auto-t5-v1-base-s-q"
11
  M2 = "anshoomehra/question-generation-auto-t5-v1-base-s-q-c"
12
 
 
15
 
16
  device = ['cuda' if torch.cuda.is_available() else 'cpu'][0]
17
 
18
+ _m0 = AutoModelForSeq2SeqLM.from_pretrained(M0).to(device)
19
+ _tk0 = AutoTokenizer.from_pretrained(M0, cache_dir="./cache")
20
+
21
  _m1 = AutoModelForSeq2SeqLM.from_pretrained(M1).to(device)
22
  _tk1 = AutoTokenizer.from_pretrained(M1, cache_dir="./cache")
23
 
 
52
 
53
  def _generate(mode, context, hint=None, minLength=50, maxLength=500, lengthPenalty=2.0, earlyStopping=True, numReturnSequences=1, numBeams=2, noRepeatNGramSize=0, doSample=False, topK=0, topP=0, temperature=0):
54
 
55
+ predictionM0 = None
56
  predictionM1 = None
57
  predictionM2 = None
58
  predictionM4 = None
 
60
 
61
  if mode == 'Auto':
62
  _inputText = "question_context: " + context
63
+
64
+ _encoding = _tk0.encode(_inputText, return_tensors='pt', truncation=True, padding='max_length').to(device) # max_length=1024
65
+ _outputEncoded = _m0.generate(_encoding,
66
+ min_length=minLength,
67
+ max_length=maxLength,
68
+ length_penalty=lengthPenalty,
69
+ early_stopping=earlyStopping,
70
+ num_return_sequences=numReturnSequences,
71
+ num_beams=numBeams,
72
+ no_repeat_ngram_size=noRepeatNGramSize,
73
+ do_sample=doSample,
74
+ top_k=topK,
75
+ top_p=topP,
76
+ temperature=temperature
77
+ )
78
+ predictionM0 = [_tk0.decode(id, clean_up_tokenization_spaces=False, skip_special_tokens=True) for id in _outputEncoded]
79
+
80
  _encoding = _tk1.encode(_inputText, return_tensors='pt', truncation=True, padding='max_length').to(device) # max_length=1024
81
  _outputEncoded = _m1.generate(_encoding,
82
  min_length=minLength,
 
140
  temperature=temperature
141
  )
142
  predictionM5 = [_tk5.decode(id, clean_up_tokenization_spaces=False, skip_special_tokens=True) for id in _outputEncoded]
 
143
  elif mode == 'Hints':
144
  _inputText = "question_hint: " + hint + "</s>question_context: " + context
145
 
 
175
  )
176
  predictionM5 = [_tk5.decode(id, clean_up_tokenization_spaces=False, skip_special_tokens=True) for id in _outputEncoded]
177
 
178
+ predictionM0 = _formatQs(predictionM0)
179
  predictionM1 = _formatQs(predictionM1)
180
  predictionM2 = _formatQs(predictionM2)
181
  predictionM4 = _formatQs(predictionM4)
182
  predictionM5 = _formatQs(predictionM5)
183
+
184
+ return predictionM5, predictionM4, predictionM2, predictionM1, predictionM0
185
 
186
  with gr.Blocks() as demo:
187
 
 
220
  with gr.Row(variant='compact'):
221
  _predictionM2 = gr.Textbox(label="Predicted Questions - question-generation-auto-t5-v1-base-s-q-c [No Hints]")
222
  _predictionM1 = gr.Textbox(label="Predicted Questions - question-generation-auto-t5-v1-base-s-q [No Hints]")
223
+ _predictionM0 = gr.Textbox(label="Predicted Questions - question-generation-auto-t5-v1-base-s-q [No Hints]")
224
 
225
  with gr.Row():
226
  gen_btn = gr.Button("Generate Questions")
227
  gen_btn.click(fn=_generate,
228
  inputs=[mode, context, hint, minLength, maxLength, lengthPenalty, earlyStopping, numReturnSequences, numBeams, noRepeatNGramSize, doSample, topK, topP, temperature],
229
+ outputs=[_predictionM5, _predictionM4, _predictionM2, _predictionM1, _predictionM0]
230
  )
231
 
232
  demo.launch(show_error=True)