jwalanthi commited on
Commit
ef7044d
·
1 Parent(s): 923ec13

more bert models

Browse files
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -7,15 +7,19 @@ import os
7
  from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams
8
 
9
 
10
- def predict (Word, Sentence, modelname):
11
- models = {'Bert Layer 8 to Binder': ('bert-base-uncased', 'bert8_to_binder'),
12
- 'Albert Layer 8 to Binder': ('albert-xxlarge-v2', 'albert8_to_binder_opt_stop')}
 
13
  if Word not in Sentence: return "invalid input: word not in sentence"
14
- model_name = models[modelname][1]
15
- lm = cwe.CWE(models[modelname][0])
 
16
 
17
- model_path = hf_hub_download("jwalanthi/semantic-feature-classifiers", model_name+".ckpt", use_auth_token=os.environ['TOKEN'])
18
- label_path = hf_hub_download("jwalanthi/semantic-feature-classifiers", model_name+".txt", use_auth_token=os.environ['TOKEN'])
 
 
19
 
20
  model = FeatureNormPredictor.load_from_checkpoint(
21
  checkpoint_path=model_path,
@@ -40,7 +44,9 @@ demo = gr.Interface(
40
  inputs=[
41
  "text",
42
  "text",
43
- gr.Radio(["Bert Layer 8 to Binder", "Albert Layer 8 to Binder"])
 
 
44
  ],
45
  outputs=["text"],
46
  )
 
7
  from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams
8
 
9
 
10
+ def predict (Word, Sentence, llm, norm, layer):
11
+ models = {'BERT': 'bert-base-uncased',
12
+ 'ALBERT': 'albert-xxlarge-v2',
13
+ 'roBERTa': 'roberta-base'}
14
  if Word not in Sentence: return "invalid input: word not in sentence"
15
+ model_name_hf = llm.lower()
16
+ norm_name_hf = norm.lower()
17
+ lm = cwe.CWE(models[llm])
18
 
19
+ full_name_hf = f"jwalanthi/semantic-feature-classifiers/{model_name_hf}_models_all/{model_name_hf}_to_{norm_name_hf}_layer{layer}"
20
+
21
+ model_path = hf_hub_download(f"{full_name_hf}.ckpt", use_auth_token=os.environ['TOKEN'])
22
+ label_path = hf_hub_download(f"{full_name_hf}.txt", use_auth_token=os.environ['TOKEN'])
23
 
24
  model = FeatureNormPredictor.load_from_checkpoint(
25
  checkpoint_path=model_path,
 
44
  inputs=[
45
  "text",
46
  "text",
47
+ gr.Radio(["BERT"]),
48
+ gr.Radio("Binder", "McRae", "Buchanan"),
49
+ gr.Slider(0,12)
50
  ],
51
  outputs=["text"],
52
  )