jwalanthi commited on
Commit
541d16e
·
1 Parent(s): ab801fa

it works locally!!

Browse files
Files changed (1) hide show
  1. app.py +26 -23
app.py CHANGED
@@ -4,36 +4,36 @@ from minicons import cwe
4
  import pandas as pd
5
  import numpy as np
6
 
7
- from model import FeatureNormPredictor
8
 
9
 
10
  def predict (Word, Sentence, LM, Layer, Norm):
 
11
  if Word not in Sentence: return "invalid input: word not in sentence"
12
  model_name = LM + str(Layer) + '_to_' + Norm
13
- lm = cwe.CWE('bert-base-uncased')
14
  if Layer not in range (lm.layers): return "invalid input: layer not in lm"
15
 
16
- labels = "These are some fake features".split(" ")
17
- vals = np.random.randint(-10,10,(5))
18
- return model_name+" \n"+"\n".join([labels[i]+" "+str(vals[i]) for i in range(len(labels)) if vals[i]>0])
19
- # model = FeatureNormPredictor.load_from_checkpoint(
20
- # checkpoint_path=model_name+'.ckpt',
21
- # map_location=None
22
- # )
23
- # model.eval()
24
-
25
- # with open (model_name+'.txt', "r") as file:
26
- # labels = [line.rstrip() for line in file.readlines()]
27
-
28
- # data = (Word, Sentence)
29
- # embs = lm.extract_representation(data, layer=8)
30
- # avg = embs.sum(0)/len(data)
31
- # pred = torch.nn.functional.relu(model(avg))
32
- # pred = pred.squeeze(0)
33
- # pred_list = pred.detach().numpy().tolist()
34
 
35
- # output = [labels[i]+'\t'+str(pred_list[i]) for i in range(len(labels)) if pred_list[i] > 0.0]
36
- # return "All Positive Predicted Values:\n"+"\n".join(output)
37
 
38
  demo = gr.Interface(
39
  fn=predict,
@@ -47,4 +47,7 @@ demo = gr.Interface(
47
  outputs=["text"],
48
  )
49
 
50
- demo.launch()
 
 
 
 
4
  import pandas as pd
5
  import numpy as np
6
 
7
+ from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams
8
 
9
 
10
  def predict (Word, Sentence, LM, Layer, Norm):
11
+ models = {'bert': 'bert-base-uncased'}
12
  if Word not in Sentence: return "invalid input: word not in sentence"
13
  model_name = LM + str(Layer) + '_to_' + Norm
14
+ lm = cwe.CWE(models[LM])
15
  if Layer not in range (lm.layers): return "invalid input: layer not in lm"
16
 
17
+ # labels = "These are some fake features".split(" ")
18
+ # vals = np.random.randint(-10,10,(5))
19
+ # return model_name+" \n"+"\n".join([labels[i]+" "+str(vals[i]) for i in range(len(labels)) if vals[i]>0])
20
+ model = FeatureNormPredictor.load_from_checkpoint(
21
+ checkpoint_path='models/'+model_name+'.ckpt',
22
+ map_location=None
23
+ )
24
+ model.eval()
25
+
26
+ with open ('models/'+model_name+'.txt', "r") as file:
27
+ labels = [line.rstrip() for line in file.readlines()]
28
+
29
+ data = (Sentence, Word)
30
+ emb = lm.extract_representation(data, layer=8)
31
+ pred = torch.nn.functional.relu(model(emb))
32
+ pred = pred.squeeze(0)
33
+ pred_list = pred.detach().numpy().tolist()
 
34
 
35
+ output = [labels[i]+'\t'+str(pred_list[i]) for i in range(len(labels)) if pred_list[i] > 0.0]
36
+ return "All Positive Predicted Values:\n"+"\n".join(output)
37
 
38
  demo = gr.Interface(
39
  fn=predict,
 
47
  outputs=["text"],
48
  )
49
 
50
+ demo.launch()
51
+
52
+ if __name__ == "__main__":
53
+ demo.launch()