text-generation-webui
/
extensions
/multimodal
/pipelines
/minigpt-4-pipeline
/minigpt4
/mini_gpt4.py
import torch | |
import torch.nn as nn | |
from .blip2 import Blip2Base | |
class MiniGPT4(Blip2Base): | |
""" | |
BLIP2 GPT-LLAMA model. | |
""" | |
def __init__( | |
self, | |
llama_hidden_size=5120, | |
vision_dtype=torch.float32, | |
vision_device=torch.device("cpu"), | |
projector_dtype=torch.float32, | |
projector_device=torch.device("cpu"), | |
vit_model="eva_clip_g", | |
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", | |
img_size=224, | |
drop_path_rate=0, | |
use_grad_checkpoint=False, | |
vit_precision="fp32", | |
num_query_token=32, | |
max_txt_len=32, | |
end_sym='\n' | |
): | |
super().__init__() | |
self.vision_dtype = vision_dtype | |
self.vision_device = vision_device | |
self.projector_dtype = projector_dtype | |
self.projector_device = projector_device | |
self.tokenizer = self.init_tokenizer() | |
print('Loading VIT') | |
self.visual_encoder, self.ln_vision = self.init_vision_encoder( | |
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision | |
) | |
self.visual_encoder = self.visual_encoder.eval().to(self.vision_device, dtype=self.vision_dtype) | |
self.ln_vision = self.ln_vision.eval().to(self.vision_device, dtype=self.vision_dtype) | |
print('Loading VIT Done') | |
print('Loading Q-Former') | |
self.Qformer, self.query_tokens = self.init_Qformer( | |
num_query_token, self.visual_encoder.num_features | |
) | |
self.Qformer.cls = None | |
self.Qformer.bert.embeddings.word_embeddings = None | |
self.Qformer.bert.embeddings.position_embeddings = None | |
for layer in self.Qformer.bert.encoder.layer: | |
layer.output = None | |
layer.intermediate = None | |
self.load_from_pretrained(url_or_filename=q_former_model) | |
self.Qformer = self.Qformer.eval().to(self.projector_device, dtype=self.projector_dtype) | |
print('Loading Q-Former Done') | |
self.llama_proj = nn.Linear( | |
self.Qformer.config.hidden_size, llama_hidden_size | |
).to(self.projector_device, dtype=self.projector_dtype) | |
self.max_txt_len = max_txt_len | |
self.end_sym = end_sym | |
def encode_img(self, image): | |
image = image.to(self.vision_device, dtype=self.vision_dtype) | |
with torch.no_grad(): | |
image_embeds = self.ln_vision(self.visual_encoder(image)).to(self.projector_device, dtype=self.projector_dtype) | |
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.projector_device) | |
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1).to(self.projector_device) | |
query_output = self.Qformer.bert( | |
query_embeds=query_tokens, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_atts, | |
return_dict=True, | |
) | |
inputs_llama = self.llama_proj(query_output.last_hidden_state) | |
return inputs_llama | |