File size: 4,587 Bytes
13f6fc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17e82fc
13f6fc8
 
 
 
 
 
 
17e82fc
13f6fc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17e82fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13f6fc8
 
 
 
 
 
 
17e82fc
 
 
 
13f6fc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import llama_cpp
import os, sys
from ctypes import POINTER, c_float
import torch
from torch import nn

# use PandaGPT path
panda_gpt_path = os.path.join(os.path.dirname(__file__), "PandaGPT")
imagebind_ckpt_path = os.path.join(os.path.dirname(__file__), "imagebind_huge.pth")

if not os.path.exists(panda_gpt_path):
    os.system("git clone https://github.com/yxuansu/PandaGPT "+panda_gpt_path)

sys.path.insert(0, os.path.join(panda_gpt_path,"code","model"))
from ImageBind.models import imagebind_model
from ImageBind import data

def numpy_to_floatptr(x):
    return x.astype(np.float32).ctypes.data_as(POINTER(c_float))

class PandaGPT:
    def __init__(self, args=(), kwargs={}):
        self.visual_encoder,_ = imagebind_model.imagebind_huge(pretrained=True, store_path=os.path.dirname(imagebind_ckpt_path))
        self.visual_encoder.eval()
        self.llama_proj = nn.Linear(1024, 5120) # self.visual_hidden_size, 5120)
        self.max_tgt_len = 400
        self.model = llama_cpp.Llama(*args, **kwargs)
        self.generated_text = ""
        self.device = "cpu"

    def eval_embd(self, x):
        y = numpy_to_floatptr(x.T)
        ctx = self.model.ctx
        n_past = self.model.n_tokens
        n_threads = self.model.n_threads
        llama_cpp.llama_eval_embd(ctx, y, x.shape[0], n_past, n_threads)
        self.model.n_tokens += x.shape[0]

    def eval_string(self, s):
        s = self.model.tokenize(s.encode())
        self.model.eval(s)

    def generate_with_print(self, end="###", hook=lambda x: print(x,flush=True, end="")):
        end = end.encode()
        ret = b""
        for i in range(self.max_tgt_len):
            token = self.model.sample()
            self.model.eval([token])
            txt = self.model.detokenize([token])
            ret += txt
            hook(txt.decode(errors="replace"))
            if ret.endswith(end):
                break
        return ret.decode(errors="replace")

    def load_projection(self, path):
        state = torch.load(path, map_location="cpu")
        self.llama_proj.load_state_dict({
            "weight": state["llama_proj.weight"],
            "bias": state["llama_proj.bias"]})

    def eval_inputs(self, inputs):
        self.eval_string("<Img>")
        embds = self.extract_multimoal_feature(inputs)
        for i in embds:
            self.eval_embd(i)
        self.eval_string("</Img> ")

    def chat(self, question):
        return self.chat_with_image(None, question)

    def chat_with_image(self, inputs, question):
        self.eval_with_image(inputs, question)
        ret = self.generate_with_print(end="###")
        self.generated_text += ret
        return ret

    def generate(self, end="###"):
        end = end.encode()
        ret = b""
        for i in range(self.max_tgt_len):
            token = self.model.sample()
            self.model.eval([token])
            txt = self.model.detokenize([token])
            ret += txt
            yield txt.decode(errors="replace")
            if ret.endswith(end):
                break

    def eval_with_image(self, inputs, question):
        if self.generated_text == "":
            self.eval_string("###")
        self.eval_string(" Human: ")
        if inputs:
            self.eval_inputs(inputs)
        self.eval_string(question)
        self.eval_string("\n### Assistant:")

    def reset(self):
        self.generated_text = ""
        self.model.reset()

    def extract_multimoal_feature(self, inputs):
        features = []
        for key in ["image", "audio", "video", "thermal"]:
            if key + "_paths" in inputs:
                embeds = self.encode_data(key, inputs[key+"_paths"])
                features.append(embeds)
        return features

    def encode_data(self, data_type, data_paths):

        type_map = {
            "image": ModalityType.VISION,
            "audio": ModalityType.AUDIO,
            "video": ModalityType.VISION,
            "thermal": ModalityType.THERMAL,
        }
        load_map = {
            "image": data.load_and_transform_vision_data,
            "audio": data.load_and_transform_audio_data,
            "video": data.load_and_transform_video_data,
            "thermal": data.load_and_transform_thermal_data
        }

        load_function = load_map[data_type]
        key = type_map[data_type]

        inputs = {key: load_function(data_paths, self.device)}
        with torch.no_grad():
            embeddings = self.visual_encoder(inputs)
            embeds = embeddings[key]
            embeds = self.llama_proj(embeds).cpu().numpy()
        return embeds