Spaces:
Sleeping
Sleeping
File size: 3,251 Bytes
f315cdb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoConfig
import json
from torch.utils.data import Dataset, DataLoader
instruct_dataset = f'./llava_instruct_150k.json'
with open(instruct_dataset, 'r') as f:
instruct_data = json.load(f)
class CustomTextDataset(Dataset):
def __init__(self, json_data, image_embedding_dict, tokenizer, maxContext=512):
self.image_embedding_dict = image_embedding_dict
self.tokenizer = tokenizer
self.json_data = json_data
self.maxContext = maxContext
self.entries = []
for entry in json_data:
image = entry['image']
image_embedding = self.getEmbeddingForImage(image)
if image_embedding is None:
continue
conversations = entry['conversations']
for i in range(len(conversations)):
if conversations[i]['from'] == 'human':
if len(conversations[i]['value'] + conversations[i + 1]['value']) > 512:
continue
question = 'Question: ' + conversations[i]['value'].lstrip('<image>\n')
answer = 'Answer: ' + conversations[i + 1]['value']
self.entries.append({
'image_name': image,
'image_embedding': image_embedding,
'Question': question,
'Answer': answer,
'QnAText': question + answer
})
print('------------- num entries = -----------------')
print(len(self.entries))
def getEmbeddingForImage(self, image):
if image in self.image_embedding_dict:
image_embedding = self.image_embedding_dict[image]
return image_embedding
else:
return None
def __len__(self):
return len(self.entries)
def __getitem__(self, idx):
entry = self.entries[idx]
image_name = entry['image_name']
Q_caption_tokens = tokenizer.encode(entry['Question'], add_special_tokens=True)
QnA_captions_tokens = tokenizer.encode(entry['QnAText'], add_special_tokens=True)
QTokensLength = len(Q_caption_tokens)
QnA_length = len(QnA_captions_tokens)
QnA_captions_tokens = QnA_captions_tokens + \
[tokenizer.pad_token_id] * (self.maxContext - len(QnA_captions_tokens))
return {'image_name': entry['image_name'],
'QText': entry['Question'],
'AText': entry['Answer'],
'image_embedding': entry['image_embedding'].to("cuda"),
'QnA_tokens': torch.tensor(QnA_captions_tokens),
'QTokensLength': QTokensLength,
'QnA_length': QnA_length
}
imgEmbDict = torch.load('img_embeddings_dict.pth')
model_name = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
custom_dataset = CustomTextDataset(instruct_data, imgEmbDict, tokenizer)
custom_dataloader = DataLoader(custom_dataset, batch_size=10, shuffle=True)
|