import re | |
import torch | |
from torchvision import transforms | |
from transformers import BlipForConditionalGeneration, BlipProcessor | |
from internals.util.commons import download_image | |
class Image2Text: | |
__loaded = False | |
def load(self): | |
if self.__loaded: | |
return | |
self.processor = BlipProcessor.from_pretrained( | |
"Salesforce/blip-image-captioning-large" | |
) | |
self.model = BlipForConditionalGeneration.from_pretrained( | |
"Salesforce/blip-image-captioning-large", torch_dtype=torch.float16 | |
).to("cuda") | |
self.__loaded = True | |
def process(self, imageUrl: str) -> str: | |
self.load() | |
image = download_image(imageUrl).resize((512, 512)) | |
inputs = self.processor.__call__(image, return_tensors="pt").to( | |
"cuda", torch.float16 | |
) | |
output_ids = self.model.generate( | |
**inputs, do_sample=False, top_p=0.9, max_length=128 | |
) | |
output_text = self.processor.batch_decode(output_ids) | |
print(output_text) | |
output_text = output_text[0] | |
output_text = re.sub("</.>|\\n|\[SEP\]", "", output_text) | |
return output_text | |