File size: 1,496 Bytes
a9d25c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration  # , BitsAndBytesConfig
from .env_utils import get_device, low_vram_mode

device = get_device()

blip2_model_id = "Salesforce/blip2-opt-2.7b"  # or replace with your local model path
blip2_precision = torch.bfloat16

# Load BLIP2 model and processor from HuggingFace
blip2_processor = Blip2Processor.from_pretrained(blip2_model_id)
if low_vram_mode:
    blip2_model = Blip2ForConditionalGeneration.from_pretrained(
        blip2_model_id,
        torch_dtype=blip2_precision,
        device_map=device,
        # quantization_config = BitsAndBytesConfig(load_in_8bit=True) if low_vram_mode else None,        # ZeroGPU does not support quantization.
    ).eval()
else:
    blip2_model = Blip2ForConditionalGeneration.from_pretrained(blip2_model_id, torch_dtype=blip2_precision, device_map=device).eval()


def blip2_caption(raw_image):
    # unconditional image captioning
    inputs  = blip2_processor(raw_image, return_tensors="pt")
    inputs  = inputs.to(device=device, dtype=blip2_precision)
    out     = blip2_model.generate(**inputs)
    caption = blip2_processor.decode(out[0], skip_special_tokens=True)
    return caption


# if __name__ == "__main__":
#     from PIL import Image

#     # Test the RAM++ model
#     image_path = os.path.join(os.path.dirname(__file__), "../sources/test_imgs/1.jpg")
#     image = Image.open(image_path)
#     result = blip2_caption(image)
#     print(result)