File size: 3,119 Bytes
b73c955
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import torch.nn as nn

from .blip2 import Blip2Base


class MiniGPT4(Blip2Base):
    """
    BLIP2 GPT-LLAMA model.
    """

    def __init__(
        self,
        llama_hidden_size=5120,
        vision_dtype=torch.float32,
        vision_device=torch.device("cpu"),
        projector_dtype=torch.float32,
        projector_device=torch.device("cpu"),
        vit_model="eva_clip_g",
        q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
        img_size=224,
        drop_path_rate=0,
        use_grad_checkpoint=False,
        vit_precision="fp32",
        num_query_token=32,
        max_txt_len=32,
        end_sym='\n'
    ):
        super().__init__()

        self.vision_dtype = vision_dtype
        self.vision_device = vision_device
        self.projector_dtype = projector_dtype
        self.projector_device = projector_device

        self.tokenizer = self.init_tokenizer()

        print('Loading VIT')
        self.visual_encoder, self.ln_vision = self.init_vision_encoder(
            vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
        )
        self.visual_encoder = self.visual_encoder.eval().to(self.vision_device, dtype=self.vision_dtype)
        self.ln_vision = self.ln_vision.eval().to(self.vision_device, dtype=self.vision_dtype)
        print('Loading VIT Done')

        print('Loading Q-Former')
        self.Qformer, self.query_tokens = self.init_Qformer(
            num_query_token, self.visual_encoder.num_features
        )
        self.Qformer.cls = None
        self.Qformer.bert.embeddings.word_embeddings = None
        self.Qformer.bert.embeddings.position_embeddings = None
        for layer in self.Qformer.bert.encoder.layer:
            layer.output = None
            layer.intermediate = None
        self.load_from_pretrained(url_or_filename=q_former_model)

        self.Qformer = self.Qformer.eval().to(self.projector_device, dtype=self.projector_dtype)
        print('Loading Q-Former Done')

        self.llama_proj = nn.Linear(
            self.Qformer.config.hidden_size, llama_hidden_size
        ).to(self.projector_device, dtype=self.projector_dtype)
        self.max_txt_len = max_txt_len
        self.end_sym = end_sym


    def encode_img(self, image):
        image = image.to(self.vision_device, dtype=self.vision_dtype)

        with torch.no_grad():
            image_embeds = self.ln_vision(self.visual_encoder(image)).to(self.projector_device, dtype=self.projector_dtype)
            image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.projector_device)

            query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1).to(self.projector_device)
            query_output = self.Qformer.bert(
                query_embeds=query_tokens,
                encoder_hidden_states=image_embeds,
                encoder_attention_mask=image_atts,
                return_dict=True,
            )

            inputs_llama = self.llama_proj(query_output.last_hidden_state)
        return inputs_llama