jwalanthi commited on
Commit
25004af
·
1 Parent(s): 11ca206
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -3,11 +3,12 @@ import torch
3
  from minicons import cwe
4
  from huggingface_hub import hf_hub_download
5
  import os
 
6
 
7
  from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams
8
 
9
 
10
- def predict (Word, Sentence, LLM, Norm, Layer):
11
  models = {'BERT': 'bert-base-uncased',
12
  'ALBERT': 'albert-xxlarge-v2',
13
  'roBERTa': 'roberta-base'}
@@ -33,12 +34,16 @@ def predict (Word, Sentence, LLM, Norm, Layer):
33
  labels = [line.rstrip() for line in file.readlines()]
34
 
35
  data = (Sentence, Word)
36
- emb = lm.extract_representation(data, layer=8)
37
  pred = torch.nn.functional.relu(model(emb))
38
  pred = pred.squeeze(0)
39
  pred_list = pred.detach().numpy().tolist()
 
 
 
 
40
 
41
- output = [labels[i]+'\t\t\t\t\t\t\t'+str(pred_list[i]) for i in range(len(labels)) if pred_list[i] > 0.0]
42
  return "All Positive Predicted Values:\n"+"\n".join(output)
43
 
44
  demo = gr.Interface(
 
3
  from minicons import cwe
4
  from huggingface_hub import hf_hub_download
5
  import os
6
+ import pandas as pd
7
 
8
  from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams
9
 
10
 
11
+ def predict (Sentence, Word, LLM, Norm, Layer):
12
  models = {'BERT': 'bert-base-uncased',
13
  'ALBERT': 'albert-xxlarge-v2',
14
  'roBERTa': 'roberta-base'}
 
34
  labels = [line.rstrip() for line in file.readlines()]
35
 
36
  data = (Sentence, Word)
37
+ emb = lm.extract_representation(data, layer=Layer)
38
  pred = torch.nn.functional.relu(model(emb))
39
  pred = pred.squeeze(0)
40
  pred_list = pred.detach().numpy().tolist()
41
+
42
+ df = pd.DataFrame({'feature':labels, 'value':pred_list})
43
+ df = df.sort_values('values', ascending=False)
44
+ df = df[df['values'] > 0]
45
 
46
+ output = [df['features'][i]+'\t\t\t\t\t\t\t'+str(df['values'][i]) for i in range(df.shape(0))]
47
  return "All Positive Predicted Values:\n"+"\n".join(output)
48
 
49
  demo = gr.Interface(