Commit
·
042dd04
1
Parent(s):
d0e6d99
Update app.py
Browse files
app.py
CHANGED
@@ -18,9 +18,9 @@ from sentence_transformers import SentenceTransformer, util
|
|
18 |
#from sklearn.metrics.pairwise import cosine_similarity
|
19 |
|
20 |
#device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
21 |
-
model_sts = gr.Interface.load('huggingface/sentence-transformers/stsb-distilbert-base')
|
22 |
|
23 |
-
|
24 |
|
25 |
#batch_size = 1
|
26 |
#scorer = LMScorer.from_pretrained('gpt2' , device=device, batch_size=batch_size)
|
@@ -43,14 +43,14 @@ def softmax(x):
|
|
43 |
|
44 |
# Load pre-trained model
|
45 |
|
46 |
-
|
47 |
|
48 |
-
model = gr.Interface.load('huggingface/distilgpt2', output_hidden_states = True, output_attentions = True)
|
49 |
|
50 |
#model.eval()
|
51 |
-
tokenizer = gr.Interface.load('huggingface/distilgpt2')
|
52 |
|
53 |
-
|
54 |
#tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
|
55 |
|
56 |
|
@@ -130,8 +130,8 @@ def Visual_re_ranker(caption, visual_context_label, visual_context_prob):
|
|
130 |
demo = gr.Interface(
|
131 |
fn=Visual_re_ranker,
|
132 |
description="Demo for Belief Revision based Caption Re-ranker with Visual Semantic Information",
|
133 |
-
|
134 |
-
outputs=[gr.Textbox(value="Language Model Score") , gr.Textbox(value="Semantic Similarity Score"), gr.Textbox(value="Belief revision score via visual context")],
|
135 |
-
|
136 |
)
|
137 |
demo.launch()
|
|
|
18 |
#from sklearn.metrics.pairwise import cosine_similarity
|
19 |
|
20 |
#device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
21 |
+
#model_sts = gr.Interface.load('huggingface/sentence-transformers/stsb-distilbert-base')
|
22 |
|
23 |
+
model_sts = SentenceTransformer('stsb-distilbert-base')
|
24 |
|
25 |
#batch_size = 1
|
26 |
#scorer = LMScorer.from_pretrained('gpt2' , device=device, batch_size=batch_size)
|
|
|
43 |
|
44 |
# Load pre-trained model
|
45 |
|
46 |
+
model = GPT2LMHeadModel.from_pretrained('distilgpt2', output_hidden_states = True, output_attentions = True)
|
47 |
|
48 |
+
#model = gr.Interface.load('huggingface/distilgpt2', output_hidden_states = True, output_attentions = True)
|
49 |
|
50 |
#model.eval()
|
51 |
+
#tokenizer = gr.Interface.load('huggingface/distilgpt2')
|
52 |
|
53 |
+
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
|
54 |
#tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
|
55 |
|
56 |
|
|
|
130 |
demo = gr.Interface(
|
131 |
fn=Visual_re_ranker,
|
132 |
description="Demo for Belief Revision based Caption Re-ranker with Visual Semantic Information",
|
133 |
+
inputs=[gr.Textbox(value="a city street filled with traffic at night") , gr.Textbox(value="traffic"), gr.Textbox(value="0.7458009")],
|
134 |
+
#outputs=[gr.Textbox(value="Language Model Score") , gr.Textbox(value="Semantic Similarity Score"), gr.Textbox(value="Belief revision score via visual context")],
|
135 |
+
outputs="label",
|
136 |
)
|
137 |
demo.launch()
|