File size: 3,391 Bytes
b73c955
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from typing import List
import torch
from PIL import Image

from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
from modules import shared
from modules.text_generation import encode
from huggingface_hub import hf_hub_download
from .minigpt4.processor import Blip2ImageEvalProcessor
from .minigpt4.mini_gpt4 import MiniGPT4


class MiniGPT4_Pipeline(AbstractMultimodalPipeline):
    def __init__(self, params: dict) -> None:
        super().__init__()
        self.image_processor = Blip2ImageEvalProcessor()

    @staticmethod
    def placeholder_token_id() -> int:
        return 1

    @staticmethod
    def image_start() -> str:
        return "<Img>"

    @staticmethod
    def image_end() -> str:
        return "</Img>"

    @staticmethod
    def num_image_embeds() -> int:
        return 32

    @staticmethod
    def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
        return shared.model.model.embed_tokens(input_ids).to(shared.model.device, dtype=shared.model.dtype)

    @staticmethod
    def placeholder_embeddings() -> torch.Tensor:
        placeholders = encode("<ImgContent>", add_bos_token=False, add_special_tokens=False)[0]#torch.ones(MiniGPT4_Pipeline.num_image_embeds()) * MiniGPT4_Pipeline.placeholder_token_id()
        return MiniGPT4_Pipeline.embed_tokens(placeholders.to(shared.model.device, dtype=torch.int64)).to(dtype=shared.model.dtype)

    def embed_images(self, images: List[Image.Image]) -> torch.Tensor:
        im = torch.stack([self.image_processor(image) for image in images])
        image_emb = self.vision_tower.encode_img(im)
        return image_emb.to(shared.model.device, dtype=shared.model.dtype)


class MiniGPT4_13b_Pipeline(MiniGPT4_Pipeline):
    def __init__(self, params: dict) -> None:
        super().__init__(params)
        ckpt_path = hf_hub_download("Vision-CAIR/MiniGPT-4", "pretrained_minigpt4.pth")
        ckpt = torch.load(ckpt_path, map_location="cpu")
        self.vision_tower = MiniGPT4(llama_hidden_size=5120,
                                     vision_dtype=self._get_dtype("vision_bits", params),
                                     vision_device=self._get_device("vision_device", params),
                                     projector_device=self._get_device("projector_device", params),
                                     projector_dtype=self._get_dtype("projector_bits", params))
        self.vision_tower.load_state_dict(ckpt['model'], strict=False)

    @staticmethod
    def name() -> str:
        return "minigpt4-13b"


class MiniGPT4_7b_Pipeline(MiniGPT4_Pipeline):
    def __init__(self, params: dict) -> None:
        super().__init__(params)
        ckpt_path = hf_hub_download("ckpt/minigpt4-7B", "prerained_minigpt4_7b.pth")
        ckpt = torch.load(ckpt_path, map_location="cpu")
        self.vision_tower = MiniGPT4(llama_hidden_size=4096,
                                     vision_dtype=self._get_dtype("vision_bits", params),
                                     vision_device=self._get_device("vision_device", params),
                                     projector_device=self._get_device("projector_device", params),
                                     projector_dtype=self._get_dtype("projector_bits", params))
        self.vision_tower.load_state_dict(ckpt['model'], strict=False)

    @staticmethod
    def name() -> str:
        return "minigpt4-7b"