Singularity666 commited on
Commit
257f974
·
1 Parent(s): d29c405

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +3 -9
main.py CHANGED
@@ -8,7 +8,7 @@ import timm
8
  import torch
9
  import torch.nn.functional as F
10
 
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
  class CFG:
14
  debug = False
@@ -296,7 +296,7 @@ def get_image_embeddings(image):
296
  return image_embeddings
297
 
298
 
299
- def predict_caption(image, model, text_embeddings, captions, n=1):
300
  # get the image embeddings
301
  image_embeddings = get_image_embeddings(image)
302
  if image_embeddings is None:
@@ -332,10 +332,4 @@ def get_text_embeddings(valid_df):
332
  text_embeddings = model.text_projection(text_features)
333
  valid_text_embeddings.append(text_embeddings)
334
 
335
- return model, torch.cat(valid_text_embeddings)
336
-
337
- def get_alternative_caption(image, model, text_embeddings, captions, n=1):
338
- matches = predict_caption(
339
- image, model, text_embeddings, captions, n+1
340
- )
341
- return matches[-1]
 
8
  import torch
9
  import torch.nn.functional as F
10
 
11
+ device = torch.device("cpu")
12
 
13
  class CFG:
14
  debug = False
 
296
  return image_embeddings
297
 
298
 
299
+ def predict_caption(image, model, text_embeddings, captions, n=2):
300
  # get the image embeddings
301
  image_embeddings = get_image_embeddings(image)
302
  if image_embeddings is None:
 
332
  text_embeddings = model.text_projection(text_features)
333
  valid_text_embeddings.append(text_embeddings)
334
 
335
+ return model, torch.cat(valid_text_embeddings)