File size: 461 Bytes
5bc8f00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import open_clip
import torch
from PIL import Image

model, _, transform = open_clip.create_model_and_transforms(
  model_name="coca_ViT-L-14",
  pretrained="mscoco_finetuned_laion2B-s13B-b90k"
)

def get_captions(image):
    im = transform(image).unsqueeze(0)
    
    with torch.no_grad(), torch.cuda.amp.autocast():
        generated = model.generate(im)
    
    return open_clip.decode(generated[0]).split("<end_of_text>")[0].replace("<start_of_text>", "")