jwalanthi commited on
Commit
f50f2ed
·
1 Parent(s): 73d65b8

prediction added

Browse files
Files changed (1) hide show
  1. app.py +19 -8
app.py CHANGED
@@ -7,19 +7,30 @@ import numpy as np
7
  from model import FeatureNormPredictor
8
 
9
 
10
- def predict (word, sentence, lm_name, layer, norm):
11
- if word not in sentence: return "invalid input: word not in sentence"
12
- model_name = lm_name + str(layer) + '_to_' + norm
13
  lm = cwe.CWE('bert-base-uncased')
14
- if layer not in range (lm.layers): return "invalid input: layer not in lm"
 
15
  model = FeatureNormPredictor.load_from_checkpoint(
16
  checkpoint_path=model_name+'.ckpt',
17
  map_location=None
18
  )
19
  model.eval()
20
- inputs = [word, sentence, lm_name, str(layer), norm]
21
- outputs = [input+'\t'+str(np.random.randint(0,100, size=1)[0]) for input in inputs]
22
- return "\n".join(outputs)
 
 
 
 
 
 
 
 
 
 
23
 
24
  demo = gr.Interface(
25
  fn=predict,
@@ -28,7 +39,7 @@ demo = gr.Interface(
28
  "text",
29
  gr.Radio(["bert", "roberta", "electra"]),
30
  "number",
31
- gr.Radio(["Binder", "McRae", "Buchanan"]),
32
  ],
33
  outputs=["text"],
34
  )
 
7
  from model import FeatureNormPredictor
8
 
9
 
10
+ def predict (Word, Sentence, LM, Layer, Norm):
11
+ if Word not in Sentence: return "invalid input: word not in sentence"
12
+ model_name = LM + str(Layer) + '_to_' + Norm
13
  lm = cwe.CWE('bert-base-uncased')
14
+ if Layer not in range (lm.layers): return "invalid input: layer not in lm"
15
+
16
  model = FeatureNormPredictor.load_from_checkpoint(
17
  checkpoint_path=model_name+'.ckpt',
18
  map_location=None
19
  )
20
  model.eval()
21
+
22
+ with open (model_name+'.txt', "r") as file:
23
+ labels = [line.rstrip() for line in file.readlines()]
24
+
25
+ data = (Word, Sentence)
26
+ embs = lm.extract_representation(data, layer=8)
27
+ avg = embs.sum(0)/len(data)
28
+ pred = torch.nn.functional.relu(model(avg))
29
+ pred = pred.squeeze(0)
30
+ pred_list = pred.detach().numpy().tolist()
31
+
32
+ output = [labels[i]+'\t'+str(pred_list[i]) for i in range(len(labels)) if i > 0.0]
33
+ return "\n".join(output)
34
 
35
  demo = gr.Interface(
36
  fn=predict,
 
39
  "text",
40
  gr.Radio(["bert", "roberta", "electra"]),
41
  "number",
42
+ gr.Radio(["binder", "mcrae", "buchanan"]),
43
  ],
44
  outputs=["text"],
45
  )