Spaces:
Running
Running
prediction added
Browse files
app.py
CHANGED
@@ -7,19 +7,30 @@ import numpy as np
|
|
7 |
from model import FeatureNormPredictor
|
8 |
|
9 |
|
10 |
-
def predict (
|
11 |
-
if
|
12 |
-
model_name =
|
13 |
lm = cwe.CWE('bert-base-uncased')
|
14 |
-
if
|
|
|
15 |
model = FeatureNormPredictor.load_from_checkpoint(
|
16 |
checkpoint_path=model_name+'.ckpt',
|
17 |
map_location=None
|
18 |
)
|
19 |
model.eval()
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
demo = gr.Interface(
|
25 |
fn=predict,
|
@@ -28,7 +39,7 @@ demo = gr.Interface(
|
|
28 |
"text",
|
29 |
gr.Radio(["bert", "roberta", "electra"]),
|
30 |
"number",
|
31 |
-
gr.Radio(["
|
32 |
],
|
33 |
outputs=["text"],
|
34 |
)
|
|
|
7 |
from model import FeatureNormPredictor
|
8 |
|
9 |
|
10 |
+
def predict (Word, Sentence, LM, Layer, Norm):
|
11 |
+
if Word not in Sentence: return "invalid input: word not in sentence"
|
12 |
+
model_name = LM + str(Layer) + '_to_' + Norm
|
13 |
lm = cwe.CWE('bert-base-uncased')
|
14 |
+
if Layer not in range (lm.layers): return "invalid input: layer not in lm"
|
15 |
+
|
16 |
model = FeatureNormPredictor.load_from_checkpoint(
|
17 |
checkpoint_path=model_name+'.ckpt',
|
18 |
map_location=None
|
19 |
)
|
20 |
model.eval()
|
21 |
+
|
22 |
+
with open (model_name+'.txt', "r") as file:
|
23 |
+
labels = [line.rstrip() for line in file.readlines()]
|
24 |
+
|
25 |
+
data = (Word, Sentence)
|
26 |
+
embs = lm.extract_representation(data, layer=8)
|
27 |
+
avg = embs.sum(0)/len(data)
|
28 |
+
pred = torch.nn.functional.relu(model(avg))
|
29 |
+
pred = pred.squeeze(0)
|
30 |
+
pred_list = pred.detach().numpy().tolist()
|
31 |
+
|
32 |
+
output = [labels[i]+'\t'+str(pred_list[i]) for i in range(len(labels)) if i > 0.0]
|
33 |
+
return "\n".join(output)
|
34 |
|
35 |
demo = gr.Interface(
|
36 |
fn=predict,
|
|
|
39 |
"text",
|
40 |
gr.Radio(["bert", "roberta", "electra"]),
|
41 |
"number",
|
42 |
+
gr.Radio(["binder", "mcrae", "buchanan"]),
|
43 |
],
|
44 |
outputs=["text"],
|
45 |
)
|