jwalanthi commited on
Commit
1cd9844
·
1 Parent(s): 5eb0088

bert and albert

Browse files
Files changed (1) hide show
  1. app.py +9 -14
app.py CHANGED
@@ -7,19 +7,16 @@ import os
7
  from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams
8
 
9
 
10
- def predict (Word, Sentence, LM, Layer, Norm):
11
- models = {'bert': 'bert-base-uncased'}
 
12
  if Word not in Sentence: return "invalid input: word not in sentence"
13
- model_name = LM + str(Layer) + '_to_' + Norm
14
- lm = cwe.CWE(models[LM])
15
- if Layer not in range (lm.layers): return "invalid input: layer not in lm"
16
 
17
- model_path = hf_hub_download("jwalanthi/bert_layer8_to_binder", model_name+".ckpt", use_auth_token=os.environ['TOKEN'])
18
- label_path = hf_hub_download("jwalanthi/bert_layer8_to_binder", model_name+".txt", use_auth_token=os.environ['TOKEN'])
19
 
20
- # labels = "These are some fake features".split(" ")
21
- # vals = np.random.randint(-10,10,(5))
22
- # return model_name+" \n"+"\n".join([labels[i]+" "+str(vals[i]) for i in range(len(labels)) if vals[i]>0])
23
  model = FeatureNormPredictor.load_from_checkpoint(
24
  checkpoint_path=model_path,
25
  map_location=None
@@ -43,14 +40,12 @@ demo = gr.Interface(
43
  inputs=[
44
  "text",
45
  "text",
46
- gr.Radio(["bert", "albert"]),
47
- "number",
48
- gr.Radio(["binder", "mcrae", "buchanan"]),
49
  ],
50
  outputs=["text"],
51
  )
52
 
53
- demo.launch()
54
 
55
  if __name__ == "__main__":
56
  demo.launch(share=True)
 
7
  from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams
8
 
9
 
10
+ def predict (Word, Sentence, modelname):
11
+ models = {'Bert Layer 8 to Binder': ('bert-base-uncased', 'bert8_to_binder'),
12
+ 'Albert Layer 8 to Binder': ('albert-xxlarge-v2', 'albert8_to_binder_opt_stop')}
13
  if Word not in Sentence: return "invalid input: word not in sentence"
14
+ model_name = models[modelname][1]
15
+ lm = cwe.CWE(models[modelname][0])
 
16
 
17
+ model_path = hf_hub_download("jwalanthi/semantic-feature-classifiers", model_name+".ckpt", use_auth_token=os.environ['TOKEN'])
18
+ label_path = hf_hub_download("jwalanthi/semantic-feature-classifiers", model_name+".txt", use_auth_token=os.environ['TOKEN'])
19
 
 
 
 
20
  model = FeatureNormPredictor.load_from_checkpoint(
21
  checkpoint_path=model_path,
22
  map_location=None
 
40
  inputs=[
41
  "text",
42
  "text",
43
+ gr.Radio(["Bert Layer 8 to Binder", "Albert Layer 8 to Binder"])
 
 
44
  ],
45
  outputs=["text"],
46
  )
47
 
48
+ demo.launch(share=True)
49
 
50
  if __name__ == "__main__":
51
  demo.launch(share=True)