Spaces:
Runtime error
Runtime error
Commit
·
3be5505
1
Parent(s):
bc93555
Update main.py
Browse files
main.py
CHANGED
@@ -8,7 +8,7 @@ import timm
|
|
8 |
import torch
|
9 |
import torch.nn.functional as F
|
10 |
|
11 |
-
device = torch.device("cpu")
|
12 |
|
13 |
class CFG:
|
14 |
debug = False
|
@@ -296,7 +296,7 @@ def get_image_embeddings(image):
|
|
296 |
return image_embeddings
|
297 |
|
298 |
|
299 |
-
def predict_caption(image, model, text_embeddings, captions, n=
|
300 |
# get the image embeddings
|
301 |
image_embeddings = get_image_embeddings(image)
|
302 |
if image_embeddings is None:
|
@@ -332,4 +332,4 @@ def get_text_embeddings(valid_df):
|
|
332 |
text_embeddings = model.text_projection(text_features)
|
333 |
valid_text_embeddings.append(text_embeddings)
|
334 |
|
335 |
-
return model, torch.cat(valid_text_embeddings)
|
|
|
8 |
import torch
|
9 |
import torch.nn.functional as F
|
10 |
|
11 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
|
13 |
class CFG:
|
14 |
debug = False
|
|
|
296 |
return image_embeddings
|
297 |
|
298 |
|
299 |
+
def predict_caption(image, model, text_embeddings, captions, n=1):
|
300 |
# get the image embeddings
|
301 |
image_embeddings = get_image_embeddings(image)
|
302 |
if image_embeddings is None:
|
|
|
332 |
text_embeddings = model.text_projection(text_features)
|
333 |
valid_text_embeddings.append(text_embeddings)
|
334 |
|
335 |
+
return model, torch.cat(valid_text_embeddings)
|