Commit
·
4a5d667
1
Parent(s):
291f55b
Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
from doctest import OutputChecker
|
3 |
import sys
|
4 |
import argparse
|
5 |
-
|
6 |
import re
|
7 |
import os
|
8 |
import gradio as gr
|
@@ -19,7 +19,7 @@ import requests
|
|
19 |
#from sklearn.metrics.pairwise import cosine_similarity
|
20 |
|
21 |
#device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
22 |
-
|
23 |
|
24 |
#SentenceTransformer('stsb-distilbert-base', device=device)
|
25 |
|
@@ -108,8 +108,8 @@ def Visual_re_ranker(caption, visual_context_label, visual_context_prob):
|
|
108 |
caption = caption
|
109 |
visual_context_label= visual_context_label
|
110 |
visual_context_prob = visual_context_prob
|
111 |
-
caption_emb =
|
112 |
-
visual_context_label_emb =
|
113 |
|
114 |
|
115 |
sim = cosine_scores = util.pytorch_cos_sim(caption_emb, visual_context_label_emb)
|
|
|
2 |
from doctest import OutputChecker
|
3 |
import sys
|
4 |
import argparse
|
5 |
+
import torch
|
6 |
import re
|
7 |
import os
|
8 |
import gradio as gr
|
|
|
19 |
#from sklearn.metrics.pairwise import cosine_similarity
|
20 |
|
21 |
#device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
22 |
+
model_sts = gr.Interface.load('huggingface/sentence-transformers/stsb-distilbert-base')
|
23 |
|
24 |
#SentenceTransformer('stsb-distilbert-base', device=device)
|
25 |
|
|
|
108 |
caption = caption
|
109 |
visual_context_label= visual_context_label
|
110 |
visual_context_prob = visual_context_prob
|
111 |
+
caption_emb = model_sts.encode(caption, convert_to_tensor=True)
|
112 |
+
visual_context_label_emb = model_sts.encode(visual_context_label, convert_to_tensor=True)
|
113 |
|
114 |
|
115 |
sim = cosine_scores = util.pytorch_cos_sim(caption_emb, visual_context_label_emb)
|