File size: 4,010 Bytes
13f6fc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import llama_cpp
import os, sys
from ctypes import POINTER, c_float
import torch
from torch import nn

# use PandaGPT path
panda_gpt_path = os.path.join(os.path.dirname(__file__), "PandaGPT")
imagebind_ckpt_path = os.path.join(os.path.dirname(__file__), "imagebind_huge.pth")

if not os.path.exists(panda_gpt_path):
    os.system("git clone https://github.com/yxuansu/PandaGPT "+panda_gpt_path)

sys.path.insert(0, os.path.join(panda_gpt_path,"code","model"))
from ImageBind.models import imagebind_model
from ImageBind import data

def numpy_to_floatptr(x):
    return x.astype(np.float32).ctypes.data_as(POINTER(c_float))

class PandaGPT:
    def __init__(self, args=(), kwargs={}):
        self.visual_encoder,_ = imagebind_model.imagebind_huge(pretrained=True, store_path=os.path.dirname(imagebind_ckpt_path))
        self.visual_encoder.eval()
        self.llama_proj = nn.Linear(1024, 5120) # self.visual_hidden_size, 5120)
        self.max_tgt_len = 400
        self.model = llama_cpp.Llama(*args, **kwargs)
        self.generated_text = ""
        self.device = "cpu"

    def eval_embd(self, x):
        y = numpy_to_floatptr(x.T)
        ctx = self.model.ctx
        n_past = self.model.n_tokens
        n_threads = self.model.n_threads
        llama_cpp.llama_eval_embd(ctx, y, x.shape[0], n_past, n_threads)
        self.model.n_tokens += x.shape[0]

    def eval_string(self, s):
        s = self.model.tokenize(s.encode())
        self.model.eval(s)

    def generate_with_print(self, end="###"):
        end = end.encode()
        ret = b""
        for i in range(self.max_tgt_len):
            token = self.model.sample()
            self.model.eval([token])
            txt = self.model.detokenize([token])
            ret += txt
            print(txt.decode(errors="replace"), flush=True, end="")
            if ret.endswith(end):
                break
        return ret.decode(errors="replace")

    def load_projection(self, path):
        state = torch.load(path, map_location="cpu")
        self.llama_proj.load_state_dict({
            "weight": state["llama_proj.weight"],
            "bias": state["llama_proj.bias"]})

    def eval_inputs(self, inputs):
        self.eval_string("<Img>")
        embds = self.extract_multimoal_feature(inputs)
        for i in embds:
            self.eval_embd(i)
        self.eval_string("</Img> ")

    def chat(self, question):
        return self.chat_with_image(None, question)

    def chat_with_image(self, inputs, question):
        if self.generated_text == "":
            self.eval_string("###")
        self.eval_string(" Human: ")
        if inputs:
            self.eval_inputs(inputs)
        self.eval_string(question)
        self.eval_string("\n### Assistant:")
        ret = self.generate_with_print(end="###")
        self.generated_text += ret
        return ret

    def extract_multimoal_feature(self, inputs):
        features = []
        for key in ["image", "audio", "video", "thermal"]:
            if key + "_paths" in inputs:
                embeds = self.encode_data(key, inputs[key+"_paths"])
                features.append(embeds)
        return features

    def encode_data(self, data_type, data_paths):

        type_map = {
            "image": ModalityType.VISION,
            "audio": ModalityType.AUDIO,
            "video": ModalityType.VISION,
            "thermal": ModalityType.THERMAL,
        }
        load_map = {
            "image": data.load_and_transform_vision_data,
            "audio": data.load_and_transform_audio_data,
            "video": data.load_and_transform_video_data,
            "thermal": data.load_and_transform_thermal_data
        }

        load_function = load_map[data_type]
        key = type_map[data_type]

        inputs = {key: load_function(data_paths, self.device)}
        with torch.no_grad():
            embeddings = self.visual_encoder(inputs)
            embeds = embeddings[key]
            embeds = self.llama_proj(embeds).cpu().numpy()
        return embeds