Spaces:
Running
Running
bert and albert
Browse files
app.py
CHANGED
@@ -7,19 +7,16 @@ import os
|
|
7 |
from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams
|
8 |
|
9 |
|
10 |
-
def predict (Word, Sentence,
|
11 |
-
models = {'
|
|
|
12 |
if Word not in Sentence: return "invalid input: word not in sentence"
|
13 |
-
model_name =
|
14 |
-
lm = cwe.CWE(models[
|
15 |
-
if Layer not in range (lm.layers): return "invalid input: layer not in lm"
|
16 |
|
17 |
-
model_path = hf_hub_download("jwalanthi/
|
18 |
-
label_path = hf_hub_download("jwalanthi/
|
19 |
|
20 |
-
# labels = "These are some fake features".split(" ")
|
21 |
-
# vals = np.random.randint(-10,10,(5))
|
22 |
-
# return model_name+" \n"+"\n".join([labels[i]+" "+str(vals[i]) for i in range(len(labels)) if vals[i]>0])
|
23 |
model = FeatureNormPredictor.load_from_checkpoint(
|
24 |
checkpoint_path=model_path,
|
25 |
map_location=None
|
@@ -43,14 +40,12 @@ demo = gr.Interface(
|
|
43 |
inputs=[
|
44 |
"text",
|
45 |
"text",
|
46 |
-
gr.Radio(["
|
47 |
-
"number",
|
48 |
-
gr.Radio(["binder", "mcrae", "buchanan"]),
|
49 |
],
|
50 |
outputs=["text"],
|
51 |
)
|
52 |
|
53 |
-
demo.launch()
|
54 |
|
55 |
if __name__ == "__main__":
|
56 |
demo.launch(share=True)
|
|
|
7 |
from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams
|
8 |
|
9 |
|
10 |
+
def predict (Word, Sentence, modelname):
|
11 |
+
models = {'Bert Layer 8 to Binder': ('bert-base-uncased', 'bert8_to_binder'),
|
12 |
+
'Albert Layer 8 to Binder': ('albert-xxlarge-v2', 'albert8_to_binder_opt_stop')}
|
13 |
if Word not in Sentence: return "invalid input: word not in sentence"
|
14 |
+
model_name = models[modelname][1]
|
15 |
+
lm = cwe.CWE(models[modelname][0])
|
|
|
16 |
|
17 |
+
model_path = hf_hub_download("jwalanthi/semantic-feature-classifiers", model_name+".ckpt", use_auth_token=os.environ['TOKEN'])
|
18 |
+
label_path = hf_hub_download("jwalanthi/semantic-feature-classifiers", model_name+".txt", use_auth_token=os.environ['TOKEN'])
|
19 |
|
|
|
|
|
|
|
20 |
model = FeatureNormPredictor.load_from_checkpoint(
|
21 |
checkpoint_path=model_path,
|
22 |
map_location=None
|
|
|
40 |
inputs=[
|
41 |
"text",
|
42 |
"text",
|
43 |
+
gr.Radio(["Bert Layer 8 to Binder", "Albert Layer 8 to Binder"])
|
|
|
|
|
44 |
],
|
45 |
outputs=["text"],
|
46 |
)
|
47 |
|
48 |
+
demo.launch(share=True)
|
49 |
|
50 |
if __name__ == "__main__":
|
51 |
demo.launch(share=True)
|