Spaces:
Runtime error
Runtime error
Commit
·
9a708c4
1
Parent(s):
5eece2c
Update generator.py to electra small
Browse files- generator.py +28 -27
generator.py
CHANGED
@@ -37,11 +37,12 @@ import streamlit as st
|
|
37 |
def load_model():
|
38 |
hfm = pickle.load(open('hfmodel.sav','rb'))
|
39 |
hft = T5TokenizerFast.from_pretrained("t5-base")
|
40 |
-
tok = att.from_pretrained("
|
41 |
-
model = pickle.load(open('
|
42 |
-
return hfm, hft,tok, model
|
|
|
43 |
|
44 |
-
hfmodel, hftokenizer,
|
45 |
|
46 |
def run_model(input_string, **generator_args):
|
47 |
generator_args = {
|
@@ -67,29 +68,29 @@ def run_model(input_string, **generator_args):
|
|
67 |
# al_tokenizer = pickle.load(open('models/al_tokenizer.sav', 'rb'))
|
68 |
def QA(question, context):
|
69 |
# model_name="deepset/electra-base-squad2"
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
inputs = tokenizer(question, context, return_tensors="pt")
|
79 |
-
# Run the model, the deepset way
|
80 |
-
with torch.no_grad():
|
81 |
-
|
82 |
-
start_score = output.start_logits
|
83 |
-
end_score = output.end_logits
|
84 |
-
#Get the rel scores for the context, and calculate the most probable begginign using torch
|
85 |
-
start = torch.argmax(start_score)
|
86 |
-
end = torch.argmax(end_score)
|
87 |
-
#cinvert tokens to strings
|
88 |
-
# output = tokenizer.decode(input_ids[start:end+1], skip_special_tokens=True)
|
89 |
-
predict_answer_tokens = inputs.input_ids[0, start : end + 1]
|
90 |
-
output = tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)
|
91 |
-
output = string.capwords(output)
|
92 |
-
return f"Q. {question} \n Ans. {output}"
|
93 |
# QA("What was the first C program","The first prgram written in C was Hello World")
|
94 |
|
95 |
def gen_question(inputs):
|
|
|
37 |
def load_model():
|
38 |
hfm = pickle.load(open('hfmodel.sav','rb'))
|
39 |
hft = T5TokenizerFast.from_pretrained("t5-base")
|
40 |
+
# tok = att.from_pretrained("")
|
41 |
+
model = pickle.load(open('electra_model.sav','rb'))
|
42 |
+
# return hfm, hft,tok, model
|
43 |
+
return hfm, hft, model
|
44 |
|
45 |
+
hfmodel, hftokenizer, model = load_model()
|
46 |
|
47 |
def run_model(input_string, **generator_args):
|
48 |
generator_args = {
|
|
|
68 |
# al_tokenizer = pickle.load(open('models/al_tokenizer.sav', 'rb'))
|
69 |
def QA(question, context):
|
70 |
# model_name="deepset/electra-base-squad2"
|
71 |
+
nlp = pipeline("question-answering",model=model)
|
72 |
+
format = {
|
73 |
+
'question':question,
|
74 |
+
'context':context
|
75 |
+
}
|
76 |
+
res = nlp(format)
|
77 |
+
output = f"{question}\n{string.capwords(res['answer'])}\tscore : [{res['score']}] \n"
|
78 |
+
return output
|
79 |
+
# inputs = tokenizer(question, context, return_tensors="pt")
|
80 |
+
# # Run the model, the deepset way
|
81 |
+
# with torch.no_grad():
|
82 |
+
# output = model(**inputs)
|
83 |
+
# start_score = output.start_logits
|
84 |
+
# end_score = output.end_logits
|
85 |
+
# #Get the rel scores for the context, and calculate the most probable begginign using torch
|
86 |
+
# start = torch.argmax(start_score)
|
87 |
+
# end = torch.argmax(end_score)
|
88 |
+
# #cinvert tokens to strings
|
89 |
+
# # output = tokenizer.decode(input_ids[start:end+1], skip_special_tokens=True)
|
90 |
+
# predict_answer_tokens = inputs.input_ids[0, start : end + 1]
|
91 |
+
# output = tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)
|
92 |
+
# output = string.capwords(output)
|
93 |
+
# return f"Q. {question} \n Ans. {output}"
|
94 |
# QA("What was the first C program","The first prgram written in C was Hello World")
|
95 |
|
96 |
def gen_question(inputs):
|