Spaces:
Runtime error
Runtime error
Commit
·
257f974
1
Parent(s):
d29c405
Update main.py
Browse files
main.py
CHANGED
@@ -8,7 +8,7 @@ import timm
|
|
8 |
import torch
|
9 |
import torch.nn.functional as F
|
10 |
|
11 |
-
device = torch.device("
|
12 |
|
13 |
class CFG:
|
14 |
debug = False
|
@@ -296,7 +296,7 @@ def get_image_embeddings(image):
|
|
296 |
return image_embeddings
|
297 |
|
298 |
|
299 |
-
def predict_caption(image, model, text_embeddings, captions, n=
|
300 |
# get the image embeddings
|
301 |
image_embeddings = get_image_embeddings(image)
|
302 |
if image_embeddings is None:
|
@@ -332,10 +332,4 @@ def get_text_embeddings(valid_df):
|
|
332 |
text_embeddings = model.text_projection(text_features)
|
333 |
valid_text_embeddings.append(text_embeddings)
|
334 |
|
335 |
-
return model, torch.cat(valid_text_embeddings)
|
336 |
-
|
337 |
-
def get_alternative_caption(image, model, text_embeddings, captions, n=1):
|
338 |
-
matches = predict_caption(
|
339 |
-
image, model, text_embeddings, captions, n+1
|
340 |
-
)
|
341 |
-
return matches[-1]
|
|
|
8 |
import torch
|
9 |
import torch.nn.functional as F
|
10 |
|
11 |
+
device = torch.device("cpu")
|
12 |
|
13 |
class CFG:
|
14 |
debug = False
|
|
|
296 |
return image_embeddings
|
297 |
|
298 |
|
299 |
+
def predict_caption(image, model, text_embeddings, captions, n=2):
|
300 |
# get the image embeddings
|
301 |
image_embeddings = get_image_embeddings(image)
|
302 |
if image_embeddings is None:
|
|
|
332 |
text_embeddings = model.text_projection(text_features)
|
333 |
valid_text_embeddings.append(text_embeddings)
|
334 |
|
335 |
+
return model, torch.cat(valid_text_embeddings)
|
|
|
|
|
|
|
|
|
|
|
|