File size: 3,251 Bytes
f315cdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoConfig
import json 
from torch.utils.data import Dataset, DataLoader

instruct_dataset = f'./llava_instruct_150k.json'
with open(instruct_dataset, 'r') as f:
    instruct_data = json.load(f)

class CustomTextDataset(Dataset):
    def __init__(self, json_data, image_embedding_dict,  tokenizer, maxContext=512):
        self.image_embedding_dict = image_embedding_dict
        self.tokenizer = tokenizer
        self.json_data = json_data
        self.maxContext = maxContext
        self.entries = []        
        for entry in json_data:
            image = entry['image']
            image_embedding = self.getEmbeddingForImage(image)
            if image_embedding is None:
                continue
            
            conversations = entry['conversations']
            for i in range(len(conversations)):
                if conversations[i]['from'] == 'human':
                    if len(conversations[i]['value'] + conversations[i + 1]['value']) > 512:
                        continue
                    question = 'Question: ' + conversations[i]['value'].lstrip('<image>\n')
                    answer = 'Answer: ' + conversations[i + 1]['value']
                    self.entries.append({
                        'image_name': image,
                        'image_embedding': image_embedding,
                        'Question': question,
                        'Answer': answer,
                        'QnAText': question + answer
                        }) 
        print('------------- num entries = -----------------')
        print(len(self.entries))

    def getEmbeddingForImage(self, image):
        if image in self.image_embedding_dict:
            image_embedding = self.image_embedding_dict[image]
            return image_embedding
        else:
            return None      

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx):
        entry = self.entries[idx]
        image_name = entry['image_name']
        Q_caption_tokens = tokenizer.encode(entry['Question'], add_special_tokens=True)
        QnA_captions_tokens = tokenizer.encode(entry['QnAText'], add_special_tokens=True)
        QTokensLength = len(Q_caption_tokens)
        QnA_length = len(QnA_captions_tokens)
        

        QnA_captions_tokens = QnA_captions_tokens + \
            [tokenizer.pad_token_id] * (self.maxContext - len(QnA_captions_tokens))       

        return {'image_name': entry['image_name'], 
                'QText': entry['Question'], 
                'AText': entry['Answer'], 
                'image_embedding':  entry['image_embedding'].to("cuda"), 
                'QnA_tokens': torch.tensor(QnA_captions_tokens),
                'QTokensLength': QTokensLength,
                'QnA_length': QnA_length
               }



imgEmbDict = torch.load('img_embeddings_dict.pth')

model_name = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

custom_dataset = CustomTextDataset(instruct_data, imgEmbDict,  tokenizer)
custom_dataloader = DataLoader(custom_dataset, batch_size=10, shuffle=True)