Spaces:
Runtime error
Runtime error
Commit
·
8ee4f89
1
Parent(s):
8a463ae
Revert run_model
Browse files- generator.py +10 -20
generator.py
CHANGED
@@ -34,34 +34,29 @@ import streamlit as st
|
|
34 |
# hfmodel = pickle.load(open('models/hfmodel.sav', 'rb'))
|
35 |
|
36 |
def load_model():
|
37 |
-
hfm = pickle.load(open('
|
38 |
hft = T5TokenizerFast.from_pretrained("t5-base")
|
39 |
model = pickle.load(open('electra_model.sav','rb'))
|
40 |
tok = et.from_pretrained("mrm8488/electra-small-finetuned-squadv2")
|
41 |
# return hfm, hft,tok, model
|
42 |
return hfm, hft,tok, model
|
43 |
|
44 |
-
hfmodel, hftokenizer,tok, model = load_model()
|
45 |
|
46 |
def run_model(input_string, **generator_args):
|
47 |
generator_args = {
|
48 |
"max_length": 256,
|
49 |
"num_beams": 4,
|
50 |
"length_penalty": 1.5,
|
51 |
-
"no_repeat_ngram_size":
|
52 |
-
"early_stopping":
|
53 |
}
|
54 |
# tokenizer = att.from_pretrained("ThomasSimonini/t5-end2end-question-generation")
|
55 |
-
# output = nlp(input_string)
|
56 |
-
|
57 |
input_string = "generate questions: " + input_string + " </s>"
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
output = hftokenizer.decode(res[0], skip_special_tokens=True)
|
63 |
-
# output = output.split('</sep>')
|
64 |
-
# output = [o.strip() for o in output[:-1]]
|
65 |
return output
|
66 |
|
67 |
|
@@ -126,18 +121,13 @@ def read_file(filepath_name):
|
|
126 |
|
127 |
def create_string_for_generator(context):
|
128 |
gen_list = gen_question(context)
|
129 |
-
return gen_list
|
130 |
|
131 |
def creator(context):
|
132 |
questions = create_string_for_generator(context)
|
133 |
-
questions = questions.split('?')
|
134 |
pairs = []
|
135 |
for ques in questions:
|
136 |
-
l = len(ques)
|
137 |
-
if(l == 0):
|
138 |
-
continue
|
139 |
-
if ques[l-1] != '?':
|
140 |
-
ques = ques + '?'
|
141 |
pair = QA(ques,context)
|
142 |
print(pair)
|
143 |
pairs.append(pair)
|
|
|
34 |
# hfmodel = pickle.load(open('models/hfmodel.sav', 'rb'))
|
35 |
|
36 |
def load_model():
|
37 |
+
hfm = pickle.load(open('hfmodel.sav','rb'))
|
38 |
hft = T5TokenizerFast.from_pretrained("t5-base")
|
39 |
model = pickle.load(open('electra_model.sav','rb'))
|
40 |
tok = et.from_pretrained("mrm8488/electra-small-finetuned-squadv2")
|
41 |
# return hfm, hft,tok, model
|
42 |
return hfm, hft,tok, model
|
43 |
|
44 |
+
hfmodel, hftokenizer, tok, model = load_model()
|
45 |
|
46 |
def run_model(input_string, **generator_args):
|
47 |
generator_args = {
|
48 |
"max_length": 256,
|
49 |
"num_beams": 4,
|
50 |
"length_penalty": 1.5,
|
51 |
+
"no_repeat_ngram_size": 3,
|
52 |
+
"early_stopping": True,
|
53 |
}
|
54 |
# tokenizer = att.from_pretrained("ThomasSimonini/t5-end2end-question-generation")
|
|
|
|
|
55 |
input_string = "generate questions: " + input_string + " </s>"
|
56 |
+
input_ids = hftokenizer.encode(input_string, return_tensors="pt")
|
57 |
+
res = hfmodel.generate(input_ids, **generator_args)
|
58 |
+
output = hftokenizer.batch_decode(res, skip_special_tokens=True)
|
59 |
+
output = [item.split("<sep>") for item in output]
|
|
|
|
|
|
|
60 |
return output
|
61 |
|
62 |
|
|
|
121 |
|
122 |
def create_string_for_generator(context):
|
123 |
gen_list = gen_question(context)
|
124 |
+
return (gen_list[0][0]).split('? ')
|
125 |
|
126 |
def creator(context):
|
127 |
questions = create_string_for_generator(context)
|
128 |
+
# questions = questions.split('?')
|
129 |
pairs = []
|
130 |
for ques in questions:
|
|
|
|
|
|
|
|
|
|
|
131 |
pair = QA(ques,context)
|
132 |
print(pair)
|
133 |
pairs.append(pair)
|