jwalanthi commited on
Commit
4be65da
·
1 Parent(s): 966ec6e

dummy output again until figure out model upload

Browse files
Files changed (1) hide show
  1. app.py +21 -18
app.py CHANGED
@@ -13,31 +13,34 @@ def predict (Word, Sentence, LM, Layer, Norm):
13
  lm = cwe.CWE('bert-base-uncased')
14
  if Layer not in range (lm.layers): return "invalid input: layer not in lm"
15
 
16
- model = FeatureNormPredictor.load_from_checkpoint(
17
- checkpoint_path=model_name+'.ckpt',
18
- map_location=None
19
- )
20
- model.eval()
21
-
22
- with open (model_name+'.txt', "r") as file:
23
- labels = [line.rstrip() for line in file.readlines()]
24
-
25
- data = (Word, Sentence)
26
- embs = lm.extract_representation(data, layer=8)
27
- avg = embs.sum(0)/len(data)
28
- pred = torch.nn.functional.relu(model(avg))
29
- pred = pred.squeeze(0)
30
- pred_list = pred.detach().numpy().tolist()
 
 
 
31
 
32
- output = [labels[i]+'\t'+str(pred_list[i]) for i in range(len(labels)) if i > 0.0]
33
- return "All Positive Predicted Values:\n"+"\n".join(output)
34
 
35
  demo = gr.Interface(
36
  fn=predict,
37
  inputs=[
38
  "text",
39
  "text",
40
- gr.Radio(["bert", "roberta", "electra"]),
41
  "number",
42
  gr.Radio(["binder", "mcrae", "buchanan"]),
43
  ],
 
13
  lm = cwe.CWE('bert-base-uncased')
14
  if Layer not in range (lm.layers): return "invalid input: layer not in lm"
15
 
16
+ labels = "These are some fake features".split(" ")
17
+ vals = np.random.randint(-10,10,(5))
18
+ print(model_name+" \n"+"\n".join([labels[i]+" "+str(vals[i]) for i in range(len(labels)) if vals[i]>0]))
19
+ # model = FeatureNormPredictor.load_from_checkpoint(
20
+ # checkpoint_path=model_name+'.ckpt',
21
+ # map_location=None
22
+ # )
23
+ # model.eval()
24
+
25
+ # with open (model_name+'.txt', "r") as file:
26
+ # labels = [line.rstrip() for line in file.readlines()]
27
+
28
+ # data = (Word, Sentence)
29
+ # embs = lm.extract_representation(data, layer=8)
30
+ # avg = embs.sum(0)/len(data)
31
+ # pred = torch.nn.functional.relu(model(avg))
32
+ # pred = pred.squeeze(0)
33
+ # pred_list = pred.detach().numpy().tolist()
34
 
35
+ # output = [labels[i]+'\t'+str(pred_list[i]) for i in range(len(labels)) if pred_list[i] > 0.0]
36
+ # return "All Positive Predicted Values:\n"+"\n".join(output)
37
 
38
  demo = gr.Interface(
39
  fn=predict,
40
  inputs=[
41
  "text",
42
  "text",
43
+ gr.Radio(["bert", "albert"]),
44
  "number",
45
  gr.Radio(["binder", "mcrae", "buchanan"]),
46
  ],